darkfi/net/
channel.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2026 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use std::{
20    collections::HashMap,
21    fmt,
22    sync::{
23        atomic::{AtomicBool, Ordering::SeqCst},
24        Arc,
25    },
26    time::UNIX_EPOCH,
27};
28
29use darkfi_serial::{
30    async_trait, AsyncDecodable, AsyncEncodable, SerialDecodable, SerialEncodable, VarInt,
31};
32use rand::{rngs::OsRng, Rng};
33use smol::{
34    io::{self, AsyncRead, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
35    lock::{Mutex as AsyncMutex, OnceCell},
36    Executor,
37};
38use tracing::{debug, error, trace, warn};
39use url::Url;
40
41use super::{
42    dnet::{self, dnetev, DnetEvent},
43    hosts::{HostColor, HostsPtr},
44    message,
45    message::{SerializedMessage, VersionMessage, MAX_COMMAND_LENGTH},
46    message_publisher::{MessageSubscription, MessageSubsystem},
47    metering::{MeteringConfiguration, MeteringQueue},
48    p2p::P2pPtr,
49    session::{
50        Session, SessionBitFlag, SessionWeakPtr, SESSION_ALL, SESSION_INBOUND, SESSION_OUTBOUND,
51        SESSION_REFINE,
52    },
53    transport::PtStream,
54};
55use crate::{
56    net::BanPolicy,
57    system::{msleep, Publisher, PublisherPtr, StoppableTask, StoppableTaskPtr, Subscription},
58    util::{logger::verbose, time::NanoTimestamp},
59    Error, Result,
60};
61
62/// Atomic pointer to async channel
63pub type ChannelPtr = Arc<Channel>;
64
65/// Channel debug info
66#[derive(Clone, Debug, SerialEncodable, SerialDecodable)]
67pub struct ChannelInfo {
68    pub resolve_addr: Option<Url>,
69    pub connect_addr: Url,
70    pub start_time: u64,
71    pub id: u32,
72    pub transport_mixed: bool,
73}
74
75impl ChannelInfo {
76    fn new(
77        resolve_addr: Option<Url>,
78        connect_addr: Url,
79        start_time: u64,
80        transport_mixed: bool,
81    ) -> Self {
82        Self { resolve_addr, connect_addr, start_time, id: OsRng.gen(), transport_mixed }
83    }
84}
85
86/// Async channel for communication between nodes.
87pub struct Channel {
88    /// The reading half of the transport stream
89    reader: AsyncMutex<ReadHalf<Box<dyn PtStream>>>,
90    /// The writing half of the transport stream
91    writer: AsyncMutex<WriteHalf<Box<dyn PtStream>>>,
92    /// The message subsystem instance for this channel
93    message_subsystem: MessageSubsystem,
94    /// Publisher listening for stop signal for closing this channel
95    stop_publisher: PublisherPtr<Error>,
96    /// Task that is listening for the stop signal
97    receive_task: StoppableTaskPtr,
98    /// A boolean marking if this channel is stopped
99    stopped: AtomicBool,
100    /// Weak pointer to respective session
101    pub(in crate::net) session: SessionWeakPtr,
102    /// The version message of the node we are connected to.
103    /// Some if the version exchange has already occurred, None
104    /// otherwise.
105    pub version: OnceCell<Arc<VersionMessage>>,
106    /// Channel debug info
107    pub info: ChannelInfo,
108    /// Map holding a `MeteringQueue` for each [`crate::net::Message`]
109    /// to perform rate limiting of propagation towards the stream.
110    metering_map: AsyncMutex<HashMap<String, MeteringQueue>>,
111}
112
113impl Channel {
114    /// Sets up a new channel. Creates a reader and writer [`PtStream`] and
115    /// the message publisher subsystem. Performs a network handshake on the
116    /// subsystem dispatchers.
117    pub async fn new(
118        stream: Box<dyn PtStream>,
119        resolve_addr: Option<Url>,
120        connect_addr: Url,
121        session: SessionWeakPtr,
122        transport_mixed: bool,
123    ) -> Arc<Self> {
124        let (reader, writer) = io::split(stream);
125        let reader = AsyncMutex::new(reader);
126        let writer = AsyncMutex::new(writer);
127
128        let message_subsystem = MessageSubsystem::new();
129        Self::setup_dispatchers(&message_subsystem).await;
130
131        let start_time = UNIX_EPOCH.elapsed().unwrap().as_secs();
132        let info =
133            ChannelInfo::new(resolve_addr, connect_addr.clone(), start_time, transport_mixed);
134        let metering_map = AsyncMutex::new(HashMap::new());
135
136        Arc::new(Self {
137            reader,
138            writer,
139            message_subsystem,
140            stop_publisher: Publisher::new(),
141            receive_task: StoppableTask::new(),
142            stopped: AtomicBool::new(false),
143            session,
144            version: OnceCell::new(),
145            info,
146            metering_map,
147        })
148    }
149
150    /// Perform network handshake for message subsystem dispatchers.
151    async fn setup_dispatchers(subsystem: &MessageSubsystem) {
152        subsystem.add_dispatch::<message::VersionMessage>().await;
153        subsystem.add_dispatch::<message::VerackMessage>().await;
154        subsystem.add_dispatch::<message::PingMessage>().await;
155        subsystem.add_dispatch::<message::PongMessage>().await;
156        subsystem.add_dispatch::<message::GetAddrsMessage>().await;
157        subsystem.add_dispatch::<message::AddrsMessage>().await;
158    }
159
160    /// Starts the channel. Runs a receive loop to start receiving messages
161    /// or handles a network failure.
162    pub fn start(self: Arc<Self>, executor: Arc<Executor<'_>>) {
163        debug!(target: "net::channel::start", "START {self:?}");
164
165        let self_ = self.clone();
166        self.receive_task.clone().start(
167            self.clone().main_receive_loop(),
168            |result| self_.handle_stop(result),
169            Error::ChannelStopped,
170            executor,
171        );
172
173        debug!(target: "net::channel::start", "END {self:?}");
174    }
175
176    /// Stops the channel.
177    /// Notifies all publishers that the channel has been closed in `handle_stop()`.
178    pub async fn stop(&self) {
179        debug!(target: "net::channel::stop", "START {self:?}");
180        self.receive_task.stop().await;
181        debug!(target: "net::channel::stop", "END {self:?}");
182    }
183
184    /// Creates a subscription to a stopped signal.
185    /// If the channel is stopped then this will return a ChannelStopped error.
186    pub async fn subscribe_stop(&self) -> Result<Subscription<Error>> {
187        debug!(target: "net::channel::subscribe_stop", "START {self:?}");
188
189        if self.is_stopped() {
190            return Err(Error::ChannelStopped)
191        }
192
193        let sub = self.stop_publisher.clone().subscribe().await;
194
195        debug!(target: "net::channel::subscribe_stop", "END {self:?}");
196
197        Ok(sub)
198    }
199
200    pub fn is_stopped(&self) -> bool {
201        self.stopped.load(SeqCst)
202    }
203
204    /// Sends a message across a channel. First it converts the message
205    /// into a `SerializedMessage` and then calls `send_serialized` to send it.
206    /// Returns an error if something goes wrong.
207    pub async fn send<M: message::Message>(&self, message: &M) -> Result<()> {
208        self.send_serialized(
209            &SerializedMessage::new(message).await,
210            &M::METERING_SCORE,
211            &M::METERING_CONFIGURATION,
212        )
213        .await
214    }
215
216    /// Sends the encoded payload of provided `SerializedMessage` across the channel.
217    ///
218    /// We first check if we should apply some throttling, based on the provided
219    /// `Message` configuration. We always sleep 2x times more than the expected one,
220    /// so we don't flood the peer.
221    /// Then, calls `send_message` that creates a new payload and sends it over the
222    /// network transport as a packet.
223    /// Returns an error if something goes wrong.
224    pub async fn send_serialized(
225        &self,
226        message: &SerializedMessage,
227        metering_score: &u64,
228        metering_config: &MeteringConfiguration,
229    ) -> Result<()> {
230        debug!(
231             target: "net::channel::send", "[START] command={} {self:?}",
232             message.command,
233        );
234
235        // Check if we need to initialize a `MeteringQueue`
236        // for this specific `Message`.
237        let mut lock = self.metering_map.lock().await;
238        if !lock.contains_key(&message.command) {
239            lock.insert(message.command.clone(), MeteringQueue::new(metering_config.clone()));
240        }
241
242        // Insert metering information and grab potential sleep time.
243        // It's safe to unwrap here since we initialized the value
244        // previously.
245        let queue = lock.get_mut(&message.command).unwrap();
246        queue.push(metering_score);
247        let sleep_time = queue.sleep_time();
248        drop(lock);
249
250        // Check if we need to sleep
251        if let Some(sleep_time) = sleep_time {
252            let sleep_time = 2 * sleep_time;
253            debug!(
254                target: "net::channel::send",
255                "[P2P] Channel rate limit is active, sleeping before sending for: {sleep_time} (ms)"
256            );
257            msleep(sleep_time).await;
258        }
259
260        // Check if the channel is stopped, so we can abort
261        if self.is_stopped() {
262            return Err(Error::ChannelStopped)
263        }
264
265        // Catch failure and stop channel, return a net error
266        if let Err(e) = self.send_message(message).await {
267            if self.session.upgrade().unwrap().type_id() & (SESSION_ALL & !SESSION_REFINE) != 0 {
268                error!(
269                    target: "net::channel::send", "[P2P] Channel send error for [{self:?}]: {e}"
270                );
271            }
272            self.stop().await;
273            return Err(Error::ChannelStopped)
274        }
275
276        debug!(
277            target: "net::channel::send", "[END] command={} {self:?}",
278            message.command
279        );
280
281        Ok(())
282    }
283
284    /// Sends the encoded payload of provided `SerializedMessage` by writing
285    /// the data to the channel async stream.
286    async fn send_message(&self, message: &SerializedMessage) -> Result<()> {
287        assert!(!message.command.is_empty());
288
289        let stream = &mut *self.writer.lock().await;
290        let mut written: usize = 0;
291
292        dnetev!(self, SendMessage, {
293            chan: self.info.clone(),
294            cmd: message.command.clone(),
295            time: NanoTimestamp::current_time(),
296        });
297
298        trace!(target: "net::channel::send_message", "Sending magic...");
299        let magic_bytes = self.p2p().settings().read().await.magic_bytes.0;
300        written += magic_bytes.encode_async(stream).await?;
301        trace!(target: "net::channel::send_message", "Sent magic");
302
303        trace!(target: "net::channel::send_message", "Sending command...");
304        written += message.command.encode_async(stream).await?;
305        trace!(target: "net::channel::send_message", "Sent command: {}", message.command);
306
307        trace!(target: "net::channel::send_message", "Sending payload...");
308        // First extract the length of the payload as a VarInt and write it to the stream.
309        written += VarInt(message.payload.len() as u64).encode_async(stream).await?;
310        // Then write the encoded payload itself to the stream.
311        stream.write_all(&message.payload).await?;
312        written += message.payload.len();
313
314        trace!(target: "net::channel::send_message", "Sent payload {} bytes, total bytes {written}",
315            message.payload.len());
316
317        stream.flush().await?;
318
319        Ok(())
320    }
321
322    /// Returns a decoded Message command. We start by extracting the length
323    /// from the stream, then allocate the precise buffer for this length
324    /// using stream.take(). This manual deserialization provides a basic
325    /// DDOS protection, since it prevents nodes from sending an arbitarily
326    /// large payload.
327    pub async fn read_command<R: AsyncRead + Unpin + Send + Sized>(
328        &self,
329        stream: &mut R,
330    ) -> Result<String> {
331        // Messages should have a 4 byte header of magic digits.
332        // This is used for network debugging.
333        let mut magic = [0u8; 4];
334        trace!(target: "net::channel::read_command", "Reading magic...");
335        stream.read_exact(&mut magic).await?;
336
337        trace!(target: "net::channel::read_command", "Read magic {magic:?}");
338        let magic_bytes = self.p2p().settings().read().await.magic_bytes.0;
339        if magic != magic_bytes {
340            error!(target: "net::channel::read_command", "Error: Magic bytes mismatch");
341
342            // If it is outbound, ban the host so we don't share it with other nodes
343            if self.session_type_id() & SESSION_OUTBOUND != 0 {
344                if let BanPolicy::Strict = self.p2p().settings().read().await.ban_policy {
345                    self.ban().await;
346                }
347            }
348
349            return Err(Error::MalformedPacket)
350        }
351
352        // First extract the length from the stream
353        let cmd_len = VarInt::decode_async(stream).await?.0;
354        if cmd_len > (MAX_COMMAND_LENGTH as u64) {
355            error!(target: "net::channel::read_command",
356                "Error: Command length ({cmd_len}) exceeds configured limit ({MAX_COMMAND_LENGTH}). Dropping...");
357            return Err(Error::MessageInvalid);
358        }
359
360        // Then extract precisely `cmd_len` items from the stream.
361        let mut take = stream.take(cmd_len);
362
363        // Deserialize into a vector of `cmd_len` size.
364        let mut bytes = vec![0; cmd_len.try_into().unwrap()];
365        take.read_exact(&mut bytes).await?;
366
367        let command = String::from_utf8(bytes)?;
368
369        Ok(command)
370    }
371
372    /// Subscribe to a message on the message subsystem.
373    pub async fn subscribe_msg<M: message::Message>(&self) -> Result<MessageSubscription<M>> {
374        debug!(
375            target: "net::channel::subscribe_msg", "[START] command={} {self:?}",
376            M::NAME
377        );
378
379        let sub = self.message_subsystem.subscribe::<M>().await;
380
381        debug!(
382            target: "net::channel::subscribe_msg", "[END] command={} {self:?}",
383            M::NAME
384        );
385
386        sub
387    }
388
389    /// Handle network errors. Panic if error passes silently, otherwise
390    /// broadcast the error.
391    async fn handle_stop(self: Arc<Self>, result: Result<()>) {
392        debug!(target: "net::channel::handle_stop", "[START] {self:?}");
393
394        self.stopped.store(true, SeqCst);
395
396        match result {
397            Ok(()) => panic!("Channel task should never complete without error status"),
398            // Send this error to all channel subscribers
399            Err(e) => {
400                self.stop_publisher.notify(Error::ChannelStopped).await;
401                self.message_subsystem.trigger_error(e).await;
402            }
403        }
404
405        debug!(target: "net::channel::handle_stop", "[END] {self:?}");
406    }
407
408    /// Run the receive loop. Start receiving messages or handle network failure.
409    async fn main_receive_loop(self: Arc<Self>) -> Result<()> {
410        debug!(target: "net::channel::main_receive_loop", "[START] {self:?}");
411
412        // Acquire reader lock
413        let reader = &mut *self.reader.lock().await;
414
415        // Run loop
416        loop {
417            let command = match self.read_command(reader).await {
418                Ok(command) => command,
419                Err(err) => {
420                    if Self::is_eof_error(&err) {
421                        verbose!(
422                            target: "net::channel::main_receive_loop",
423                            "[P2P] Channel {} disconnected",
424                            self.display_address()
425                        );
426                    } else if let Error::MessageInvalid = err {
427                        // The command name length has exceeded the limit, this is possibly a malicious attack so ban it
428                        if let BanPolicy::Strict = self.p2p().settings().read().await.ban_policy {
429                            self.ban().await;
430                        }
431                    } else if self.session.upgrade().unwrap().type_id() &
432                        (SESSION_ALL & !SESSION_REFINE) !=
433                        0
434                    {
435                        error!(
436                            target: "net::channel::main_receive_loop",
437                            "[P2P] Read error on channel {}: {err}",
438                            self.display_address()
439                        );
440                    }
441
442                    debug!(
443                        target: "net::channel::main_receive_loop",
444                        "Stopping channel {self:?}"
445                    );
446                    return Err(Error::ChannelStopped)
447                }
448            };
449
450            dnetev!(self, RecvMessage, {
451                chan: self.info.clone(),
452                cmd: command.clone(),
453                time: NanoTimestamp::current_time(),
454            });
455
456            // Send result to our publishers
457            match self.message_subsystem.notify(&command, reader).await {
458                Ok(()) => {}
459                Err(Error::MissingDispatcher) |
460                Err(Error::MessageInvalid) |
461                Err(Error::MeteringLimitExceeded) => {
462                    // If we're getting messages without dispatchers or its invalid,
463                    // it's spam. We therefore ban this channel if:
464                    //
465                    // 1) This channel is NOT part of a refine session.
466                    //
467                    // It's possible that nodes can send messages without
468                    // dispatchers during the refinery process. If that happens
469                    // we simply ignore it. Otherwise, it's spam.
470                    //
471                    // 2) BanPolicy is set to Strict.
472                    //
473                    // We only ban if the BanPolicy is set to Strict, which is
474                    // the default setting for most nodes. The exception to
475                    // this is a seed node like Lilith which has BanPolicy::Relaxed
476                    // since it regularly forms connections with nodes sending
477                    // messages it does not have dispatchers for.
478                    if self.session.upgrade().unwrap().type_id() != SESSION_REFINE {
479                        warn!(
480                        target: "net::channel::main_receive_loop",
481                        "MissingDispatcher|MessageInvalid|MeteringLimitExceeded for command={command}, channel={self:?}"
482                        );
483
484                        if let BanPolicy::Strict = self.p2p().settings().read().await.ban_policy {
485                            self.ban().await;
486                        }
487
488                        return Err(Error::ChannelStopped)
489                    }
490                }
491                Err(_) => unreachable!("You added a new error in notify()"),
492            }
493        }
494    }
495
496    /// Ban a malicious peer and stop the channel.
497    pub async fn ban(&self) {
498        debug!(target: "net::channel::ban", "START {self:?}");
499        debug!(target: "net::channel::ban", "Peer: {:?}", self.display_address());
500
501        // Just store the hostname if this is an inbound session.
502        // This will block all ports from this peer by setting
503        // `hosts.block_all_ports()` to true.
504        let peer = {
505            if self.session_type_id() & SESSION_INBOUND != 0 {
506                if self.address().host().is_none() {
507                    error!("[P2P] ban() caught Url without host: {:?}", self.display_address());
508                    return
509                }
510
511                // An inbound Tor connection can't really be banned :)
512                #[cfg(feature = "p2p-tor")]
513                if (self.address().scheme() == "tor" || self.address().scheme() == "tor+tls") &&
514                    self.p2p().hosts().is_local_host(self.address())
515                {
516                    return
517                }
518
519                if self.address().scheme() == "unix" {
520                    return
521                }
522
523                // If we already have a successful connection with this host on another port,
524                // this might indicate a misconfiguration or unintended overlap between separate P2P networks.
525                // To prevent interference, we block only this specific port rather than the entire host.
526                if self.hosts().has_existing_connection(self.address()) {
527                    self.address().clone()
528                } else {
529                    let mut addr = self.address().clone();
530                    addr.set_port(None).unwrap();
531                    addr
532                }
533            } else {
534                self.address().clone()
535            }
536        };
537
538        let last_seen = UNIX_EPOCH.elapsed().unwrap().as_secs();
539        verbose!(target: "net::channel::ban", "Blacklisting peer={peer}");
540        match self.p2p().hosts().move_host(&peer, last_seen, HostColor::Black).await {
541            Ok(()) => {
542                verbose!(target: "net::channel::ban", "Peer={peer} blacklisted successfully");
543            }
544            Err(e) => {
545                warn!(target: "net::channel::ban", "Could not blacklisted peer={peer}, err={e}");
546            }
547        }
548        self.stop().await;
549        debug!(target: "net::channel::ban", "STOP {self:?}");
550    }
551
552    /// Returns the relevant socket address for this connection. If this is
553    /// an outbound connection, the transport-processed resolve_addr will
554    /// be returned except for transport mixed connections, to make sure
555    /// mixed hosts don't enter hostlist.
556    /// Otherwise for inbound connections it will default
557    /// to connect_addr.
558    pub fn address(&self) -> &Url {
559        if !self.info.transport_mixed {
560            if let Some(resolve_addr) = &self.info.resolve_addr {
561                return resolve_addr
562            }
563        }
564        &self.info.connect_addr
565    }
566
567    /// Returns the address used for UI purposes like in logging or tools like dnet.
568    /// For transport_mixed connection shows the mixed address.
569    pub fn display_address(&self) -> &Url {
570        self.info.resolve_addr.as_ref().unwrap_or(&self.info.connect_addr)
571    }
572
573    /// Returns the socket address that has undergone transport
574    /// processing, if it exists. Returns None otherwise.
575    pub fn resolve_addr(&self) -> Option<Url> {
576        self.info.resolve_addr.clone()
577    }
578
579    /// Return the socket address without transport processing.
580    pub fn connect_addr(&self) -> &Url {
581        &self.info.connect_addr
582    }
583
584    /// Set the VersionMessage of the node this channel is connected
585    /// to. Called on receiving a version message in `ProtocolVersion`.
586    pub(crate) async fn set_version(&self, version: Arc<VersionMessage>) {
587        self.version.set(version).await.unwrap();
588    }
589    /// Should only be called after the version exchange has been completed.
590    pub fn get_version(&self) -> Arc<VersionMessage> {
591        self.version.get().unwrap().clone()
592    }
593
594    /// Returns the inner [`MessageSubsystem`] reference
595    pub fn message_subsystem(&self) -> &MessageSubsystem {
596        &self.message_subsystem
597    }
598
599    fn session(&self) -> Arc<dyn Session> {
600        self.session.upgrade().unwrap()
601    }
602
603    pub fn session_type_id(&self) -> SessionBitFlag {
604        let session = self.session();
605        session.type_id()
606    }
607
608    #[inline]
609    pub fn p2p(&self) -> P2pPtr {
610        self.session().p2p()
611    }
612    #[inline]
613    pub fn hosts(&self) -> HostsPtr {
614        self.p2p().hosts()
615    }
616
617    fn is_eof_error(err: &Error) -> bool {
618        match err {
619            Error::Io(ioerr) => ioerr == &std::io::ErrorKind::UnexpectedEof,
620            _ => false,
621        }
622    }
623}
624
625impl fmt::Debug for Channel {
626    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
627        write!(f, "<Channel addr='{}' id={}>", self.display_address(), self.info.id)
628    }
629}