1use std::{
20 collections::HashMap,
21 fmt,
22 sync::{
23 atomic::{AtomicBool, Ordering::SeqCst},
24 Arc,
25 },
26 time::UNIX_EPOCH,
27};
28
29use darkfi_serial::{
30 async_trait, AsyncDecodable, AsyncEncodable, SerialDecodable, SerialEncodable, VarInt,
31};
32use rand::{rngs::OsRng, Rng};
33use smol::{
34 io::{self, AsyncRead, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
35 lock::{Mutex as AsyncMutex, OnceCell},
36 Executor,
37};
38use tracing::{debug, error, trace, warn};
39use url::Url;
40
41use super::{
42 dnet::{self, dnetev, DnetEvent},
43 hosts::{HostColor, HostsPtr},
44 message,
45 message::{SerializedMessage, VersionMessage, MAX_COMMAND_LENGTH},
46 message_publisher::{MessageSubscription, MessageSubsystem},
47 metering::{MeteringConfiguration, MeteringQueue},
48 p2p::P2pPtr,
49 session::{
50 Session, SessionBitFlag, SessionWeakPtr, SESSION_ALL, SESSION_INBOUND, SESSION_OUTBOUND,
51 SESSION_REFINE,
52 },
53 transport::PtStream,
54};
55use crate::{
56 net::BanPolicy,
57 system::{msleep, Publisher, PublisherPtr, StoppableTask, StoppableTaskPtr, Subscription},
58 util::{logger::verbose, time::NanoTimestamp},
59 Error, Result,
60};
61
62pub type ChannelPtr = Arc<Channel>;
64
65#[derive(Clone, Debug, SerialEncodable, SerialDecodable)]
67pub struct ChannelInfo {
68 pub resolve_addr: Option<Url>,
69 pub connect_addr: Url,
70 pub start_time: u64,
71 pub id: u32,
72 pub transport_mixed: bool,
73}
74
75impl ChannelInfo {
76 fn new(
77 resolve_addr: Option<Url>,
78 connect_addr: Url,
79 start_time: u64,
80 transport_mixed: bool,
81 ) -> Self {
82 Self { resolve_addr, connect_addr, start_time, id: OsRng.gen(), transport_mixed }
83 }
84}
85
86pub struct Channel {
88 reader: AsyncMutex<ReadHalf<Box<dyn PtStream>>>,
90 writer: AsyncMutex<WriteHalf<Box<dyn PtStream>>>,
92 message_subsystem: MessageSubsystem,
94 stop_publisher: PublisherPtr<Error>,
96 receive_task: StoppableTaskPtr,
98 stopped: AtomicBool,
100 pub(in crate::net) session: SessionWeakPtr,
102 pub version: OnceCell<Arc<VersionMessage>>,
106 pub info: ChannelInfo,
108 metering_map: AsyncMutex<HashMap<String, MeteringQueue>>,
111}
112
113impl Channel {
114 pub async fn new(
118 stream: Box<dyn PtStream>,
119 resolve_addr: Option<Url>,
120 connect_addr: Url,
121 session: SessionWeakPtr,
122 transport_mixed: bool,
123 ) -> Arc<Self> {
124 let (reader, writer) = io::split(stream);
125 let reader = AsyncMutex::new(reader);
126 let writer = AsyncMutex::new(writer);
127
128 let message_subsystem = MessageSubsystem::new();
129 Self::setup_dispatchers(&message_subsystem).await;
130
131 let start_time = UNIX_EPOCH.elapsed().unwrap().as_secs();
132 let info =
133 ChannelInfo::new(resolve_addr, connect_addr.clone(), start_time, transport_mixed);
134 let metering_map = AsyncMutex::new(HashMap::new());
135
136 Arc::new(Self {
137 reader,
138 writer,
139 message_subsystem,
140 stop_publisher: Publisher::new(),
141 receive_task: StoppableTask::new(),
142 stopped: AtomicBool::new(false),
143 session,
144 version: OnceCell::new(),
145 info,
146 metering_map,
147 })
148 }
149
150 async fn setup_dispatchers(subsystem: &MessageSubsystem) {
152 subsystem.add_dispatch::<message::VersionMessage>().await;
153 subsystem.add_dispatch::<message::VerackMessage>().await;
154 subsystem.add_dispatch::<message::PingMessage>().await;
155 subsystem.add_dispatch::<message::PongMessage>().await;
156 subsystem.add_dispatch::<message::GetAddrsMessage>().await;
157 subsystem.add_dispatch::<message::AddrsMessage>().await;
158 }
159
160 pub fn start(self: Arc<Self>, executor: Arc<Executor<'_>>) {
163 debug!(target: "net::channel::start", "START {self:?}");
164
165 let self_ = self.clone();
166 self.receive_task.clone().start(
167 self.clone().main_receive_loop(),
168 |result| self_.handle_stop(result),
169 Error::ChannelStopped,
170 executor,
171 );
172
173 debug!(target: "net::channel::start", "END {self:?}");
174 }
175
176 pub async fn stop(&self) {
179 debug!(target: "net::channel::stop", "START {self:?}");
180 self.receive_task.stop().await;
181 debug!(target: "net::channel::stop", "END {self:?}");
182 }
183
184 pub async fn subscribe_stop(&self) -> Result<Subscription<Error>> {
187 debug!(target: "net::channel::subscribe_stop", "START {self:?}");
188
189 if self.is_stopped() {
190 return Err(Error::ChannelStopped)
191 }
192
193 let sub = self.stop_publisher.clone().subscribe().await;
194
195 debug!(target: "net::channel::subscribe_stop", "END {self:?}");
196
197 Ok(sub)
198 }
199
200 pub fn is_stopped(&self) -> bool {
201 self.stopped.load(SeqCst)
202 }
203
204 pub async fn send<M: message::Message>(&self, message: &M) -> Result<()> {
208 self.send_serialized(
209 &SerializedMessage::new(message).await,
210 &M::METERING_SCORE,
211 &M::METERING_CONFIGURATION,
212 )
213 .await
214 }
215
216 pub async fn send_serialized(
225 &self,
226 message: &SerializedMessage,
227 metering_score: &u64,
228 metering_config: &MeteringConfiguration,
229 ) -> Result<()> {
230 debug!(
231 target: "net::channel::send", "[START] command={} {self:?}",
232 message.command,
233 );
234
235 let mut lock = self.metering_map.lock().await;
238 if !lock.contains_key(&message.command) {
239 lock.insert(message.command.clone(), MeteringQueue::new(metering_config.clone()));
240 }
241
242 let queue = lock.get_mut(&message.command).unwrap();
246 queue.push(metering_score);
247 let sleep_time = queue.sleep_time();
248 drop(lock);
249
250 if let Some(sleep_time) = sleep_time {
252 let sleep_time = 2 * sleep_time;
253 debug!(
254 target: "net::channel::send",
255 "[P2P] Channel rate limit is active, sleeping before sending for: {sleep_time} (ms)"
256 );
257 msleep(sleep_time).await;
258 }
259
260 if self.is_stopped() {
262 return Err(Error::ChannelStopped)
263 }
264
265 if let Err(e) = self.send_message(message).await {
267 if self.session.upgrade().unwrap().type_id() & (SESSION_ALL & !SESSION_REFINE) != 0 {
268 error!(
269 target: "net::channel::send", "[P2P] Channel send error for [{self:?}]: {e}"
270 );
271 }
272 self.stop().await;
273 return Err(Error::ChannelStopped)
274 }
275
276 debug!(
277 target: "net::channel::send", "[END] command={} {self:?}",
278 message.command
279 );
280
281 Ok(())
282 }
283
284 async fn send_message(&self, message: &SerializedMessage) -> Result<()> {
287 assert!(!message.command.is_empty());
288
289 let stream = &mut *self.writer.lock().await;
290 let mut written: usize = 0;
291
292 dnetev!(self, SendMessage, {
293 chan: self.info.clone(),
294 cmd: message.command.clone(),
295 time: NanoTimestamp::current_time(),
296 });
297
298 trace!(target: "net::channel::send_message", "Sending magic...");
299 let magic_bytes = self.p2p().settings().read().await.magic_bytes.0;
300 written += magic_bytes.encode_async(stream).await?;
301 trace!(target: "net::channel::send_message", "Sent magic");
302
303 trace!(target: "net::channel::send_message", "Sending command...");
304 written += message.command.encode_async(stream).await?;
305 trace!(target: "net::channel::send_message", "Sent command: {}", message.command);
306
307 trace!(target: "net::channel::send_message", "Sending payload...");
308 written += VarInt(message.payload.len() as u64).encode_async(stream).await?;
310 stream.write_all(&message.payload).await?;
312 written += message.payload.len();
313
314 trace!(target: "net::channel::send_message", "Sent payload {} bytes, total bytes {written}",
315 message.payload.len());
316
317 stream.flush().await?;
318
319 Ok(())
320 }
321
322 pub async fn read_command<R: AsyncRead + Unpin + Send + Sized>(
328 &self,
329 stream: &mut R,
330 ) -> Result<String> {
331 let mut magic = [0u8; 4];
334 trace!(target: "net::channel::read_command", "Reading magic...");
335 stream.read_exact(&mut magic).await?;
336
337 trace!(target: "net::channel::read_command", "Read magic {magic:?}");
338 let magic_bytes = self.p2p().settings().read().await.magic_bytes.0;
339 if magic != magic_bytes {
340 error!(target: "net::channel::read_command", "Error: Magic bytes mismatch");
341
342 if self.session_type_id() & SESSION_OUTBOUND != 0 {
344 if let BanPolicy::Strict = self.p2p().settings().read().await.ban_policy {
345 self.ban().await;
346 }
347 }
348
349 return Err(Error::MalformedPacket)
350 }
351
352 let cmd_len = VarInt::decode_async(stream).await?.0;
354 if cmd_len > (MAX_COMMAND_LENGTH as u64) {
355 error!(target: "net::channel::read_command",
356 "Error: Command length ({cmd_len}) exceeds configured limit ({MAX_COMMAND_LENGTH}). Dropping...");
357 return Err(Error::MessageInvalid);
358 }
359
360 let mut take = stream.take(cmd_len);
362
363 let mut bytes = vec![0; cmd_len.try_into().unwrap()];
365 take.read_exact(&mut bytes).await?;
366
367 let command = String::from_utf8(bytes)?;
368
369 Ok(command)
370 }
371
372 pub async fn subscribe_msg<M: message::Message>(&self) -> Result<MessageSubscription<M>> {
374 debug!(
375 target: "net::channel::subscribe_msg", "[START] command={} {self:?}",
376 M::NAME
377 );
378
379 let sub = self.message_subsystem.subscribe::<M>().await;
380
381 debug!(
382 target: "net::channel::subscribe_msg", "[END] command={} {self:?}",
383 M::NAME
384 );
385
386 sub
387 }
388
389 async fn handle_stop(self: Arc<Self>, result: Result<()>) {
392 debug!(target: "net::channel::handle_stop", "[START] {self:?}");
393
394 self.stopped.store(true, SeqCst);
395
396 match result {
397 Ok(()) => panic!("Channel task should never complete without error status"),
398 Err(e) => {
400 self.stop_publisher.notify(Error::ChannelStopped).await;
401 self.message_subsystem.trigger_error(e).await;
402 }
403 }
404
405 debug!(target: "net::channel::handle_stop", "[END] {self:?}");
406 }
407
408 async fn main_receive_loop(self: Arc<Self>) -> Result<()> {
410 debug!(target: "net::channel::main_receive_loop", "[START] {self:?}");
411
412 let reader = &mut *self.reader.lock().await;
414
415 loop {
417 let command = match self.read_command(reader).await {
418 Ok(command) => command,
419 Err(err) => {
420 if Self::is_eof_error(&err) {
421 verbose!(
422 target: "net::channel::main_receive_loop",
423 "[P2P] Channel {} disconnected",
424 self.display_address()
425 );
426 } else if let Error::MessageInvalid = err {
427 if let BanPolicy::Strict = self.p2p().settings().read().await.ban_policy {
429 self.ban().await;
430 }
431 } else if self.session.upgrade().unwrap().type_id() &
432 (SESSION_ALL & !SESSION_REFINE) !=
433 0
434 {
435 error!(
436 target: "net::channel::main_receive_loop",
437 "[P2P] Read error on channel {}: {err}",
438 self.display_address()
439 );
440 }
441
442 debug!(
443 target: "net::channel::main_receive_loop",
444 "Stopping channel {self:?}"
445 );
446 return Err(Error::ChannelStopped)
447 }
448 };
449
450 dnetev!(self, RecvMessage, {
451 chan: self.info.clone(),
452 cmd: command.clone(),
453 time: NanoTimestamp::current_time(),
454 });
455
456 match self.message_subsystem.notify(&command, reader).await {
458 Ok(()) => {}
459 Err(Error::MissingDispatcher) |
460 Err(Error::MessageInvalid) |
461 Err(Error::MeteringLimitExceeded) => {
462 if self.session.upgrade().unwrap().type_id() != SESSION_REFINE {
479 warn!(
480 target: "net::channel::main_receive_loop",
481 "MissingDispatcher|MessageInvalid|MeteringLimitExceeded for command={command}, channel={self:?}"
482 );
483
484 if let BanPolicy::Strict = self.p2p().settings().read().await.ban_policy {
485 self.ban().await;
486 }
487
488 return Err(Error::ChannelStopped)
489 }
490 }
491 Err(_) => unreachable!("You added a new error in notify()"),
492 }
493 }
494 }
495
496 pub async fn ban(&self) {
498 debug!(target: "net::channel::ban", "START {self:?}");
499 debug!(target: "net::channel::ban", "Peer: {:?}", self.display_address());
500
501 let peer = {
505 if self.session_type_id() & SESSION_INBOUND != 0 {
506 if self.address().host().is_none() {
507 error!("[P2P] ban() caught Url without host: {:?}", self.display_address());
508 return
509 }
510
511 #[cfg(feature = "p2p-tor")]
513 if (self.address().scheme() == "tor" || self.address().scheme() == "tor+tls") &&
514 self.p2p().hosts().is_local_host(self.address())
515 {
516 return
517 }
518
519 if self.address().scheme() == "unix" {
520 return
521 }
522
523 if self.hosts().has_existing_connection(self.address()) {
527 self.address().clone()
528 } else {
529 let mut addr = self.address().clone();
530 addr.set_port(None).unwrap();
531 addr
532 }
533 } else {
534 self.address().clone()
535 }
536 };
537
538 let last_seen = UNIX_EPOCH.elapsed().unwrap().as_secs();
539 verbose!(target: "net::channel::ban", "Blacklisting peer={peer}");
540 match self.p2p().hosts().move_host(&peer, last_seen, HostColor::Black).await {
541 Ok(()) => {
542 verbose!(target: "net::channel::ban", "Peer={peer} blacklisted successfully");
543 }
544 Err(e) => {
545 warn!(target: "net::channel::ban", "Could not blacklisted peer={peer}, err={e}");
546 }
547 }
548 self.stop().await;
549 debug!(target: "net::channel::ban", "STOP {self:?}");
550 }
551
552 pub fn address(&self) -> &Url {
559 if !self.info.transport_mixed {
560 if let Some(resolve_addr) = &self.info.resolve_addr {
561 return resolve_addr
562 }
563 }
564 &self.info.connect_addr
565 }
566
567 pub fn display_address(&self) -> &Url {
570 self.info.resolve_addr.as_ref().unwrap_or(&self.info.connect_addr)
571 }
572
573 pub fn resolve_addr(&self) -> Option<Url> {
576 self.info.resolve_addr.clone()
577 }
578
579 pub fn connect_addr(&self) -> &Url {
581 &self.info.connect_addr
582 }
583
584 pub(crate) async fn set_version(&self, version: Arc<VersionMessage>) {
587 self.version.set(version).await.unwrap();
588 }
589 pub fn get_version(&self) -> Arc<VersionMessage> {
591 self.version.get().unwrap().clone()
592 }
593
594 pub fn message_subsystem(&self) -> &MessageSubsystem {
596 &self.message_subsystem
597 }
598
599 fn session(&self) -> Arc<dyn Session> {
600 self.session.upgrade().unwrap()
601 }
602
603 pub fn session_type_id(&self) -> SessionBitFlag {
604 let session = self.session();
605 session.type_id()
606 }
607
608 #[inline]
609 pub fn p2p(&self) -> P2pPtr {
610 self.session().p2p()
611 }
612 #[inline]
613 pub fn hosts(&self) -> HostsPtr {
614 self.p2p().hosts()
615 }
616
617 fn is_eof_error(err: &Error) -> bool {
618 match err {
619 Error::Io(ioerr) => ioerr == &std::io::ErrorKind::UnexpectedEof,
620 _ => false,
621 }
622 }
623}
624
625impl fmt::Debug for Channel {
626 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
627 write!(f, "<Channel addr='{}' id={}>", self.display_address(), self.info.id)
628 }
629}