1use std::{clone::Clone, collections::HashMap, fmt::Debug, sync::Arc};
20
21use async_trait::async_trait;
22use smol::{
23 channel::{Receiver, Sender},
24 lock::RwLock,
25 Executor,
26};
27use tracing::debug;
28
29use super::{
30 super::{
31 channel::ChannelPtr, message::Message, message_publisher::MessageSubscription,
32 session::SessionBitFlag,
33 },
34 protocol_base::{ProtocolBase, ProtocolBasePtr},
35 protocol_jobs_manager::{ProtocolJobsManager, ProtocolJobsManagerPtr},
36 P2pPtr,
37};
38use crate::{
39 system::{StoppableTask, StoppableTaskPtr},
40 Error, Result,
41};
42
43#[derive(Debug)]
45pub enum ProtocolGenericAction<M> {
46 Broadcast,
48 Response(M),
50 Skip,
52 Stop,
54}
55
56pub type ProtocolGenericHandlerPtr<M, R> = Arc<ProtocolGenericHandler<M, R>>;
57
58pub struct ProtocolGenericHandler<M: Message + Clone, R: Message + Clone + Debug> {
63 sender: Sender<(u32, M)>,
68 pub receiver: Receiver<(u32, M)>,
71 senders: RwLock<HashMap<u32, Sender<ProtocolGenericAction<R>>>>,
74 pub task: StoppableTaskPtr,
77}
78
79impl<M: Message + Clone, R: Message + Clone + Debug> ProtocolGenericHandler<M, R> {
80 pub async fn new(
83 p2p: &P2pPtr,
84 name: &'static str,
85 session: SessionBitFlag,
86 ) -> ProtocolGenericHandlerPtr<M, R> {
87 let (sender, receiver) = smol::channel::unbounded::<(u32, M)>();
89
90 let senders = RwLock::new(HashMap::new());
92
93 let task = StoppableTask::new();
95
96 let handler = Arc::new(Self { sender, receiver, senders, task });
98
99 let _handler = handler.clone();
101 p2p.protocol_registry()
102 .register(session, move |channel, p2p| {
103 let handler = _handler.clone();
104 async move { ProtocolGeneric::init(channel, name, handler, p2p).await.unwrap() }
105 })
106 .await;
107
108 handler
109 }
110
111 async fn register_channel_sender(
114 &self,
115 channel: u32,
116 sender: Sender<ProtocolGenericAction<R>>,
117 ) {
118 let mut lock = self.senders.write().await;
120 lock.insert(channel, sender);
121
122 let mut stale = vec![];
124 for (channel, sender) in lock.iter() {
125 if sender.is_closed() {
126 stale.push(*channel);
127 }
128 }
129
130 for channel in stale {
132 lock.remove(&channel);
133 }
134
135 drop(lock);
136 }
137
138 pub async fn send_action(&self, channel: u32, action: ProtocolGenericAction<R>) {
140 debug!(
141 target: "net::protocol_generic::ProtocolGenericHandler::send_action",
142 "Sending action {action:?} to channel {channel}..."
143 );
144
145 let mut lock = self.senders.write().await;
147 let Some(sender) = lock.get(&channel) else {
148 debug!(
149 target: "net::protocol_generic::ProtocolGenericHandler::send_action",
150 "Channel wasn't found."
151 );
152
153 drop(lock);
154 return
155 };
156
157 if let Err(e) = sender.send(action).await {
159 debug!(
160 target: "net::protocol_generic::ProtocolGenericHandler::send_action",
161 "Channel {channel} send fail: {e}"
162 );
163 lock.remove(&channel);
164 };
165
166 drop(lock);
167 }
168}
169
170pub struct ProtocolGeneric<M: Message + Clone, R: Message + Clone + Debug> {
172 msg_sub: MessageSubscription<M>,
174 sender: Sender<(u32, M)>,
176 receiver: Receiver<ProtocolGenericAction<R>>,
178 channel: ChannelPtr,
180 p2p: P2pPtr,
182 jobsman: ProtocolJobsManagerPtr,
184}
185
186impl<M: Message + Clone, R: Message + Clone + Debug> ProtocolGeneric<M, R> {
187 pub async fn init(
189 channel: ChannelPtr,
190 name: &'static str,
191 handler: ProtocolGenericHandlerPtr<M, R>,
192 p2p: P2pPtr,
193 ) -> Result<ProtocolBasePtr> {
194 debug!(
195 target: "net::protocol_generic::init",
196 "Adding generic protocol for message {name} to the protocol registry"
197 );
198
199 let msg_subsystem = channel.message_subsystem();
201 msg_subsystem.add_dispatch::<M>().await;
202 msg_subsystem.add_dispatch::<R>().await;
203
204 let msg_sub = channel.subscribe_msg::<M>().await?;
206
207 let (action_sender, receiver) = smol::channel::bounded(1);
209 handler.register_channel_sender(channel.info.id, action_sender).await;
210
211 Ok(Arc::new(Self {
212 msg_sub,
213 sender: handler.sender.clone(),
214 receiver,
215 channel: channel.clone(),
216 p2p,
217 jobsman: ProtocolJobsManager::new(name, channel),
218 }))
219 }
220
221 async fn handle_receive_message(self: Arc<Self>) -> Result<()> {
226 debug!(
227 target: "net::protocol_generic::handle_receive_message",
228 "START"
229 );
230 let exclude_list = vec![self.channel.address().clone()];
231
232 loop {
233 let msg = match self.msg_sub.receive().await {
235 Ok(m) => m,
236 Err(e) => {
237 debug!(
238 target: "net::protocol_generic::handle_receive_message",
239 "[{}] recv fail: {e}", self.jobsman.clone().name()
240 );
241 continue
242 }
243 };
244
245 let msg_copy = (*msg).clone();
246
247 if let Err(e) = self.sender.send((self.channel.info.id, msg_copy.clone())).await {
249 debug!(
250 target: "net::protocol_generic::handle_receive_message",
251 "[{}] sending to channel fail: {e}", self.jobsman.clone().name()
252 );
253 continue
254 }
255
256 let action = match self.receiver.recv().await {
258 Ok(a) => a,
259 Err(e) => {
260 debug!(
261 target: "net::protocol_generic::handle_receive_message",
262 "[{}] action signal recv fail: {e}", self.jobsman.clone().name()
263 );
264 continue
265 }
266 };
267
268 match action {
270 ProtocolGenericAction::Broadcast => {
271 self.p2p.broadcast_with_exclude(&msg_copy, &exclude_list).await
272 }
273 ProtocolGenericAction::Response(r) => {
274 if let Err(e) = self.channel.send(&r).await {
275 debug!(
276 target: "net::protocol_generic::handle_receive_message",
277 "[{}] Channel send fail: {e}", self.jobsman.clone().name()
278 )
279 };
280 }
281 ProtocolGenericAction::Skip => {
282 debug!(
283 target: "net::protocol_generic::handle_receive_message",
284 "[{}] Skip action signal received.", self.jobsman.clone().name()
285 );
286 }
287 ProtocolGenericAction::Stop => {
288 self.channel.stop().await;
289 return Err(Error::ChannelStopped)
290 }
291 }
292 }
293 }
294}
295
296#[async_trait]
297impl<M: Message + Clone, R: Message + Clone + Debug> ProtocolBase for ProtocolGeneric<M, R> {
298 async fn start(self: Arc<Self>, ex: Arc<Executor<'_>>) -> Result<()> {
299 debug!(target: "net::protocol_generic::start", "START");
300 self.jobsman.clone().start(ex.clone());
301 self.jobsman.clone().spawn(self.clone().handle_receive_message(), ex).await;
302 debug!(target: "net::protocol_generic::start", "END");
303 Ok(())
304 }
305
306 fn name(&self) -> &'static str {
307 self.jobsman.clone().name()
308 }
309}