use std::collections::vec_deque::Drain;
use std::collections::VecDeque;
use tokio::sync::oneshot::error::RecvError;
use tokio::sync::oneshot::Receiver;
pub struct OneShotHolder<E> {
size: usize,
inflight: VecDeque<Receiver<Result<(), E>>>,
}
impl<E> OneShotHolder<E> {
pub fn new(size: usize) -> OneShotHolder<E> {
OneShotHolder {
size,
inflight: VecDeque::with_capacity(size),
}
}
pub async fn add(&mut self, item: Receiver<Result<(), E>>) -> Result<Result<(), E>, RecvError> {
if self.size == 0 {
return item.await;
}
let result = if self.inflight.len() >= self.size {
let fut = self.inflight.pop_front().unwrap();
fut.await
} else {
Ok(Ok(()))
};
self.inflight.push_back(item);
result
}
pub fn drain(&mut self) -> Drain<'_, Receiver<Result<(), E>>> {
self.inflight.drain(..)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::oneshot::channel;
#[derive(Debug)]
struct CustomError;
#[tokio::test]
async fn test_oneshot_holder() {
let mut holder: OneShotHolder<CustomError> = OneShotHolder::new(1);
let (tx1, rx1) = channel::<Result<(), CustomError>>();
let (tx2, rx2) = channel::<Result<(), CustomError>>();
let r = holder.add(rx1).await.unwrap();
assert!(r.is_ok());
tokio::spawn(async move {
if let Err(_) = tx1.send(Ok(())) {
panic!("error is not expected");
}
});
let r = holder.add(rx2).await.unwrap();
assert!(r.is_ok());
tokio::spawn(async move {
if let Err(_) = tx2.send(Err(CustomError)) {
panic!("error is not expected");
}
});
let mut iter = holder.drain();
match iter.next() {
Some(r) => {
if let Ok(_) = r.await.unwrap() {
panic!("Error expected");
}
}
None => panic!("Expected an entry."),
};
assert!(iter.next().is_none());
}
#[tokio::test]
async fn test_zero_size_oneshot_holder() {
let mut holder: OneShotHolder<CustomError> = OneShotHolder::new(0);
let (tx1, rx1) = channel::<Result<(), CustomError>>();
tokio::spawn(async move {
if let Err(_) = tx1.send(Ok(())) {
panic!("error is not expected");
}
});
let r = holder.add(rx1).await.unwrap();
assert!(r.is_ok());
let mut iter = holder.drain();
assert!(iter.next().is_none());
}
#[tokio::test]
async fn test_receiver_error() {
let mut holder: OneShotHolder<CustomError> = OneShotHolder::new(1);
let (tx1, rx1) = channel::<Result<(), CustomError>>();
let (tx2, rx2) = channel::<Result<(), CustomError>>();
let r = holder.add(rx1).await.unwrap();
assert!(r.is_ok());
tokio::spawn(async move {
drop(tx1);
});
let r = holder.add(rx2).await;
assert!(r.is_err()); tokio::spawn(async move {
if let Err(_) = tx2.send(Err(CustomError)) {
panic!("error is not expected");
}
});
let mut iter = holder.drain();
match iter.next() {
Some(r) => {
if let Ok(_) = r.await.unwrap() {
panic!("Error expected");
}
}
None => panic!("Expected an entry."),
};
assert!(iter.next().is_none());
}
}