1use std::{
20 cmp::Eq,
21 collections::{HashMap, HashSet},
22 fmt::Debug,
23 hash::Hash,
24 marker::{Send, Sync},
25 sync::{Arc, Weak},
26};
27
28use futures::stream::FuturesUnordered;
29use num_bigint::BigUint;
30use smol::{
31 channel,
32 lock::{Mutex, RwLock, Semaphore},
33 stream::StreamExt,
34};
35use tracing::{info, warn};
36use url::Url;
37
38use crate::{
39 dht::event::DhtEvent,
40 net::{
41 connector::Connector,
42 session::{SESSION_DIRECT, SESSION_MANUAL},
43 ChannelPtr, Message, P2pPtr,
44 },
45 system::{msleep, ExecutorPtr, Publisher, PublisherPtr, Subscription},
46 util::time::Timestamp,
47 Error, Result,
48};
49
50pub mod settings;
51pub use settings::{DhtSettings, DhtSettingsOpt};
52
53pub mod handler;
54pub use handler::DhtHandler;
55
56pub mod tasks;
57
58pub mod event;
59
60pub trait DhtNode: Debug + Clone + Send + Sync + PartialEq + Eq + Hash {
61 fn id(&self) -> blake3::Hash;
62 fn addresses(&self) -> Vec<Url>;
63}
64
65#[macro_export]
67macro_rules! impl_dht_node_defaults {
68 ($t:ty) => {
69 impl std::hash::Hash for $t {
70 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
71 self.id().hash(state);
72 }
73 }
74 impl std::cmp::PartialEq for $t {
75 fn eq(&self, other: &Self) -> bool {
76 self.id() == other.id()
77 }
78 }
79 impl std::cmp::Eq for $t {}
80 };
81}
82pub use impl_dht_node_defaults;
83
84enum DhtLookupType {
85 Nodes,
86 Value,
87}
88
89pub enum DhtLookupReply<N: DhtNode, V> {
90 Nodes(Vec<N>),
91 Value(V),
92 NodesAndValue(Vec<N>, V),
93}
94
95pub struct DhtBucket<N: DhtNode> {
96 pub nodes: Vec<N>,
97}
98
99pub type DhtHashTable<V> = Arc<RwLock<HashMap<blake3::Hash, V>>>;
101
102type PingLock<N> = Arc<Mutex<Option<Result<N>>>>;
103
104#[derive(Clone, Debug)]
105pub struct ChannelCacheItem<N: DhtNode> {
106 pub node: Option<N>,
108 pub last_used: Timestamp,
111 pub ping_received: bool,
113 pub ping_sent: bool,
115}
116
117#[derive(Clone, Debug)]
118pub struct HostCacheItem {
119 pub last_ping: Timestamp,
121 pub node_id: blake3::Hash,
123}
124
125pub struct Dht<H: DhtHandler> {
126 pub handler: RwLock<Weak<H>>,
128 pub bootstrapped: Arc<RwLock<bool>>,
130 pub buckets: Arc<RwLock<Vec<DhtBucket<H::Node>>>>,
132 pub hash_table: DhtHashTable<H::Value>,
134 pub n_buckets: usize,
136 pub channel_cache: Arc<RwLock<HashMap<u32, ChannelCacheItem<H::Node>>>>,
138 pub host_cache: Arc<RwLock<HashMap<Url, HostCacheItem>>>,
140 ping_locks: Arc<Mutex<HashMap<u32, PingLock<H::Node>>>>,
142 pub add_node_tx: channel::Sender<(H::Node, ChannelPtr)>,
144 pub add_node_rx: channel::Receiver<(H::Node, ChannelPtr)>,
146 pub settings: DhtSettings,
148 pub event_publisher: PublisherPtr<DhtEvent<H::Node, H::Value>>,
150 pub p2p: P2pPtr,
152 pub connector: Connector,
154 pub executor: ExecutorPtr,
156}
157
158impl<H: DhtHandler> Dht<H> {
159 pub async fn new(settings: &DhtSettings, p2p: P2pPtr, ex: ExecutorPtr) -> Self {
160 let mut buckets = vec![];
162 for _ in 0..256 {
163 buckets.push(DhtBucket { nodes: vec![] })
164 }
165
166 let (add_node_tx, add_node_rx) = smol::channel::unbounded();
167
168 let session_weak = Arc::downgrade(&p2p.session_manual());
169 let connector = Connector::new(p2p.settings(), session_weak);
170
171 Self {
172 handler: RwLock::new(Weak::new()),
173 buckets: Arc::new(RwLock::new(buckets)),
174 hash_table: Arc::new(RwLock::new(HashMap::new())),
175 n_buckets: 256,
176 bootstrapped: Arc::new(RwLock::new(false)),
177 channel_cache: Arc::new(RwLock::new(HashMap::new())),
178 host_cache: Arc::new(RwLock::new(HashMap::new())),
179 ping_locks: Arc::new(Mutex::new(HashMap::new())),
180 add_node_tx,
181 add_node_rx,
182
183 event_publisher: Publisher::new(),
184
185 settings: settings.clone(),
186
187 p2p: p2p.clone(),
188 connector,
189 executor: ex,
190 }
191 }
192
193 pub async fn handler(&self) -> Arc<H> {
194 self.handler.read().await.upgrade().unwrap()
195 }
196
197 pub async fn is_bootstrapped(&self) -> bool {
198 let bootstrapped = self.bootstrapped.read().await;
199 *bootstrapped
200 }
201
202 pub async fn set_bootstrapped(&self, value: bool) {
203 let mut bootstrapped = self.bootstrapped.write().await;
204 *bootstrapped = value;
205 }
206
207 pub async fn subscribe(&self) -> Subscription<DhtEvent<H::Node, H::Value>> {
208 self.event_publisher.clone().subscribe().await
209 }
210
211 pub fn distance(&self, key_1: &blake3::Hash, key_2: &blake3::Hash) -> [u8; 32] {
213 let bytes1 = key_1.as_bytes();
214 let bytes2 = key_2.as_bytes();
215
216 let mut result_bytes = [0u8; 32];
217
218 for i in 0..32 {
219 result_bytes[i] = bytes1[i] ^ bytes2[i];
220 }
221
222 result_bytes
223 }
224
225 pub fn sort_by_distance(&self, nodes: &mut [H::Node], key: &blake3::Hash) {
227 nodes.sort_by(|a, b| {
228 let distance_a = BigUint::from_bytes_be(&self.distance(key, &a.id()));
229 let distance_b = BigUint::from_bytes_be(&self.distance(key, &b.id()));
230 distance_a.cmp(&distance_b)
231 });
232 }
233
234 pub async fn get_bucket_index(&self, self_node_id: &blake3::Hash, key: &blake3::Hash) -> usize {
236 if key == self_node_id {
237 return 0;
238 }
239 let distance = self.distance(self_node_id, key);
240 let mut leading_zeros = 0;
241
242 for &byte in &distance {
243 if byte == 0 {
244 leading_zeros += 8;
245 } else {
246 leading_zeros += byte.leading_zeros() as usize;
247 break;
248 }
249 }
250
251 let bucket_index = self.n_buckets - leading_zeros;
252 std::cmp::min(bucket_index, self.n_buckets - 1)
253 }
254
255 pub async fn find_neighbors(&self, key: &blake3::Hash, n: usize) -> Vec<H::Node> {
258 let buckets_lock = self.buckets.clone();
259 let buckets = buckets_lock.read().await;
260
261 let mut neighbors = Vec::new();
262
263 for i in 0..self.n_buckets {
264 if let Some(bucket) = buckets.get(i) {
265 neighbors.extend(bucket.nodes.iter().cloned());
266 }
267 }
268
269 self.sort_by_distance(&mut neighbors, key);
270
271 neighbors.truncate(n);
272
273 neighbors
274 }
275
276 pub async fn get_node_from_channel(&self, channel_id: u32) -> Option<H::Node> {
278 let channel_cache_lock = self.channel_cache.clone();
279 let channel_cache = channel_cache_lock.read().await;
280 if let Some(cached) = channel_cache.get(&channel_id) {
281 return cached.node.clone();
282 }
283
284 None
285 }
286
287 pub async fn reset(&self) {
289 let mut bootstrapped = self.bootstrapped.write().await;
290 *bootstrapped = false;
291
292 let mut buckets = vec![];
293 for _ in 0..256 {
294 buckets.push(DhtBucket { nodes: vec![] })
295 }
296
297 *self.buckets.write().await = buckets;
298 *self.hash_table.write().await = HashMap::new();
299 }
300
301 pub async fn announce<M: Message>(
303 &self,
304 key: &blake3::Hash,
305 value: &H::Value,
306 message: &M,
307 ) -> Result<()> {
308 let self_node = self.handler().await.node().await?;
309 if self_node.addresses().is_empty() {
310 return Err(().into()); }
312
313 self.handler().await.add_value(key, value).await;
314 let nodes = self.lookup_nodes(key).await;
315 info!(target: "dht::announce", "[DHT] Announcing {} to {} nodes", H::key_to_string(key), nodes.len());
316
317 for node in nodes {
318 if let Ok((channel, _)) = self.get_channel(&node).await {
319 let _ = channel.send(message).await;
320 self.cleanup_channel(channel).await;
321 }
322 }
323
324 Ok(())
325 }
326
327 pub async fn bootstrap(&self) {
329 let self_node = self.handler().await.node().await;
330 if self_node.is_err() {
331 return;
332 }
333 let self_node = self_node.unwrap();
334
335 self.set_bootstrapped(true).await;
336
337 info!(target: "dht::bootstrap", "[DHT] Bootstrapping");
338 self.event_publisher.notify(DhtEvent::BootstrapStarted).await;
339
340 let _nodes = self.lookup_nodes(&self_node.id()).await;
341
342 self.event_publisher.notify(DhtEvent::BootstrapCompleted).await;
348 }
349
350 async fn on_new_node(&self, node: &H::Node, channel: ChannelPtr) {
352 info!(target: "dht::on_new_node", "[DHT] Found new node {}", H::key_to_string(&node.id()));
353
354 if !self.is_bootstrapped().await {
356 self.bootstrap().await;
357 }
358
359 let self_node = self.handler().await.node().await;
361 if self_node.is_err() {
362 return;
363 }
364 let self_id = self_node.unwrap().id();
365 for (key, value) in self.hash_table.read().await.iter() {
366 let node_distance = BigUint::from_bytes_be(&self.distance(key, &node.id()));
367 let self_distance = BigUint::from_bytes_be(&self.distance(key, &self_id));
368 if node_distance <= self_distance {
369 let _ = self.handler().await.store(channel.clone(), key, value).await;
370 }
371 }
372 }
373
374 pub async fn update_node(&self, node: &H::Node, channel: ChannelPtr) {
378 self.p2p.session_direct().inc_channel_usage(&channel, 1).await;
379 if let Err(e) = self.add_node_tx.send((node.clone(), channel.clone())).await {
380 warn!(target: "dht::update_node", "[DHT] Cannot add node {}: {e}", H::key_to_string(&node.id()))
381 }
382 }
383
384 pub async fn remove_node(&self, node_id: &blake3::Hash) {
386 let handler = self.handler().await;
387 let self_node = handler.node().await;
388 if self_node.is_err() {
389 return;
390 }
391 let bucket_index = handler.dht().get_bucket_index(&self_node.unwrap().id(), node_id).await;
392 let buckets_lock = handler.dht().buckets.clone();
393 let mut buckets = buckets_lock.write().await;
394 let bucket = &mut buckets[bucket_index];
395 bucket.nodes.retain(|node| node.id() != *node_id);
396 }
397
398 pub async fn ping(&self, channel: ChannelPtr) -> Result<H::Node> {
401 let lock_map = self.ping_locks.clone();
402 let mut locks = lock_map.lock().await;
403
404 let lock = if let Some(lock) = locks.get(&channel.info.id) {
406 lock.clone()
407 } else {
408 let lock = Arc::new(Mutex::new(None));
409 locks.insert(channel.info.id, lock.clone());
410 lock
411 };
412 drop(locks);
413
414 let mut result = lock.lock().await;
416
417 if let Some(res) = result.clone() {
418 return res
419 }
420
421 let ping_result = self.handler().await.ping(channel.clone()).await;
423 *result = Some(ping_result.clone());
424 ping_result
425 }
426
427 async fn lookup(
429 &self,
430 key: blake3::Hash,
431 lookup_type: DhtLookupType,
432 ) -> (Vec<H::Node>, Vec<H::Value>) {
433 let net_settings = self.p2p.settings().read_arc().await;
434 let active_profiles = net_settings.active_profiles.clone();
435 drop(net_settings);
436 let external_addrs = self.p2p.hosts().external_addrs().await;
437
438 let (k, a) = (self.settings.k, self.settings.alpha);
439 let semaphore = Arc::new(Semaphore::new(self.settings.concurrency));
440 let queried_addrs = Arc::new(Mutex::new(HashSet::new()));
441 let mut seen_nodes = HashSet::new();
442 let mut nodes_to_visit = self.find_neighbors(&key, k).await;
443 let mut result = Vec::new();
444 let mut futures = FuturesUnordered::new();
445 let mut consecutive_stalls = 0;
446
447 let mut values = Vec::new();
448
449 let distance_check = |(furthest, next): (&H::Node, &H::Node)| {
450 BigUint::from_bytes_be(&self.distance(&key, &furthest.id())) <
451 BigUint::from_bytes_be(&self.distance(&key, &next.id()))
452 };
453
454 let lookup = async |node: H::Node, key, addrs: Vec<Url>| {
457 let _permit = semaphore.acquire().await;
458
459 let mut last_err = None;
461 for addr in addrs {
462 let mut queried_addrs_set = queried_addrs.lock().await;
463 if queried_addrs_set.contains(&addr) {
465 continue;
466 }
467 queried_addrs_set.insert(addr.clone());
468 drop(queried_addrs_set);
469
470 let channel = self.create_channel(&addr).await.map(|(ch, _)| ch);
472
473 if let Err(e) = channel {
474 last_err = Some(e);
475 continue
476 }
477 let channel = channel.unwrap();
478
479 let handler = self.handler().await;
480 let res = match &lookup_type {
481 DhtLookupType::Nodes => {
482 info!(target: "dht::lookup", "[DHT] [LOOKUP] Querying node {} for nodes lookup of key {}", H::key_to_string(&node.id()), H::key_to_string(key));
483 handler.find_nodes(channel.clone(), key).await.map(DhtLookupReply::Nodes)
484 }
485 DhtLookupType::Value => {
486 info!(target: "dht::lookup", "[DHT] [LOOKUP] Querying node {} for value lookup of key {}", H::key_to_string(&node.id()), H::key_to_string(key));
487 handler.find_value(channel.clone(), key).await
488 }
489 };
490
491 self.cleanup_channel(channel).await;
492 if res.is_ok() {
493 return (node, res)
494 }
495 last_err = res.err();
496 }
497 if let Some(e) = last_err {
498 return (node, Err(e))
499 }
500
501 (node, Err(Error::Custom("All node's addresses failed".to_string())))
502 };
503
504 let spawn_futures = async |nodes_to_visit: &mut Vec<H::Node>,
506 futures: &mut FuturesUnordered<_>| {
507 for _ in 0..a {
508 if !nodes_to_visit.is_empty() {
509 let node = nodes_to_visit.remove(0);
510 let valid_addrs: Vec<Url> = node
511 .addresses()
512 .iter()
513 .filter(|addr| {
514 active_profiles.contains(&addr.scheme().to_string()) &&
515 !external_addrs.contains(addr)
516 })
517 .cloned()
518 .collect();
519 if !valid_addrs.is_empty() {
520 futures.push(Box::pin(lookup(node, &key, valid_addrs)));
521 }
522 }
523 }
524 };
525
526 spawn_futures(&mut nodes_to_visit, &mut futures).await;
528
529 while let Some((queried_node, res)) = futures.next().await {
531 if let Err(e) = res {
532 warn!(target: "dht::lookup", "[DHT] [LOOKUP] Error in lookup: {e}");
533
534 if futures.is_empty() {
537 spawn_futures(&mut nodes_to_visit, &mut futures).await;
538 }
539
540 continue;
541 }
542
543 let (nodes, value) = match res.unwrap() {
544 DhtLookupReply::Nodes(nodes) => (Some(nodes), None),
545 DhtLookupReply::Value(value) => (None, Some(value)),
546 DhtLookupReply::NodesAndValue(nodes, value) => (Some(nodes), Some(value)),
547 };
548
549 if let Some(value) = value {
551 info!(target: "dht::lookup", "[DHT] [LOOKUP] Found value for {} from {}", H::key_to_string(&key), H::key_to_string(&queried_node.id()));
552 values.push(value.clone());
553 self.event_publisher.notify(DhtEvent::ValueFound { key, value }).await;
554 }
555
556 if let Some(mut nodes) = nodes {
558 if !nodes.is_empty() {
559 info!(target: "dht::lookup", "[DHT] [LOOKUP] Found {} nodes from {}", nodes.len(), H::key_to_string(&queried_node.id()));
560
561 self.event_publisher
562 .notify(DhtEvent::NodesFound { key, nodes: nodes.clone() })
563 .await;
564
565 if let Ok(self_node) = self.handler().await.node().await {
567 let self_id = self_node.id();
568 nodes.retain(|node: &H::Node| {
569 node.id() != self_id && seen_nodes.insert(node.id())
570 });
571 }
572
573 nodes_to_visit.extend(nodes.clone());
575 self.sort_by_distance(&mut nodes_to_visit, &key);
576 }
577 }
578
579 result.push(queried_node);
580 self.sort_by_distance(&mut result, &key);
581
582 if result.len() >= k &&
586 result.last().zip(nodes_to_visit.first()).is_some_and(distance_check)
587 {
588 consecutive_stalls += 1;
589 if consecutive_stalls >= 3 {
590 break;
591 }
592 } else {
593 consecutive_stalls = 0;
594 }
595
596 spawn_futures(&mut nodes_to_visit, &mut futures).await;
598 }
599
600 info!(target: "dht::lookup", "[DHT] [LOOKUP] Lookup for {} completed", H::key_to_string(&key));
601
602 let nodes: Vec<_> = result.into_iter().take(k).collect();
603 (nodes, values)
604 }
605
606 pub async fn lookup_nodes(&self, key: &blake3::Hash) -> Vec<H::Node> {
608 info!(target: "dht::lookup_nodes", "[DHT] [LOOKUP] Starting node lookup for key {}", H::key_to_string(key));
609
610 self.event_publisher.notify(DhtEvent::NodesLookupStarted { key: *key }).await;
611
612 let (nodes, _) = self.lookup(*key, DhtLookupType::Nodes).await;
613
614 self.event_publisher
615 .notify(DhtEvent::NodesLookupCompleted { key: *key, nodes: nodes.clone() })
616 .await;
617
618 nodes
619 }
620
621 pub async fn lookup_value(&self, key: &blake3::Hash) -> (Vec<H::Node>, Vec<H::Value>) {
623 info!(target: "dht::lookup_value", "[DHT] [LOOKUP] Starting value lookup for key {}", H::key_to_string(key));
624
625 self.event_publisher.notify(DhtEvent::ValueLookupStarted { key: *key }).await;
626
627 let (nodes, values) = self.lookup(*key, DhtLookupType::Value).await;
628
629 self.event_publisher
630 .notify(DhtEvent::ValueLookupCompleted {
631 key: *key,
632 nodes: nodes.clone(),
633 values: values.clone(),
634 })
635 .await;
636
637 (nodes, values)
638 }
639
640 pub async fn update_channel(&self, channel_id: u32) {
642 let channel_cache_lock = self.channel_cache.clone();
643 let mut channel_cache = channel_cache_lock.write().await;
644
645 if let Some(cached) = channel_cache.get_mut(&channel_id) {
646 cached.last_used = Timestamp::current_time();
647 }
648 }
649
650 pub async fn get_channel(&self, node: &H::Node) -> Result<(ChannelPtr, H::Node)> {
653 let node_id = node.id();
654
655 let channel_cache = self.channel_cache.read().await.clone();
659 if let Some((channel_id, cached)) = channel_cache
660 .clone()
661 .iter()
662 .find(|(_, cached)| cached.node.clone().is_some_and(|n| n.id() == node_id))
663 {
664 if let Some(channel) = self.p2p.get_channel(*channel_id) {
665 if channel.session_type_id() & SESSION_DIRECT == 0 {
666 if channel.is_stopped() {
667 self.cleanup_channel(channel).await;
668 } else {
669 return Ok((channel, cached.node.clone().unwrap()))
670 }
671 }
672 }
673 }
674
675 self.create_channel_to_node(node).await
676 }
677
678 pub async fn create_channel(&self, addr: &Url) -> Result<(ChannelPtr, H::Node)> {
681 let external_addrs = self.p2p.hosts().external_addrs().await;
682 if external_addrs.contains(addr) {
683 return Err(Error::Custom(
684 "Can't create a channel to our own external address".to_string(),
685 ))
686 }
687
688 let channel = self.p2p.session_direct().get_channel(addr).await?;
689 let channel_cache = self.channel_cache.read().await;
690 if let Some(cached) = channel_cache.get(&channel.info.id) {
691 if let Some(node) = &cached.node {
692 return Ok((channel, node.clone()))
693 }
694 }
695 drop(channel_cache);
696
697 let node = self.ping(channel.clone()).await;
698 if let Err(e) = node {
700 self.cleanup_channel(channel).await;
701 return Err(e);
702 }
703 let node = node.unwrap();
704 self.add_channel_to_cache(channel.info.id, &node).await;
705 Ok((channel, node))
706 }
707
708 pub async fn create_channel_to_node(&self, node: &H::Node) -> Result<(ChannelPtr, H::Node)> {
709 let net_settings = self.p2p.settings().read_arc().await;
710 let active_profiles = net_settings.active_profiles.clone();
711 drop(net_settings);
712
713 let mut addrs = node.addresses().clone();
715 addrs.retain(|addr| active_profiles.contains(&addr.scheme().to_string()));
716 for addr in addrs {
717 let res = self.create_channel(&addr).await;
718
719 if res.is_err() {
720 continue;
721 }
722
723 let (channel, node) = res.unwrap();
724 return Ok((channel, node));
725 }
726
727 Err(Error::Custom("Could not create channel".to_string()))
728 }
729
730 pub async fn add_channel_to_cache(&self, channel_id: u32, node: &H::Node) {
733 let mut channel_cache = self.channel_cache.write().await;
734 channel_cache
735 .entry(channel_id)
736 .and_modify(|c| c.last_used = Timestamp::current_time())
737 .or_insert(ChannelCacheItem {
738 node: Some(node.clone()),
739 last_used: Timestamp::current_time(),
740 ping_received: false,
741 ping_sent: false,
742 });
743 }
744
745 pub async fn wait_fully_pinged(&self, channel_id: u32) -> Result<()> {
747 loop {
748 let channel_cache = self.channel_cache.read().await;
749 let cached = channel_cache
750 .get(&channel_id)
751 .ok_or(Error::Custom("Missing channel".to_string()))?;
752 if cached.ping_received && cached.ping_sent {
753 return Ok(())
754 }
755 drop(channel_cache);
756 msleep(100).await;
757 }
758 }
759
760 pub async fn cleanup_channel(&self, channel: ChannelPtr) {
762 let channel_cache_lock = self.channel_cache.clone();
763 let mut channel_cache = channel_cache_lock.write().await;
764 let mut ping_locks = self.ping_locks.lock().await;
765 if self.p2p.session_direct().cleanup_channel(channel.clone()).await {
766 channel_cache.remove(&channel.info.id);
767 ping_locks.remove(&channel.info.id);
768 }
769 }
770}