darkfi/system/
stoppable_task.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2026 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use rand::{rngs::OsRng, Rng};
20use smol::{
21    future::{self, Future},
22    Executor,
23};
24use std::sync::Arc;
25use tracing::trace;
26
27use super::CondVar;
28
29pub type StoppableTaskPtr = Arc<StoppableTask>;
30
31pub struct StoppableTask {
32    /// Used to signal to the main running process that it should stop.
33    signal: CondVar,
34    /// When we call `stop()`, we wait until the process is finished. This is used to prevent
35    /// `stop()` from exiting until the task has closed.
36    barrier: CondVar,
37
38    /// Used so we can keep StoppableTask in HashMap/HashSet
39    pub task_id: u32,
40}
41
42/// A task that can be prematurely stopped at any time.
43///
44/// ```rust
45///     let task = StoppableTask::new();
46///     task.clone().start(
47///         my_method(),
48///         |result| self_.handle_stop(result),
49///         Error::MyStopError,
50///         executor,
51///     );
52/// ```
53///
54/// Then at any time we can call `task.stop()` to close the task.
55impl StoppableTask {
56    pub fn new() -> Arc<Self> {
57        Arc::new(Self { signal: CondVar::new(), barrier: CondVar::new(), task_id: OsRng.gen() })
58    }
59
60    /// Starts the task.
61    ///
62    /// * `main` is a function of the type `async fn foo() -> ()`
63    /// * `stop_handler` is a function of the type `async fn handle_stop(result: Result<()>) -> ()`
64    /// * `stop_value` is the Error code passed to `stop_handler` when `task.stop()` is called
65    pub fn start<'a, MainFut, StopFut, StopFn, Error>(
66        self: Arc<Self>,
67        main: MainFut,
68        stop_handler: StopFn,
69        stop_value: Error,
70        executor: Arc<Executor<'a>>,
71    ) where
72        MainFut: Future<Output = std::result::Result<(), Error>> + Send + 'a,
73        StopFut: Future<Output = ()> + Send,
74        StopFn: FnOnce(std::result::Result<(), Error>) -> StopFut + Send + 'a,
75        Error: std::error::Error + Send + 'a,
76    {
77        // NOTE: we could send the error code from stop() instead of having it specified in start()
78        trace!(target: "system::StoppableTask", "Starting task {}", self.task_id);
79        // Allow stopping and starting task again.
80        // NOTE: maybe we should disallow this with a panic?
81        self.signal.reset();
82        self.barrier.reset();
83
84        executor
85            .spawn(async move {
86                // Task which waits for a stop signal
87                let stop_fut = async {
88                    self.signal.wait().await;
89                    trace!(
90                        target: "system::StoppableTask",
91                        "Stop signal received for task {}",
92                        self.task_id
93                    );
94                    Err(stop_value)
95                };
96
97                // Wait on our main task or stop task - whichever finishes first
98                let result = future::or(main, stop_fut).await;
99
100                trace!(
101                    target: "system::StoppableTask",
102                    "Closing task {} with result: {:?}",
103                    self.task_id,
104                    result
105                );
106
107                stop_handler(result).await;
108                // Allow `stop()` to finish
109                self.barrier.notify();
110            })
111            .detach();
112    }
113
114    /// Stops the task. On completion, guarantees the process has stopped.
115    /// Can be called multiple times. After the first call, this does nothing.
116    pub async fn stop(&self) {
117        trace!(target: "system::StoppableTask", "Stopping task {}", self.task_id);
118        self.signal.notify();
119        self.barrier.wait().await;
120        trace!(target: "system::StoppableTask", "Stopped task {}", self.task_id);
121    }
122
123    /// Sends a stop signal and returns immediately. Doesn't guarantee the task
124    /// stopped on completion.
125    pub fn stop_nowait(&self) {
126        trace!(target: "system::StoppableTask", "Stopping task (nowait) {}", self.task_id);
127        self.signal.notify();
128    }
129}
130
131impl std::hash::Hash for StoppableTask {
132    fn hash<H>(&self, state: &mut H)
133    where
134        H: std::hash::Hasher,
135    {
136        self.task_id.hash(state);
137    }
138}
139
140impl std::cmp::PartialEq for StoppableTask {
141    fn eq(&self, other: &Self) -> bool {
142        self.task_id == other.task_id
143    }
144}
145
146impl std::cmp::Eq for StoppableTask {}
147
148impl Drop for StoppableTask {
149    fn drop(&mut self) {
150        self.stop_nowait()
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::{
158        error::Error,
159        system::sleep_forever,
160        util::logger::{setup_test_logger, Level},
161    };
162    use tracing::warn;
163
164    #[test]
165    fn stoppit_mom() {
166        // We check this error so we can execute same file tests in parallel,
167        // otherwise second one fails to init logger here.
168        if setup_test_logger(
169            &["async_io", "polling"],
170            false,
171            //Level::Info,
172            //Level::Verbose,
173            //Level::Debug
174            Level::Trace,
175        )
176        .is_err()
177        {
178            warn!(target: "test_harness", "Logger already initialized");
179        }
180
181        let executor = Arc::new(Executor::new());
182        let executor_ = executor.clone();
183        smol::block_on(executor.run(async move {
184            let task = StoppableTask::new();
185            task.clone().start(
186                // Main process is an infinite loop
187                async {
188                    sleep_forever().await;
189                    unreachable!()
190                },
191                // Handle stop
192                |result| async move {
193                    assert!(matches!(result, Err(Error::DetachedTaskStopped)));
194                },
195                Error::DetachedTaskStopped,
196                executor_,
197            );
198            task.stop().await;
199        }))
200    }
201}