forked from mirrors/gecko-dev
		
	
		
			
				
	
	
		
			157 lines
		
	
	
	
		
			4.5 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
			
		
		
	
	
			157 lines
		
	
	
	
		
			4.5 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
| #![warn(rust_2018_idioms)]
 | |
| #![cfg(all(feature = "full", not(tokio_wasi)))] // Wasi doesn't support bind
 | |
| 
 | |
| use tokio::net::{TcpListener, TcpStream};
 | |
| use tokio::sync::{mpsc, oneshot};
 | |
| use tokio_test::assert_ok;
 | |
| 
 | |
| use std::io;
 | |
| use std::net::{IpAddr, SocketAddr};
 | |
| 
 | |
| macro_rules! test_accept {
 | |
|     ($(($ident:ident, $target:expr),)*) => {
 | |
|         $(
 | |
|             #[tokio::test]
 | |
|             async fn $ident() {
 | |
|                 let listener = assert_ok!(TcpListener::bind($target).await);
 | |
|                 let addr = listener.local_addr().unwrap();
 | |
| 
 | |
|                 let (tx, rx) = oneshot::channel();
 | |
| 
 | |
|                 tokio::spawn(async move {
 | |
|                     let (socket, _) = assert_ok!(listener.accept().await);
 | |
|                     assert_ok!(tx.send(socket));
 | |
|                 });
 | |
| 
 | |
|                 let cli = assert_ok!(TcpStream::connect(&addr).await);
 | |
|                 let srv = assert_ok!(rx.await);
 | |
| 
 | |
|                 assert_eq!(cli.local_addr().unwrap(), srv.peer_addr().unwrap());
 | |
|             }
 | |
|         )*
 | |
|     }
 | |
| }
 | |
| 
 | |
| test_accept! {
 | |
|     (ip_str, "127.0.0.1:0"),
 | |
|     (host_str, "localhost:0"),
 | |
|     (socket_addr, "127.0.0.1:0".parse::<SocketAddr>().unwrap()),
 | |
|     (str_port_tuple, ("127.0.0.1", 0)),
 | |
|     (ip_port_tuple, ("127.0.0.1".parse::<IpAddr>().unwrap(), 0)),
 | |
| }
 | |
| 
 | |
| use std::pin::Pin;
 | |
| use std::sync::{
 | |
|     atomic::{AtomicUsize, Ordering::SeqCst},
 | |
|     Arc,
 | |
| };
 | |
| use std::task::{Context, Poll};
 | |
| use tokio_stream::{Stream, StreamExt};
 | |
| 
 | |
| struct TrackPolls<'a> {
 | |
|     npolls: Arc<AtomicUsize>,
 | |
|     listener: &'a mut TcpListener,
 | |
| }
 | |
| 
 | |
| impl<'a> Stream for TrackPolls<'a> {
 | |
|     type Item = io::Result<(TcpStream, SocketAddr)>;
 | |
| 
 | |
|     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
 | |
|         self.npolls.fetch_add(1, SeqCst);
 | |
|         self.listener.poll_accept(cx).map(Some)
 | |
|     }
 | |
| }
 | |
| 
 | |
| #[tokio::test]
 | |
| async fn no_extra_poll() {
 | |
|     let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
 | |
|     let addr = listener.local_addr().unwrap();
 | |
| 
 | |
|     let (tx, rx) = oneshot::channel();
 | |
|     let (accepted_tx, mut accepted_rx) = mpsc::unbounded_channel();
 | |
| 
 | |
|     tokio::spawn(async move {
 | |
|         let mut incoming = TrackPolls {
 | |
|             npolls: Arc::new(AtomicUsize::new(0)),
 | |
|             listener: &mut listener,
 | |
|         };
 | |
|         assert_ok!(tx.send(Arc::clone(&incoming.npolls)));
 | |
|         while incoming.next().await.is_some() {
 | |
|             accepted_tx.send(()).unwrap();
 | |
|         }
 | |
|     });
 | |
| 
 | |
|     let npolls = assert_ok!(rx.await);
 | |
|     tokio::task::yield_now().await;
 | |
| 
 | |
|     // should have been polled exactly once: the initial poll
 | |
|     assert_eq!(npolls.load(SeqCst), 1);
 | |
| 
 | |
|     let _ = assert_ok!(TcpStream::connect(&addr).await);
 | |
|     accepted_rx.recv().await.unwrap();
 | |
| 
 | |
|     // should have been polled twice more: once to yield Some(), then once to yield Pending
 | |
|     assert_eq!(npolls.load(SeqCst), 1 + 2);
 | |
| }
 | |
| 
 | |
| #[tokio::test]
 | |
| async fn accept_many() {
 | |
|     use futures::future::poll_fn;
 | |
|     use std::future::Future;
 | |
|     use std::sync::atomic::AtomicBool;
 | |
| 
 | |
|     const N: usize = 50;
 | |
| 
 | |
|     let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
 | |
|     let listener = Arc::new(listener);
 | |
|     let addr = listener.local_addr().unwrap();
 | |
|     let connected = Arc::new(AtomicBool::new(false));
 | |
| 
 | |
|     let (pending_tx, mut pending_rx) = mpsc::unbounded_channel();
 | |
|     let (notified_tx, mut notified_rx) = mpsc::unbounded_channel();
 | |
| 
 | |
|     for _ in 0..N {
 | |
|         let listener = listener.clone();
 | |
|         let connected = connected.clone();
 | |
|         let pending_tx = pending_tx.clone();
 | |
|         let notified_tx = notified_tx.clone();
 | |
| 
 | |
|         tokio::spawn(async move {
 | |
|             let accept = listener.accept();
 | |
|             tokio::pin!(accept);
 | |
| 
 | |
|             let mut polled = false;
 | |
| 
 | |
|             poll_fn(|cx| {
 | |
|                 if !polled {
 | |
|                     polled = true;
 | |
|                     assert!(Pin::new(&mut accept).poll(cx).is_pending());
 | |
|                     pending_tx.send(()).unwrap();
 | |
|                     Poll::Pending
 | |
|                 } else if connected.load(SeqCst) {
 | |
|                     notified_tx.send(()).unwrap();
 | |
|                     Poll::Ready(())
 | |
|                 } else {
 | |
|                     Poll::Pending
 | |
|                 }
 | |
|             })
 | |
|             .await;
 | |
| 
 | |
|             pending_tx.send(()).unwrap();
 | |
|         });
 | |
|     }
 | |
| 
 | |
|     // Wait for all tasks to have polled at least once
 | |
|     for _ in 0..N {
 | |
|         pending_rx.recv().await.unwrap();
 | |
|     }
 | |
| 
 | |
|     // Establish a TCP connection
 | |
|     connected.store(true, SeqCst);
 | |
|     let _sock = TcpStream::connect(addr).await.unwrap();
 | |
| 
 | |
|     // Wait for all notifications
 | |
|     for _ in 0..N {
 | |
|         notified_rx.recv().await.unwrap();
 | |
|     }
 | |
| }
 | 
