1use std::{collections::HashSet, io::ErrorKind, sync::Arc};
20
21use async_trait::async_trait;
22use smol::{
23 io::{BufReader, ReadHalf, WriteHalf},
24 lock::{Mutex, MutexGuard},
25};
26use tinyjson::JsonValue;
27use tracing::{debug, info, warn};
28use url::Url;
29
30use super::{
31 common::{
32 http_read_from_stream_request, http_write_to_stream, read_from_stream, write_to_stream,
33 INIT_BUF_SIZE,
34 },
35 jsonrpc::*,
36 settings::RpcSettings,
37};
38use crate::{
39 net::transport::{Listener, PtListener, PtStream},
40 system::{StoppableTask, StoppableTaskPtr},
41 util::logger::verbose,
42 Error, Result,
43};
44
45#[async_trait]
47pub trait RequestHandler<T>: Sync + Send {
48 async fn handle_request(&self, req: JsonRequest) -> JsonResult;
49
50 async fn pong(&self, id: u16, _params: JsonValue) -> JsonResult {
51 JsonResponse::new(JsonValue::String("pong".to_string()), id).into()
52 }
53
54 async fn connections_mut(&self) -> MutexGuard<'life0, HashSet<StoppableTaskPtr>>;
55
56 async fn connections(&self) -> Vec<StoppableTaskPtr> {
57 self.connections_mut().await.iter().cloned().collect()
58 }
59
60 async fn mark_connection(&self, task: StoppableTaskPtr) {
61 self.connections_mut().await.insert(task);
62 }
63
64 async fn unmark_connection(&self, task: StoppableTaskPtr) {
65 self.connections_mut().await.remove(&task);
66 }
67
68 async fn active_connections(&self) -> usize {
69 self.connections_mut().await.len()
70 }
71
72 async fn stop_connections(&self) {
73 info!(target: "rpc::server", "[RPC] Server stopped, closing connections");
74 for (i, task) in self.connections().await.iter().enumerate() {
75 debug!(target: "rpc::server", "Stopping connection #{i}");
76 task.stop().await;
77 }
78 }
79}
80
81async fn handle_request<T>(
83 writer: Arc<Mutex<WriteHalf<Box<dyn PtStream>>>>,
84 addr: Url,
85 rh: Arc<impl RequestHandler<T> + 'static>,
86 ex: Arc<smol::Executor<'_>>,
87 tasks: Arc<Mutex<HashSet<Arc<StoppableTask>>>>,
88 settings: RpcSettings,
89 req: JsonRequest,
90) -> Result<()> {
91 let rep = if settings.is_method_disabled(&req.method) {
93 debug!(target: "rpc::server", "RPC method {} is disabled", req.method);
94 JsonError::new(ErrorCode::MethodNotFound, None, req.id).into()
95 } else {
96 rh.handle_request(req).await
97 };
98
99 match rep {
100 JsonResult::Subscriber(subscriber) => {
101 let task = StoppableTask::new();
102
103 let task_ = task.clone();
105 let addr_ = addr.clone();
106 let tasks_ = tasks.clone();
107 let writer_ = writer.clone();
108
109 task.clone().start(
111 async move {
112 let subscription = subscriber.publisher.subscribe().await;
114 loop {
115 let notification = subscription.receive().await;
117
118 debug!(target: "rpc::server", "{addr_} <-- {}", notification.stringify().unwrap());
120 let notification = JsonResult::Notification(notification);
121
122 let mut writer_lock = writer_.lock().await;
123
124 #[allow(clippy::collapsible_else_if)]
125 if settings.use_http() {
126 if let Err(e) = http_write_to_stream(&mut writer_lock, ¬ification).await {
127 subscription.unsubscribe().await;
128 return Err(e.into())
129 }
130 } else {
131 if let Err(e) = write_to_stream(&mut writer_lock, ¬ification).await {
132 subscription.unsubscribe().await;
133 return Err(e.into())
134 }
135 }
136
137 drop(writer_lock);
138 }
139 },
140 move |_| async move {
141 debug!(
142 target: "rpc::server",
143 "Removing background task {} from map", task_.task_id,
144 );
145 tasks_.lock().await.remove(&task_);
146 },
147 Error::DetachedTaskStopped,
148 ex.clone(),
149 );
150
151 debug!(target: "rpc::server", "Adding background task {} to map", task.task_id);
152 tasks.lock().await.insert(task);
153 }
154
155 JsonResult::SubscriberWithReply(subscriber, reply) => {
156 debug!(target: "rpc::server", "{addr} <-- {}", reply.stringify()?);
158 let mut writer_lock = writer.lock().await;
159 if settings.use_http() {
160 http_write_to_stream(&mut writer_lock, &reply.into()).await?;
161 } else {
162 write_to_stream(&mut writer_lock, &reply.into()).await?;
163 }
164 drop(writer_lock);
165
166 let task = StoppableTask::new();
167 let task_ = task.clone();
169 let addr_ = addr.clone();
170 let tasks_ = tasks.clone();
171 let writer_ = writer.clone();
172
173 task.clone().start(
175 async move {
176 let subscription = subscriber.publisher.subscribe().await;
178 loop {
179 let notification = subscription.receive().await;
181
182 debug!(target: "rpc::server", "{addr_} <-- {}", notification.stringify().unwrap());
184 let notification = JsonResult::Notification(notification);
185
186 let mut writer_lock = writer_.lock().await;
187 #[allow(clippy::collapsible_else_if)]
188 if settings.use_http() {
189 if let Err(e) = http_write_to_stream(&mut writer_lock, ¬ification).await {
190 subscription.unsubscribe().await;
191 drop(writer_lock);
192 return Err(e.into())
193 }
194 } else {
195 if let Err(e) = write_to_stream(&mut writer_lock, ¬ification).await {
196 subscription.unsubscribe().await;
197 drop(writer_lock);
198 return Err(e.into())
199 }
200 }
201 drop(writer_lock);
202 }
203 },
204 move |_| async move {
205 debug!(
206 target: "rpc::server",
207 "Removing background task {} from map", task_.task_id,
208 );
209 tasks_.lock().await.remove(&task_);
210 },
211 Error::DetachedTaskStopped,
212 ex.clone(),
213 );
214
215 debug!(target: "rpc::server", "Adding background task {} to map", task.task_id);
216 tasks.lock().await.insert(task);
217 }
218
219 JsonResult::Request(_) | JsonResult::Notification(_) => {
220 unreachable!("Should never happen")
221 }
222
223 JsonResult::Response(ref v) => {
224 debug!(target: "rpc::server", "{addr} <-- {}", v.stringify()?);
225 let mut writer_lock = writer.lock().await;
226 if settings.use_http() {
227 http_write_to_stream(&mut writer_lock, &rep).await?;
228 } else {
229 write_to_stream(&mut writer_lock, &rep).await?;
230 }
231 drop(writer_lock);
232 }
233
234 JsonResult::Error(ref v) => {
235 debug!(target: "rpc::server", "{addr} <-- {}", v.stringify()?);
236 let mut writer_lock = writer.lock().await;
237 if settings.use_http() {
238 http_write_to_stream(&mut writer_lock, &rep).await?;
239 } else {
240 write_to_stream(&mut writer_lock, &rep).await?;
241 }
242 drop(writer_lock);
243 }
244 }
245
246 Ok(())
247}
248
249#[allow(clippy::type_complexity)]
252pub async fn accept<'a, T: 'a>(
253 reader: Arc<Mutex<BufReader<ReadHalf<Box<dyn PtStream>>>>>,
254 writer: Arc<Mutex<WriteHalf<Box<dyn PtStream>>>>,
255 addr: Url,
256 rh: Arc<impl RequestHandler<T> + 'static>,
257 conn_limit: Option<usize>,
258 settings: RpcSettings,
259 ex: Arc<smol::Executor<'a>>,
260) -> Result<()> {
261 if let Some(conn_limit) = conn_limit {
264 if rh.clone().active_connections().await >= conn_limit {
265 debug!(
266 target: "rpc::server::accept",
267 "Connection limit reached, refusing new conn"
268 );
269 return Err(Error::RpcConnectionsExhausted)
270 }
271 }
272
273 let tasks = Arc::new(Mutex::new(HashSet::new()));
275
276 loop {
277 let mut buf = Vec::with_capacity(INIT_BUF_SIZE);
278
279 let mut reader_lock = reader.lock().await;
280 if settings.use_http() {
281 let _ = http_read_from_stream_request(&mut reader_lock, &mut buf).await?;
282 } else {
283 let _ = read_from_stream(&mut reader_lock, &mut buf).await?;
284 }
285 drop(reader_lock);
286
287 let line = match String::from_utf8(buf) {
288 Ok(v) => v,
289 Err(e) => {
290 warn!(
291 target: "rpc::server::accept",
292 "[RPC SERVER] Failed parsing string from read buffer: {e}"
293 );
294 return Err(e.into())
295 }
296 };
297
298 let val: JsonValue = match line.trim().parse() {
300 Ok(v) => v,
301 Err(e) => {
302 warn!(
303 target: "rpc::server::accept",
304 "[RPC SERVER] Failed parsing JSON string: {e}"
305 );
306 return Err(e.into())
307 }
308 };
309
310 let req = match JsonRequest::try_from(&val) {
312 Ok(v) => v,
313 Err(e) => {
314 warn!(
315 target: "rpc::server::accept",
316 "[RPC SERVER] Failed casting JSON to a JsonRequest: {e}"
317 );
318 return Err(e.into())
319 }
320 };
321
322 debug!(target: "rpc::server", "{addr} --> {}", val.stringify()?);
323
324 let task = StoppableTask::new();
326
327 let task_ = task.clone();
329 let tasks_ = tasks.clone();
330
331 task.clone().start(
333 handle_request(
334 writer.clone(),
335 addr.clone(),
336 rh.clone(),
337 ex.clone(),
338 tasks.clone(),
339 settings.clone(),
340 req,
341 ),
342 move |_| async move {
343 debug!(
344 target: "rpc::server",
345 "Removing background task {} from map", task_.task_id,
346 );
347 tasks_.lock().await.remove(&task_);
348 },
349 Error::DetachedTaskStopped,
350 ex.clone(),
351 );
352
353 debug!(target: "rpc::server", "Adding background task {} to map", task.task_id);
354 tasks.lock().await.insert(task);
355 }
356}
357
358async fn run_accept_loop<'a, T: 'a>(
361 listener: Box<dyn PtListener>,
362 rh: Arc<impl RequestHandler<T> + 'static>,
363 conn_limit: Option<usize>,
364 settings: RpcSettings,
365 ex: Arc<smol::Executor<'a>>,
366) -> Result<()> {
367 loop {
368 match listener.next().await {
369 Ok((stream, url)) => {
370 let rh_ = rh.clone();
371 verbose!(target: "rpc::server", "[RPC] Server accepted conn from {url}");
372
373 let (reader, writer) = smol::io::split(stream);
374 let reader = Arc::new(Mutex::new(BufReader::new(reader)));
375 let writer = Arc::new(Mutex::new(writer));
376
377 let task = StoppableTask::new();
378 let task_ = task.clone();
379 let ex_ = ex.clone();
380 task.clone().start(
381 accept(
382 reader,
383 writer,
384 url.clone(),
385 rh.clone(),
386 conn_limit,
387 settings.clone(),
388 ex_,
389 ),
390 |_| async move {
391 verbose!(target: "rpc::server", "[RPC] Closed conn from {url}");
392 rh_.clone().unmark_connection(task_.clone()).await;
393 },
394 Error::ChannelStopped,
395 ex.clone(),
396 );
397
398 rh.clone().mark_connection(task.clone()).await;
399 }
400
401 Err(e) if e.raw_os_error().is_some() => match e.raw_os_error().unwrap() {
403 libc::EAGAIN | libc::ECONNABORTED | libc::EPROTO | libc::EINTR => continue,
404 libc::ECONNRESET => {
405 warn!(
406 target: "rpc::server::run_accept_loop",
407 "[RPC] Connection reset by peer in accept_loop"
408 );
409 continue
410 }
411 libc::ETIMEDOUT => {
412 warn!(
413 target: "rpc::server::run_accept_loop",
414 "[RPC] Connection timed out in accept_loop"
415 );
416 continue
417 }
418 libc::EPIPE => {
419 warn!(
420 target: "rpc::server::run_accept_loop",
421 "[RPC] Broken pipe in accept_loop"
422 );
423 continue
424 }
425 x => {
426 warn!(
427 target: "rpc::server::run_accept_loop",
428 "[RPC] Unhandled OS Error: {e} {x}"
429 );
430 continue
431 }
432 },
433
434 Err(e) if e.kind() == ErrorKind::UnexpectedEof => continue,
436
437 Err(e) if e.kind() == ErrorKind::Other => {
439 if let Some(inner) = std::error::Error::source(&e) {
440 if let Some(inner) = inner.downcast_ref::<futures_rustls::rustls::Error>() {
441 warn!(
442 target: "rpc::server::run_accept_loop",
443 "[RPC] rustls listener error: {inner:?}"
444 );
445 continue
446 }
447 }
448
449 warn!(
450 target: "rpc::server::run_accept_loop",
451 "[RPC] Unhandled ErrorKind::Other error: {e:?}"
452 );
453 continue
454 }
455
456 Err(e) => {
458 warn!(
459 target: "rpc::server::run_accept_loop",
460 "[RPC] Unhandled listener.next() error: {e}"
461 );
462 continue
463 }
464 }
465 }
466}
467
468pub async fn listen_and_serve<'a, T: 'a>(
474 settings: RpcSettings,
475 rh: Arc<impl RequestHandler<T> + 'static>,
476 conn_limit: Option<usize>,
477 ex: Arc<smol::Executor<'a>>,
478) -> Result<()> {
479 let mut listen_url = settings.listen.clone();
481 if settings.listen.scheme().starts_with("http+") {
482 let scheme = settings.listen.scheme().strip_prefix("http+").unwrap();
483 let url_str = settings.listen.as_str().replace(settings.listen.scheme(), scheme);
484 listen_url = url_str.parse()?;
485 }
486
487 let listener = Listener::new(listen_url, None).await?.listen().await?;
488
489 run_accept_loop(listener, rh, conn_limit, settings, ex.clone()).await
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use crate::{rpc::client::RpcClient, system::msleep};
496 use smol::{net::TcpListener, Executor};
497
498 struct RpcServer {
499 rpc_connections: Mutex<HashSet<StoppableTaskPtr>>,
500 }
501
502 #[async_trait]
503 impl RequestHandler<()> for RpcServer {
504 async fn handle_request(&self, req: JsonRequest) -> JsonResult {
505 match req.method.as_str() {
506 "ping" => return self.pong(req.id, req.params).await,
507 _ => panic!(),
508 }
509 }
510
511 async fn connections_mut(&self) -> MutexGuard<'life0, HashSet<StoppableTaskPtr>> {
512 self.rpc_connections.lock().await
513 }
514 }
515
516 #[test]
517 fn conn_manager() -> Result<()> {
518 let executor = Arc::new(Executor::new());
519
520 smol::block_on(executor.run(async {
526 let listener = TcpListener::bind("127.0.0.1:0").await?;
528 let sockaddr = listener.local_addr()?;
529 let settings = RpcSettings {
530 listen: Url::parse(&format!("tcp://127.0.0.1:{}", sockaddr.port()))?,
531 disabled_methods: vec![],
532 };
533 drop(listener);
534
535 let rpc_server = Arc::new(RpcServer { rpc_connections: Mutex::new(HashSet::new()) });
536 let rpc_server_ = rpc_server.clone();
537
538 let server_task = StoppableTask::new();
539 server_task.clone().start(
540 listen_and_serve(settings.clone(), rpc_server.clone(), None, executor.clone()),
541 |res| async move {
542 match res {
543 Ok(()) | Err(Error::RpcServerStopped) => {
544 rpc_server_.stop_connections().await
545 }
546 Err(e) => panic!("{e}"),
547 }
548 },
549 Error::RpcServerStopped,
550 executor.clone(),
551 );
552
553 msleep(500).await;
555
556 let rpc_client0 = RpcClient::new(settings.listen.clone(), executor.clone()).await?;
558 msleep(500).await;
559 assert!(rpc_server.active_connections().await == 1);
560
561 let rpc_client1 = RpcClient::new(settings.listen.clone(), executor.clone()).await?;
563 msleep(500).await;
564 assert!(rpc_server.active_connections().await == 2);
565
566 let _rpc_client2 = RpcClient::new(settings.listen.clone(), executor.clone()).await?;
568 msleep(500).await;
569 assert!(rpc_server.active_connections().await == 3);
570
571 rpc_client0.stop().await;
573 msleep(500).await;
574 assert!(rpc_server.active_connections().await == 2);
575
576 rpc_client1.stop().await;
578 msleep(500).await;
579 assert!(rpc_server.active_connections().await == 1);
580
581 server_task.stop().await;
583 assert!(RpcClient::new(settings.listen, executor.clone()).await.is_err());
584
585 assert!(rpc_server.active_connections().await == 0);
587
588 Ok(())
589 }))
590 }
591}