darkfi/system/stoppable_task.rs
1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2026 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program. If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use rand::{rngs::OsRng, Rng};
20use smol::{
21 future::{self, Future},
22 Executor,
23};
24use std::sync::Arc;
25use tracing::trace;
26
27use super::CondVar;
28
29pub type StoppableTaskPtr = Arc<StoppableTask>;
30
31pub struct StoppableTask {
32 /// Used to signal to the main running process that it should stop.
33 signal: CondVar,
34 /// When we call `stop()`, we wait until the process is finished. This is used to prevent
35 /// `stop()` from exiting until the task has closed.
36 barrier: CondVar,
37
38 /// Used so we can keep StoppableTask in HashMap/HashSet
39 pub task_id: u32,
40}
41
42/// A task that can be prematurely stopped at any time.
43///
44/// ```rust
45/// let task = StoppableTask::new();
46/// task.clone().start(
47/// my_method(),
48/// |result| self_.handle_stop(result),
49/// Error::MyStopError,
50/// executor,
51/// );
52/// ```
53///
54/// Then at any time we can call `task.stop()` to close the task.
55impl StoppableTask {
56 pub fn new() -> Arc<Self> {
57 Arc::new(Self { signal: CondVar::new(), barrier: CondVar::new(), task_id: OsRng.gen() })
58 }
59
60 /// Starts the task.
61 ///
62 /// * `main` is a function of the type `async fn foo() -> ()`
63 /// * `stop_handler` is a function of the type `async fn handle_stop(result: Result<()>) -> ()`
64 /// * `stop_value` is the Error code passed to `stop_handler` when `task.stop()` is called
65 pub fn start<'a, MainFut, StopFut, StopFn, Error>(
66 self: Arc<Self>,
67 main: MainFut,
68 stop_handler: StopFn,
69 stop_value: Error,
70 executor: Arc<Executor<'a>>,
71 ) where
72 MainFut: Future<Output = std::result::Result<(), Error>> + Send + 'a,
73 StopFut: Future<Output = ()> + Send,
74 StopFn: FnOnce(std::result::Result<(), Error>) -> StopFut + Send + 'a,
75 Error: std::error::Error + Send + 'a,
76 {
77 // NOTE: we could send the error code from stop() instead of having it specified in start()
78 trace!(target: "system::StoppableTask", "Starting task {}", self.task_id);
79 // Allow stopping and starting task again.
80 // NOTE: maybe we should disallow this with a panic?
81 self.signal.reset();
82 self.barrier.reset();
83
84 executor
85 .spawn(async move {
86 // Task which waits for a stop signal
87 let stop_fut = async {
88 self.signal.wait().await;
89 trace!(
90 target: "system::StoppableTask",
91 "Stop signal received for task {}",
92 self.task_id
93 );
94 Err(stop_value)
95 };
96
97 // Wait on our main task or stop task - whichever finishes first
98 let result = future::or(main, stop_fut).await;
99
100 trace!(
101 target: "system::StoppableTask",
102 "Closing task {} with result: {:?}",
103 self.task_id,
104 result
105 );
106
107 stop_handler(result).await;
108 // Allow `stop()` to finish
109 self.barrier.notify();
110 })
111 .detach();
112 }
113
114 /// Stops the task. On completion, guarantees the process has stopped.
115 /// Can be called multiple times. After the first call, this does nothing.
116 pub async fn stop(&self) {
117 trace!(target: "system::StoppableTask", "Stopping task {}", self.task_id);
118 self.signal.notify();
119 self.barrier.wait().await;
120 trace!(target: "system::StoppableTask", "Stopped task {}", self.task_id);
121 }
122
123 /// Sends a stop signal and returns immediately. Doesn't guarantee the task
124 /// stopped on completion.
125 pub fn stop_nowait(&self) {
126 trace!(target: "system::StoppableTask", "Stopping task (nowait) {}", self.task_id);
127 self.signal.notify();
128 }
129}
130
131impl std::hash::Hash for StoppableTask {
132 fn hash<H>(&self, state: &mut H)
133 where
134 H: std::hash::Hasher,
135 {
136 self.task_id.hash(state);
137 }
138}
139
140impl std::cmp::PartialEq for StoppableTask {
141 fn eq(&self, other: &Self) -> bool {
142 self.task_id == other.task_id
143 }
144}
145
146impl std::cmp::Eq for StoppableTask {}
147
148impl Drop for StoppableTask {
149 fn drop(&mut self) {
150 self.stop_nowait()
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::{
158 error::Error,
159 system::sleep_forever,
160 util::logger::{setup_test_logger, Level},
161 };
162 use tracing::warn;
163
164 #[test]
165 fn stoppit_mom() {
166 // We check this error so we can execute same file tests in parallel,
167 // otherwise second one fails to init logger here.
168 if setup_test_logger(
169 &["async_io", "polling"],
170 false,
171 //Level::Info,
172 //Level::Verbose,
173 //Level::Debug
174 Level::Trace,
175 )
176 .is_err()
177 {
178 warn!(target: "test_harness", "Logger already initialized");
179 }
180
181 let executor = Arc::new(Executor::new());
182 let executor_ = executor.clone();
183 smol::block_on(executor.run(async move {
184 let task = StoppableTask::new();
185 task.clone().start(
186 // Main process is an infinite loop
187 async {
188 sleep_forever().await;
189 unreachable!()
190 },
191 // Handle stop
192 |result| async move {
193 assert!(matches!(result, Err(Error::DetachedTaskStopped)));
194 },
195 Error::DetachedTaskStopped,
196 executor_,
197 );
198 task.stop().await;
199 }))
200 }
201}