Skip to main content

tower_test/mock/
mod.rs

1//! Mock [`Service`]s for use in tests.
2//!
3//! See the [crate-level documentation](crate) for an overview and an example.
4//!
5//! [`Service`]: tower_service::Service
6
7pub mod error;
8pub mod future;
9pub mod spawn;
10
11pub use spawn::Spawn;
12
13use crate::mock::{error::Error, future::ResponseFuture};
14use core::task::Waker;
15
16use tokio::sync::{mpsc, oneshot};
17use tower_layer::Layer;
18use tower_service::Service;
19
20use std::{
21    collections::HashMap,
22    future::Future,
23    sync::{Arc, Mutex},
24    task::{Context, Poll},
25};
26
27/// Apply a [`Layer`] to a mock [`Service`] and spawn the result on a mock task.
28///
29/// Returns the layered service wrapped in a [`Spawn`], along with the [`Handle`]
30/// for the underlying [`Mock`].
31pub fn spawn_layer<T, U, L>(layer: L) -> (Spawn<L::Service>, Handle<T, U>)
32where
33    L: Layer<Mock<T, U>>,
34{
35    let (inner, handle) = pair();
36    let svc = layer.layer(inner);
37
38    (Spawn::new(svc), handle)
39}
40
41/// Create a mock [`Service`] spawned on a mock task.
42///
43/// The returned [`Spawn`] wraps a [`Mock`] so that its readiness can be polled
44/// synchronously in tests; the paired [`Handle`] is used to receive requests
45/// and send responses. See [`pair`] for the un-spawned equivalent.
46pub fn spawn<T, U>() -> (Spawn<Mock<T, U>>, Handle<T, U>) {
47    let (svc, handle) = pair();
48
49    (Spawn::new(svc), handle)
50}
51
52/// Create a mock [`Service`], pass it through `f`, and spawn the result on a
53/// mock task.
54///
55/// This is like [`spawn()`], but the closure `f` may wrap the [`Mock`] in
56/// additional middleware before it is spawned.
57pub fn spawn_with<T, U, F, S>(f: F) -> (Spawn<S>, Handle<T, U>)
58where
59    F: Fn(Mock<T, U>) -> S,
60{
61    let (svc, handle) = pair();
62
63    let svc = f(svc);
64
65    (Spawn::new(svc), handle)
66}
67
68/// A mock [`Service`].
69///
70/// Every request is forwarded to the paired [`Handle`], which decides whether
71/// and how to respond. Construct one with [`pair`] (or one of the `spawn*`
72/// functions). Cloning a `Mock` produces another service backed by the same
73/// [`Handle`], so a single handle can observe the requests of every clone.
74#[derive(Debug)]
75pub struct Mock<T, U> {
76    id: u64,
77    tx: Mutex<Tx<T, U>>,
78    state: Arc<Mutex<State>>,
79    can_send: bool,
80}
81
82/// Drives a paired [`Mock`].
83///
84/// A `Handle` receives the requests made to its [`Mock`] (via
85/// [`next_request`]/[`poll_request`], each of which yields a [`SendResponse`]
86/// for replying), can fail the mock's readiness with [`send_error`], and can
87/// limit how many requests the mock accepts with [`allow`].
88///
89/// [`next_request`]: Handle::next_request
90/// [`poll_request`]: Handle::poll_request
91/// [`send_error`]: Handle::send_error
92/// [`allow`]: Handle::allow
93#[derive(Debug)]
94pub struct Handle<T, U> {
95    rx: Rx<T, U>,
96    state: Arc<Mutex<State>>,
97}
98
99type Request<T, U> = (T, SendResponse<U>);
100
101/// Sends a response (or error) back for a single request received by a [`Mock`].
102///
103/// Returned, paired with the request, by [`Handle::next_request`] and
104/// [`Handle::poll_request`] (and by the [`assert_request_eq!`] macro).
105///
106/// [`assert_request_eq!`]: crate::assert_request_eq
107#[derive(Debug)]
108pub struct SendResponse<T> {
109    tx: oneshot::Sender<Result<T, Error>>,
110}
111
112#[derive(Debug)]
113struct State {
114    /// Tracks the number of requests that can be sent through
115    rem: u64,
116
117    /// Tasks that are blocked
118    tasks: HashMap<u64, Waker>,
119
120    /// Tracks if the `Handle` dropped
121    is_closed: bool,
122
123    /// Tracks the ID for the next mock clone
124    next_clone_id: u64,
125
126    /// Tracks the next error to yield (if any)
127    err_with: Option<Error>,
128}
129
130type Tx<T, U> = mpsc::UnboundedSender<Request<T, U>>;
131type Rx<T, U> = mpsc::UnboundedReceiver<Request<T, U>>;
132
133/// Create a [`Mock`] [`Service`] paired with its [`Handle`].
134///
135/// By default the mock accepts any number of requests (its `poll_ready` is
136/// always ready); use [`Handle::allow`] to apply backpressure.
137pub fn pair<T, U>() -> (Mock<T, U>, Handle<T, U>) {
138    let (tx, rx) = mpsc::unbounded_channel();
139    let tx = Mutex::new(tx);
140
141    let state = Arc::new(Mutex::new(State::new()));
142
143    let mock = Mock {
144        id: 0,
145        tx,
146        state: state.clone(),
147        can_send: false,
148    };
149
150    let handle = Handle { rx, state };
151
152    (mock, handle)
153}
154
155impl<T, U> Service<T> for Mock<T, U> {
156    type Response = U;
157    type Error = Error;
158    type Future = ResponseFuture<U>;
159
160    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
161        let mut state = self.state.lock().unwrap();
162
163        if state.is_closed {
164            return Poll::Ready(Err(error::Closed::new().into()));
165        }
166
167        if let Some(e) = state.err_with.take() {
168            return Poll::Ready(Err(e));
169        }
170
171        if self.can_send {
172            return Poll::Ready(Ok(()));
173        }
174
175        if state.rem > 0 {
176            assert!(!state.tasks.contains_key(&self.id));
177
178            // Returning `Ready` means the next call to `call` must succeed.
179            self.can_send = true;
180
181            Poll::Ready(Ok(()))
182        } else {
183            // Bit weird... but whatevz
184            *state
185                .tasks
186                .entry(self.id)
187                .or_insert_with(|| cx.waker().clone()) = cx.waker().clone();
188
189            Poll::Pending
190        }
191    }
192
193    fn call(&mut self, request: T) -> Self::Future {
194        // Make sure that the service has capacity
195        let mut state = self.state.lock().unwrap();
196
197        if state.is_closed {
198            return ResponseFuture::closed();
199        }
200
201        if !self.can_send {
202            panic!("service not ready; poll_ready must be called first");
203        }
204
205        self.can_send = false;
206
207        // Decrement the number of remaining requests that can be sent
208        if state.rem > 0 {
209            state.rem -= 1;
210        }
211
212        let (tx, rx) = oneshot::channel();
213        let send_response = SendResponse { tx };
214
215        match self.tx.lock().unwrap().send((request, send_response)) {
216            Ok(_) => {}
217            Err(_) => {
218                // TODO: Can this be reached
219                return ResponseFuture::closed();
220            }
221        }
222
223        ResponseFuture::new(rx)
224    }
225}
226
227impl<T, U> Clone for Mock<T, U> {
228    fn clone(&self) -> Self {
229        let id = {
230            let mut state = self.state.lock().unwrap();
231            let id = state.next_clone_id;
232
233            state.next_clone_id += 1;
234
235            id
236        };
237
238        let tx = Mutex::new(self.tx.lock().unwrap().clone());
239
240        Mock {
241            id,
242            tx,
243            state: self.state.clone(),
244            can_send: false,
245        }
246    }
247}
248
249impl<T, U> Drop for Mock<T, U> {
250    fn drop(&mut self) {
251        let mut state = match self.state.lock() {
252            Ok(v) => v,
253            Err(e) => {
254                if ::std::thread::panicking() {
255                    return;
256                }
257
258                panic!("{:?}", e);
259            }
260        };
261
262        state.tasks.remove(&self.id);
263    }
264}
265
266// ===== impl Handle =====
267
268impl<T, U> Handle<T, U> {
269    /// Polls for the next request made to the [`Mock`].
270    ///
271    /// On [`Ready`], yields the request together with a [`SendResponse`] used to
272    /// reply to it, or [`None`] once every [`Mock`] clone has been dropped.
273    ///
274    /// [`Ready`]: std::task::Poll::Ready
275    pub fn poll_request(&mut self) -> Poll<Option<Request<T, U>>> {
276        tokio_test::task::spawn(()).enter(|cx, _| Box::pin(self.rx.recv()).as_mut().poll(cx))
277    }
278
279    /// Waits for the next request made to the [`Mock`].
280    ///
281    /// Resolves to the request together with a [`SendResponse`] used to reply to
282    /// it, or [`None`] once every [`Mock`] clone has been dropped.
283    pub async fn next_request(&mut self) -> Option<Request<T, U>> {
284        self.rx.recv().await
285    }
286
287    /// Allow the [`Mock`] to accept `num` more requests.
288    ///
289    /// Once the mock has accepted that many requests, its `poll_ready` returns
290    /// [`Pending`] until `allow` is called again. A newly-created mock starts
291    /// out allowing `u64::MAX` requests, so this is only needed to exert
292    /// backpressure in a test.
293    ///
294    /// [`Pending`]: std::task::Poll::Pending
295    pub fn allow(&mut self, num: u64) {
296        let mut state = self.state.lock().unwrap();
297        state.rem = num;
298
299        if num > 0 {
300            for (_, task) in state.tasks.drain() {
301                task.wake();
302            }
303        }
304    }
305
306    /// Make the [`Mock`]'s next `poll_ready` resolve to the given error.
307    pub fn send_error<E: Into<Error>>(&mut self, e: E) {
308        let mut state = self.state.lock().unwrap();
309        state.err_with = Some(e.into());
310
311        for (_, task) in state.tasks.drain() {
312            task.wake();
313        }
314    }
315}
316
317impl<T, U> Drop for Handle<T, U> {
318    fn drop(&mut self) {
319        let mut state = match self.state.lock() {
320            Ok(v) => v,
321            Err(e) => {
322                if ::std::thread::panicking() {
323                    return;
324                }
325
326                panic!("{:?}", e);
327            }
328        };
329
330        state.is_closed = true;
331
332        for (_, task) in state.tasks.drain() {
333            task.wake();
334        }
335    }
336}
337
338// ===== impl SendResponse =====
339
340impl<T> SendResponse<T> {
341    /// Resolve the request's response future with the given response.
342    pub fn send_response(self, response: T) {
343        // TODO: Should the result be dropped?
344        let _ = self.tx.send(Ok(response));
345    }
346
347    /// Resolve the request's response future with the given error.
348    pub fn send_error<E: Into<Error>>(self, err: E) {
349        // TODO: Should the result be dropped?
350        let _ = self.tx.send(Err(err.into()));
351    }
352}
353
354// ===== impl State =====
355
356impl State {
357    fn new() -> State {
358        State {
359            rem: u64::MAX,
360            tasks: HashMap::new(),
361            is_closed: false,
362            next_clone_id: 1,
363            err_with: None,
364        }
365    }
366}