1pub mod error;
8pub mod future;
9pub mod spawn;
10
11pub use spawn::Spawn;
12
13use crate::mock::{error::Error, future::ResponseFuture};
14use core::task::Waker;
15
16use tokio::sync::{mpsc, oneshot};
17use tower_layer::Layer;
18use tower_service::Service;
19
20use std::{
21 collections::HashMap,
22 future::Future,
23 sync::{Arc, Mutex},
24 task::{Context, Poll},
25};
26
27pub fn spawn_layer<T, U, L>(layer: L) -> (Spawn<L::Service>, Handle<T, U>)
32where
33 L: Layer<Mock<T, U>>,
34{
35 let (inner, handle) = pair();
36 let svc = layer.layer(inner);
37
38 (Spawn::new(svc), handle)
39}
40
41pub fn spawn<T, U>() -> (Spawn<Mock<T, U>>, Handle<T, U>) {
47 let (svc, handle) = pair();
48
49 (Spawn::new(svc), handle)
50}
51
52pub fn spawn_with<T, U, F, S>(f: F) -> (Spawn<S>, Handle<T, U>)
58where
59 F: Fn(Mock<T, U>) -> S,
60{
61 let (svc, handle) = pair();
62
63 let svc = f(svc);
64
65 (Spawn::new(svc), handle)
66}
67
68#[derive(Debug)]
75pub struct Mock<T, U> {
76 id: u64,
77 tx: Mutex<Tx<T, U>>,
78 state: Arc<Mutex<State>>,
79 can_send: bool,
80}
81
82#[derive(Debug)]
94pub struct Handle<T, U> {
95 rx: Rx<T, U>,
96 state: Arc<Mutex<State>>,
97}
98
99type Request<T, U> = (T, SendResponse<U>);
100
101#[derive(Debug)]
108pub struct SendResponse<T> {
109 tx: oneshot::Sender<Result<T, Error>>,
110}
111
112#[derive(Debug)]
113struct State {
114 rem: u64,
116
117 tasks: HashMap<u64, Waker>,
119
120 is_closed: bool,
122
123 next_clone_id: u64,
125
126 err_with: Option<Error>,
128}
129
130type Tx<T, U> = mpsc::UnboundedSender<Request<T, U>>;
131type Rx<T, U> = mpsc::UnboundedReceiver<Request<T, U>>;
132
133pub fn pair<T, U>() -> (Mock<T, U>, Handle<T, U>) {
138 let (tx, rx) = mpsc::unbounded_channel();
139 let tx = Mutex::new(tx);
140
141 let state = Arc::new(Mutex::new(State::new()));
142
143 let mock = Mock {
144 id: 0,
145 tx,
146 state: state.clone(),
147 can_send: false,
148 };
149
150 let handle = Handle { rx, state };
151
152 (mock, handle)
153}
154
155impl<T, U> Service<T> for Mock<T, U> {
156 type Response = U;
157 type Error = Error;
158 type Future = ResponseFuture<U>;
159
160 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
161 let mut state = self.state.lock().unwrap();
162
163 if state.is_closed {
164 return Poll::Ready(Err(error::Closed::new().into()));
165 }
166
167 if let Some(e) = state.err_with.take() {
168 return Poll::Ready(Err(e));
169 }
170
171 if self.can_send {
172 return Poll::Ready(Ok(()));
173 }
174
175 if state.rem > 0 {
176 assert!(!state.tasks.contains_key(&self.id));
177
178 self.can_send = true;
180
181 Poll::Ready(Ok(()))
182 } else {
183 *state
185 .tasks
186 .entry(self.id)
187 .or_insert_with(|| cx.waker().clone()) = cx.waker().clone();
188
189 Poll::Pending
190 }
191 }
192
193 fn call(&mut self, request: T) -> Self::Future {
194 let mut state = self.state.lock().unwrap();
196
197 if state.is_closed {
198 return ResponseFuture::closed();
199 }
200
201 if !self.can_send {
202 panic!("service not ready; poll_ready must be called first");
203 }
204
205 self.can_send = false;
206
207 if state.rem > 0 {
209 state.rem -= 1;
210 }
211
212 let (tx, rx) = oneshot::channel();
213 let send_response = SendResponse { tx };
214
215 match self.tx.lock().unwrap().send((request, send_response)) {
216 Ok(_) => {}
217 Err(_) => {
218 return ResponseFuture::closed();
220 }
221 }
222
223 ResponseFuture::new(rx)
224 }
225}
226
227impl<T, U> Clone for Mock<T, U> {
228 fn clone(&self) -> Self {
229 let id = {
230 let mut state = self.state.lock().unwrap();
231 let id = state.next_clone_id;
232
233 state.next_clone_id += 1;
234
235 id
236 };
237
238 let tx = Mutex::new(self.tx.lock().unwrap().clone());
239
240 Mock {
241 id,
242 tx,
243 state: self.state.clone(),
244 can_send: false,
245 }
246 }
247}
248
249impl<T, U> Drop for Mock<T, U> {
250 fn drop(&mut self) {
251 let mut state = match self.state.lock() {
252 Ok(v) => v,
253 Err(e) => {
254 if ::std::thread::panicking() {
255 return;
256 }
257
258 panic!("{:?}", e);
259 }
260 };
261
262 state.tasks.remove(&self.id);
263 }
264}
265
266impl<T, U> Handle<T, U> {
269 pub fn poll_request(&mut self) -> Poll<Option<Request<T, U>>> {
276 tokio_test::task::spawn(()).enter(|cx, _| Box::pin(self.rx.recv()).as_mut().poll(cx))
277 }
278
279 pub async fn next_request(&mut self) -> Option<Request<T, U>> {
284 self.rx.recv().await
285 }
286
287 pub fn allow(&mut self, num: u64) {
296 let mut state = self.state.lock().unwrap();
297 state.rem = num;
298
299 if num > 0 {
300 for (_, task) in state.tasks.drain() {
301 task.wake();
302 }
303 }
304 }
305
306 pub fn send_error<E: Into<Error>>(&mut self, e: E) {
308 let mut state = self.state.lock().unwrap();
309 state.err_with = Some(e.into());
310
311 for (_, task) in state.tasks.drain() {
312 task.wake();
313 }
314 }
315}
316
317impl<T, U> Drop for Handle<T, U> {
318 fn drop(&mut self) {
319 let mut state = match self.state.lock() {
320 Ok(v) => v,
321 Err(e) => {
322 if ::std::thread::panicking() {
323 return;
324 }
325
326 panic!("{:?}", e);
327 }
328 };
329
330 state.is_closed = true;
331
332 for (_, task) in state.tasks.drain() {
333 task.wake();
334 }
335 }
336}
337
338impl<T> SendResponse<T> {
341 pub fn send_response(self, response: T) {
343 let _ = self.tx.send(Ok(response));
345 }
346
347 pub fn send_error<E: Into<Error>>(self, err: E) {
349 let _ = self.tx.send(Err(err.into()));
351 }
352}
353
354impl State {
357 fn new() -> State {
358 State {
359 rem: u64::MAX,
360 tasks: HashMap::new(),
361 is_closed: false,
362 next_clone_id: 1,
363 err_with: None,
364 }
365 }
366}