From 119e35a2f4123df21d1f9e76deec44d04d9049d3 Mon Sep 17 00:00:00 2001 From: soruh Date: Mon, 13 Jul 2020 00:41:39 +0200 Subject: [PATCH 01/15] switched from using async-std for channels to using async-channels --- .github/workflows/ci.yml | 21 ++++++++++----------- Cargo.toml | 8 +++++--- src/lib.rs | 7 ++++--- tests/tests.rs | 31 +++++++++++++++++-------------- 4 files changed, 36 insertions(+), 31 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5c8a577..ae4c197 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,16 +17,15 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@master + - uses: actions/checkout@master - - name: Install nightly - uses: actions-rs/toolchain@v1 - with: - toolchain: nightly - override: true + - name: Install nightly + uses: actions-rs/toolchain@v1 + with: + toolchain: nightly + override: true - - name: tests - uses: actions-rs/cargo@v1 - with: - command: test - args: --features unstable + - name: tests + uses: actions-rs/cargo@v1 + with: + command: test diff --git a/Cargo.toml b/Cargo.toml index 97bd9c1..40370fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,9 @@ description = "Experimental cooperative cancellation for async-std" [dependencies] pin-project-lite = "0.1.0" -async-std = "1.0" +async-channel = "1.1.1" +futures = "0.3.5" -[features] -unstable = ["async-std/unstable"] + +[dev-dependencies] +async-std = "1.0" \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 3a7026d..385e9af 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -67,11 +67,12 @@ //! The cancellation system is a subset of `C#` [`CancellationToken / CancellationTokenSource`](https://docs.microsoft.com/en-us/dotnet/standard/threading/cancellation-in-managed-threads). //! The `StopToken / StopTokenSource` terminology is borrowed from C++ paper P0660: https://wg21.link/p0660. +use futures::stream::Stream; +use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use async_std::prelude::*; -use async_std::sync::{channel, Receiver, Sender}; +use async_channel::{bounded, Receiver, Sender}; use pin_project_lite::pin_project; enum Never {} @@ -101,7 +102,7 @@ pub struct StopToken { impl Default for StopSource { fn default() -> StopSource { - let (sender, receiver) = channel::(1); + let (sender, receiver) = bounded::(1); StopSource { _chan: sender, diff --git a/tests/tests.rs b/tests/tests.rs index 8fe28f3..b79609e 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,36 +1,39 @@ use std::time::Duration; -use async_std::{prelude::*, task, sync::channel}; +use async_channel::bounded; +use async_std::{prelude::*, task}; use stop_token::StopSource; #[test] fn smoke() { task::block_on(async { - let (sender, receiver) = channel::(10); + let (sender, receiver) = bounded::(10); let stop_source = StopSource::new(); let task = task::spawn({ let stop_token = stop_source.stop_token(); let receiver = receiver.clone(); async move { - let mut xs = Vec::new(); - let mut stream = stop_token.stop_stream(receiver); - while let Some(x) = stream.next().await { - xs.push(x) + let mut xs = Vec::new(); + let mut stream = stop_token.stop_stream(receiver); + while let Some(x) = stream.next().await { + xs.push(x) + } + xs } - xs - }}); - sender.send(1).await; - sender.send(2).await; - sender.send(3).await; + }); + sender.send(1).await.unwrap(); + sender.send(2).await.unwrap(); + sender.send(3).await.unwrap(); task::sleep(Duration::from_millis(250)).await; drop(stop_source); task::sleep(Duration::from_millis(250)).await; - sender.send(4).await; - sender.send(5).await; - sender.send(6).await; + sender.send(4).await.unwrap(); + sender.send(5).await.unwrap(); + sender.send(6).await.unwrap(); + assert_eq!(task.await, vec![1, 2, 3]); }) } From 3fe39cd1e9fa66b988ab402a172641e48bf75b4f Mon Sep 17 00:00:00 2001 From: soruh Date: Mon, 13 Jul 2020 01:45:57 +0200 Subject: [PATCH 02/15] use a custom CondVar instead of channels --- Cargo.toml | 4 ++-- src/lib.rs | 58 ++++++++++++++++++++++++++++++++++++-------------- tests/tests.rs | 18 ++++++++-------- 3 files changed, 53 insertions(+), 27 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 40370fc..6ef2d57 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,9 @@ description = "Experimental cooperative cancellation for async-std" [dependencies] pin-project-lite = "0.1.0" -async-channel = "1.1.1" futures = "0.3.5" +event-listener = "2.2.0" [dev-dependencies] -async-std = "1.0" \ No newline at end of file +async-std = { version = "1.0", features = ["unstable"] } diff --git a/src/lib.rs b/src/lib.rs index 385e9af..e906ae3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -67,16 +67,18 @@ //! The cancellation system is a subset of `C#` [`CancellationToken / CancellationTokenSource`](https://docs.microsoft.com/en-us/dotnet/standard/threading/cancellation-in-managed-threads). //! The `StopToken / StopTokenSource` terminology is borrowed from C++ paper P0660: https://wg21.link/p0660. -use futures::stream::Stream; use std::future::Future; use std::pin::Pin; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; use std::task::{Context, Poll}; -use async_channel::{bounded, Receiver, Sender}; +use event_listener::Event; +use futures::stream::Stream; use pin_project_lite::pin_project; -enum Never {} - /// `StopSource` produces `StopToken` and cancels all of its tokens on drop. /// /// # Example: @@ -87,30 +89,46 @@ enum Never {} /// schedule_some_work(stop_token); /// drop(stop_source); // At this point, scheduled work notices that it is canceled. /// ``` + +#[derive(Debug)] +struct CondVar { + event: Event, + signaled: AtomicBool, +} + #[derive(Debug)] pub struct StopSource { - /// Solely for `Drop`. - _chan: Sender, + signal: Arc, stop_token: StopToken, } /// `StopToken` is a future which completes when the associated `StopSource` is dropped. #[derive(Debug, Clone)] pub struct StopToken { - chan: Receiver, + signal: Arc, } impl Default for StopSource { fn default() -> StopSource { - let (sender, receiver) = bounded::(1); - + let signal = Arc::new(CondVar { + event: Event::new(), + signaled: AtomicBool::new(false), + }); StopSource { - _chan: sender, - stop_token: StopToken { chan: receiver }, + signal: signal.clone(), + stop_token: StopToken { signal }, } } } +impl Drop for StopSource { + fn drop(&mut self) { + // This can probably be `Relaxed` and notify_relaxed + self.signal.signaled.store(true, Ordering::SeqCst); + self.signal.event.notify(usize::MAX); + } +} + impl StopSource { /// Creates a new `StopSource`. pub fn new() -> StopSource { @@ -128,12 +146,20 @@ impl StopSource { impl Future for StopToken { type Output = (); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - let chan = Pin::new(&mut self.chan); - match Stream::poll_next(chan, cx) { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + // We probably need to register this event listener before checking the flag + // to prevent race conditions + let mut event = self.signal.event.listen(); + + // This can also probably be `Relaxed` + if self.signal.signaled.load(Ordering::SeqCst) { + return Poll::Ready(()); + } + + let event = Pin::new(&mut event); + match Future::poll(event, cx) { Poll::Pending => Poll::Pending, - Poll::Ready(Some(never)) => match never {}, - Poll::Ready(None) => Poll::Ready(()), + Poll::Ready(_) => Poll::Ready(()), } } } diff --git a/tests/tests.rs b/tests/tests.rs index b79609e..b2215ca 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,14 +1,13 @@ use std::time::Duration; -use async_channel::bounded; -use async_std::{prelude::*, task}; +use async_std::{prelude::*, sync::channel, task}; use stop_token::StopSource; #[test] fn smoke() { task::block_on(async { - let (sender, receiver) = bounded::(10); + let (sender, receiver) = channel::(10); let stop_source = StopSource::new(); let task = task::spawn({ let stop_token = stop_source.stop_token(); @@ -22,17 +21,18 @@ fn smoke() { xs } }); - sender.send(1).await.unwrap(); - sender.send(2).await.unwrap(); - sender.send(3).await.unwrap(); + sender.send(1).await; + sender.send(2).await; + sender.send(3).await; task::sleep(Duration::from_millis(250)).await; drop(stop_source); + task::sleep(Duration::from_millis(250)).await; - sender.send(4).await.unwrap(); - sender.send(5).await.unwrap(); - sender.send(6).await.unwrap(); + sender.send(4).await; + sender.send(5).await; + sender.send(6).await; assert_eq!(task.await, vec![1, 2, 3]); }) From 9046b391eb02c9e54a5841daaa63c02e2240193f Mon Sep 17 00:00:00 2001 From: soruh Date: Mon, 13 Jul 2020 02:03:26 +0200 Subject: [PATCH 03/15] cleaned up the CondVar implementation --- src/lib.rs | 60 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 21 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e906ae3..5dffe41 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,7 +75,7 @@ use std::sync::{ }; use std::task::{Context, Poll}; -use event_listener::Event; +use event_listener::{Event, EventListener}; use futures::stream::Stream; use pin_project_lite::pin_project; @@ -91,26 +91,51 @@ use pin_project_lite::pin_project; /// ``` #[derive(Debug)] -struct CondVar { +/// a custom implementation of a CondVar that short-circuits after +/// being signaled once +struct ShortCircuitingCondVar { event: Event, signaled: AtomicBool, } +impl ShortCircuitingCondVar { + fn notify(&self, n: usize) { + // TODO: This can probably be `Relaxed` and `notify_relaxed` + self.signaled.store(true, Ordering::SeqCst); + self.event.notify(n); + } + + fn listen(&self) -> Option { + // We probably need to register this event listener before checking the flag + // to prevent race conditions + let listener = self.event.listen(); + + // TODO: This can probably also be `Relaxed` + if self.signaled.load(Ordering::SeqCst) { + // This happens implicitely as `listener` goes out of scope: + + None // The `CondVar` has already been triggered so we don't need to wait + } else { + Some(listener) + } + } +} + #[derive(Debug)] pub struct StopSource { - signal: Arc, + signal: Arc, stop_token: StopToken, } /// `StopToken` is a future which completes when the associated `StopSource` is dropped. #[derive(Debug, Clone)] pub struct StopToken { - signal: Arc, + signal: Arc, } impl Default for StopSource { fn default() -> StopSource { - let signal = Arc::new(CondVar { + let signal = Arc::new(ShortCircuitingCondVar { event: Event::new(), signaled: AtomicBool::new(false), }); @@ -123,9 +148,8 @@ impl Default for StopSource { impl Drop for StopSource { fn drop(&mut self) { - // This can probably be `Relaxed` and notify_relaxed - self.signal.signaled.store(true, Ordering::SeqCst); - self.signal.event.notify(usize::MAX); + // TODO: notifying only one StopToken should be sufficient + self.signal.notify(usize::MAX); } } @@ -147,19 +171,13 @@ impl Future for StopToken { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - // We probably need to register this event listener before checking the flag - // to prevent race conditions - let mut event = self.signal.event.listen(); - - // This can also probably be `Relaxed` - if self.signal.signaled.load(Ordering::SeqCst) { - return Poll::Ready(()); - } - - let event = Pin::new(&mut event); - match Future::poll(event, cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(_) => Poll::Ready(()), + if let Some(mut listener) = self.signal.listen() { + match Future::poll(Pin::new(&mut listener), cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(_) => Poll::Ready(()), + } + } else { + Poll::Ready(()) } } } From e58195ec969c8fb391ed52336c2d3dbb1eab4ead Mon Sep 17 00:00:00 2001 From: soruh Date: Mon, 13 Jul 2020 02:07:29 +0200 Subject: [PATCH 04/15] fixed up Cargo.toml --- Cargo.toml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 6ef2d57..ea7c1a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "stop-token" -version = "0.1.2" +version = "0.1.3" authors = ["Aleksey Kladov "] edition = "2018" license = "MIT OR Apache-2.0" @@ -16,3 +16,8 @@ event-listener = "2.2.0" [dev-dependencies] async-std = { version = "1.0", features = ["unstable"] } + +[features] +unstable = [] +# This feature doesn't do anything anymore, +# but is needed for backwards-compatibility \ No newline at end of file From e33d393a01e58847b905a874c4b6dc5c3f19ef38 Mon Sep 17 00:00:00 2001 From: soruh Date: Mon, 13 Jul 2020 02:08:32 +0200 Subject: [PATCH 05/15] added newline at the end of Cargo.toml --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index ea7c1a9..35b6076 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,4 +20,4 @@ async-std = { version = "1.0", features = ["unstable"] } [features] unstable = [] # This feature doesn't do anything anymore, -# but is needed for backwards-compatibility \ No newline at end of file +# but is needed for backwards-compatibility From 11a3c21adaed537b46dad10ee5a20dda765aff38 Mon Sep 17 00:00:00 2001 From: soruh Date: Mon, 13 Jul 2020 13:17:02 +0200 Subject: [PATCH 06/15] Don't create before checking if the has already been triggered --- src/lib.rs | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 5dffe41..688c7aa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -106,18 +106,21 @@ impl ShortCircuitingCondVar { } fn listen(&self) -> Option { - // We probably need to register this event listener before checking the flag - // to prevent race conditions + // TODO: These could maybe be `Aquire` + // Check if the `CondVar` has already been triggered + if self.signaled.load(Ordering::SeqCst) { + return None; // The `CondVar` has already been triggered so we don't need to wait + } + + // Register a new listener let listener = self.event.listen(); - // TODO: This can probably also be `Relaxed` + // Make sure the `CondVar` still has not been triggered to prevent race conditions if self.signaled.load(Ordering::SeqCst) { - // This happens implicitely as `listener` goes out of scope: - - None // The `CondVar` has already been triggered so we don't need to wait - } else { - Some(listener) + return None; } + + Some(listener) } } From d6700d8a46d45b3a12081ea5c63a623def91ddb4 Mon Sep 17 00:00:00 2001 From: soruh Date: Mon, 13 Jul 2020 14:52:38 +0200 Subject: [PATCH 07/15] implement caching `EventListeners` --- src/lib.rs | 57 +++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 52 insertions(+), 5 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 688c7aa..d14d34c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -69,8 +69,9 @@ use std::future::Future; use std::pin::Pin; +use std::ptr::null_mut; use std::sync::{ - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicPtr, Ordering}, Arc, }; use std::task::{Context, Poll}; @@ -94,6 +95,8 @@ use pin_project_lite::pin_project; /// a custom implementation of a CondVar that short-circuits after /// being signaled once struct ShortCircuitingCondVar { + // TODO: is there a better (safer) way to have an atomic `Option`? + cached_listener: AtomicPtr, // This is a raw pointer to a Box event: Event, signaled: AtomicBool, } @@ -112,8 +115,18 @@ impl ShortCircuitingCondVar { return None; // The `CondVar` has already been triggered so we don't need to wait } - // Register a new listener - let listener = self.event.listen(); + let ptr = self.cached_listener.swap(null_mut(), Ordering::SeqCst); + + let listener = if ptr.is_null() { + // Register a new listener + self.event.listen() + } else { + // turn the cached listener back into a `EventListener` object + + // Safety: the `cached_listener` is not null and can only come from a Box::into_raw + // it is also replaced with a `null` pointer when read so there can not be another owner. + unsafe { *Box::from_raw(ptr) } + }; // Make sure the `CondVar` still has not been triggered to prevent race conditions if self.signaled.load(Ordering::SeqCst) { @@ -122,6 +135,33 @@ impl ShortCircuitingCondVar { Some(listener) } + + fn cache_listener(&self, listener: EventListener) -> Result<(), EventListener> { + // Check if there is already a cached listener + if self.cached_listener.load(Ordering::SeqCst).is_null() { + let listener = Box::new(listener); + + unsafe { + let res = self.cached_listener.compare_and_swap( + null_mut(), + Box::into_raw(listener), + Ordering::SeqCst, + ); + if res.is_null() { + Ok(()) + } else { + // We failed to write our new cached listener due to a race + // Turn it back into a Box and return it to the caller + + // Safety: the `cached_listener` is not null and can only come from a Box::into_raw + // it is also replaced with a `null` pointer when read so there can not be another owner. + Err(*Box::from_raw(res)) + } + } + } else { + Err(listener) + } + } } #[derive(Debug)] @@ -139,6 +179,7 @@ pub struct StopToken { impl Default for StopSource { fn default() -> StopSource { let signal = Arc::new(ShortCircuitingCondVar { + cached_listener: AtomicPtr::new(null_mut()), event: Event::new(), signaled: AtomicBool::new(false), }); @@ -175,10 +216,16 @@ impl Future for StopToken { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { if let Some(mut listener) = self.signal.listen() { - match Future::poll(Pin::new(&mut listener), cx) { + let result = match Future::poll(Pin::new(&mut listener), cx) { Poll::Pending => Poll::Pending, Poll::Ready(_) => Poll::Ready(()), - } + }; + + // Try to cache the listener, if there already is a cached listener + // drop the one we have + let _ = self.signal.cache_listener(listener); + + return result; } else { Poll::Ready(()) } From 35516ff0b8db590cf63bc3fa676bce2c0cb8594b Mon Sep 17 00:00:00 2001 From: soruh Date: Mon, 13 Jul 2020 14:58:40 +0200 Subject: [PATCH 08/15] make `StopToken` a tuple struct --- src/lib.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d14d34c..e94cd9e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -172,9 +172,7 @@ pub struct StopSource { /// `StopToken` is a future which completes when the associated `StopSource` is dropped. #[derive(Debug, Clone)] -pub struct StopToken { - signal: Arc, -} +pub struct StopToken(Arc); impl Default for StopSource { fn default() -> StopSource { @@ -185,7 +183,7 @@ impl Default for StopSource { }); StopSource { signal: signal.clone(), - stop_token: StopToken { signal }, + stop_token: StopToken(signal), } } } @@ -215,7 +213,7 @@ impl Future for StopToken { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - if let Some(mut listener) = self.signal.listen() { + if let Some(mut listener) = self.0.listen() { let result = match Future::poll(Pin::new(&mut listener), cx) { Poll::Pending => Poll::Pending, Poll::Ready(_) => Poll::Ready(()), @@ -223,7 +221,7 @@ impl Future for StopToken { // Try to cache the listener, if there already is a cached listener // drop the one we have - let _ = self.signal.cache_listener(listener); + let _ = self.0.cache_listener(listener); return result; } else { From 6f4fab201c80f3126198a02d3c84118ba94f3d95 Mon Sep 17 00:00:00 2001 From: soruh Date: Mon, 13 Jul 2020 16:07:18 +0200 Subject: [PATCH 09/15] added test to contest cached_listener --- src/lib.rs | 32 ++++++------- tests/tests.rs | 122 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 137 insertions(+), 17 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e94cd9e..04fa23c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -137,26 +137,24 @@ impl ShortCircuitingCondVar { } fn cache_listener(&self, listener: EventListener) -> Result<(), EventListener> { - // Check if there is already a cached listener + // Check if there is already a cached listener to prevent `Box` if it's not neccessary if self.cached_listener.load(Ordering::SeqCst).is_null() { let listener = Box::new(listener); - unsafe { - let res = self.cached_listener.compare_and_swap( - null_mut(), - Box::into_raw(listener), - Ordering::SeqCst, - ); - if res.is_null() { - Ok(()) - } else { - // We failed to write our new cached listener due to a race - // Turn it back into a Box and return it to the caller - - // Safety: the `cached_listener` is not null and can only come from a Box::into_raw - // it is also replaced with a `null` pointer when read so there can not be another owner. - Err(*Box::from_raw(res)) - } + let res = self.cached_listener.compare_and_swap( + null_mut(), + Box::into_raw(listener), + Ordering::SeqCst, + ); + if res.is_null() { + Ok(()) + } else { + // We failed to write our new cached listener due to a race + // Turn it back into a Box and return it to the caller + + // Safety: the `cached_listener` is not null and can only come from a Box::into_raw + // it is also replaced with a `null` pointer when read so there can not be another owner. + Err(*unsafe { Box::from_raw(res) }) } } else { Err(listener) diff --git a/tests/tests.rs b/tests/tests.rs index b2215ca..7da8c2e 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -37,3 +37,125 @@ fn smoke() { assert_eq!(task.await, vec![1, 2, 3]); }) } + +#[test] +fn multiple_tokens() { + task::block_on(async { + let stop_source = StopSource::new(); + + let (sender_a, receiver_a) = channel::(10); + let task_a = task::spawn({ + let stop_token = stop_source.stop_token(); + let receiver = receiver_a.clone(); + async move { + let mut xs = Vec::new(); + let mut stream = stop_token.stop_stream(receiver); + while let Some(x) = stream.next().await { + xs.push(x) + } + xs + } + }); + + let (sender_b, receiver_b) = channel::(10); + let task_b = task::spawn({ + let stop_token = stop_source.stop_token(); + let receiver = receiver_b.clone(); + async move { + let mut xs = Vec::new(); + let mut stream = stop_token.stop_stream(receiver); + while let Some(x) = stream.next().await { + xs.push(x) + } + xs + } + }); + + sender_a.send(1).await; + sender_a.send(2).await; + sender_a.send(3).await; + + sender_b.send(101).await; + sender_b.send(102).await; + sender_b.send(103).await; + + task::sleep(Duration::from_millis(250)).await; + + drop(stop_source); + + task::sleep(Duration::from_millis(250)).await; + + sender_a.send(4).await; + sender_a.send(5).await; + sender_a.send(6).await; + + sender_b.send(104).await; + sender_b.send(105).await; + sender_b.send(106).await; + + assert_eq!(task_a.await, vec![1, 2, 3]); + assert_eq!(task_b.await, vec![101, 102, 103]); + }) +} + +#[test] +fn contest_cached_listener() { + task::block_on(async { + let stop_source = StopSource::new(); + + let (sender_a, receiver_a) = channel::(10); + let recv_task_a = task::spawn({ + let stop_token = stop_source.stop_token(); + let receiver = receiver_a.clone(); + async move { + let mut xs = Vec::new(); + let mut stream = stop_token.stop_stream(receiver); + while let Some(x) = stream.next().await { + xs.push(x) + } + xs + } + }); + + let (sender_b, receiver_b) = channel::(10); + let recv_task_b = task::spawn({ + let stop_token = stop_source.stop_token(); + let receiver = receiver_b.clone(); + async move { + let mut xs = Vec::new(); + let mut stream = stop_token.stop_stream(receiver); + while let Some(x) = stream.next().await { + xs.push(x) + } + xs + } + }); + + let _send_task_a = task::spawn({ + let sender = sender_a.clone(); + async move { + for i in 0.. { + sender.send(i).await; + } + } + }); + + let _send_task_b = task::spawn({ + let sender = sender_b.clone(); + async move { + for i in 0.. { + sender.send(i).await; + } + } + }); + + task::sleep(Duration::from_millis(250)).await; + + drop(stop_source); + + task::sleep(Duration::from_millis(250)).await; + + assert!(!recv_task_a.await.is_empty()); + assert!(!recv_task_b.await.is_empty()) + }) +} From 890ca45bd0aa4d93c819dddf3f1243e55fa319a8 Mon Sep 17 00:00:00 2001 From: soruh Date: Mon, 13 Jul 2020 17:23:23 +0200 Subject: [PATCH 10/15] put Event behind an AtomicPointer to prevent races --- src/lib.rs | 71 ++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 45 insertions(+), 26 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 04fa23c..58594ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,7 +71,7 @@ use std::future::Future; use std::pin::Pin; use std::ptr::null_mut; use std::sync::{ - atomic::{AtomicBool, AtomicPtr, Ordering}, + atomic::{AtomicPtr, Ordering}, Arc, }; use std::task::{Context, Poll}; @@ -97,43 +97,65 @@ use pin_project_lite::pin_project; struct ShortCircuitingCondVar { // TODO: is there a better (safer) way to have an atomic `Option`? cached_listener: AtomicPtr, // This is a raw pointer to a Box - event: Event, - signaled: AtomicBool, + event: AtomicPtr, +} + +impl Drop for ShortCircuitingCondVar { + fn drop(&mut self) { + // make sure we don't leak any listeners or the event + { + self.take_cached_listener(); + } + self.notify(usize::MAX); + } } impl ShortCircuitingCondVar { fn notify(&self, n: usize) { // TODO: This can probably be `Relaxed` and `notify_relaxed` - self.signaled.store(true, Ordering::SeqCst); - self.event.notify(n); + let event = self.event.swap(null_mut(), Ordering::SeqCst); + if !event.is_null() { + // Safety: event can only either be a valid raw Box pointer of null + let event = *unsafe { Box::from_raw(event) }; + event.notify(n); + } + } + + fn take_cached_listener(&self) -> Option { + let ptr = self.cached_listener.swap(null_mut(), Ordering::SeqCst); + + if ptr.is_null() { + None + } else { + // Safety: the `cached_listener` is not null and can only come from a Box::into_raw + // it is also replaced with a `null` pointer when read so there can not be another owner. + Some(unsafe { *Box::from_raw(ptr) }) + } } fn listen(&self) -> Option { // TODO: These could maybe be `Aquire` // Check if the `CondVar` has already been triggered - if self.signaled.load(Ordering::SeqCst) { + if self.event.load(Ordering::SeqCst).is_null() { return None; // The `CondVar` has already been triggered so we don't need to wait } - let ptr = self.cached_listener.swap(null_mut(), Ordering::SeqCst); + let listener = self.take_cached_listener().or_else(|| { + let event = self.event.load(Ordering::SeqCst); - let listener = if ptr.is_null() { - // Register a new listener - self.event.listen() - } else { - // turn the cached listener back into a `EventListener` object + if event.is_null() { + return None; + } - // Safety: the `cached_listener` is not null and can only come from a Box::into_raw - // it is also replaced with a `null` pointer when read so there can not be another owner. - unsafe { *Box::from_raw(ptr) } - }; + // Safety: `event` is not null and is not used mutably by anyone until it is dropped + let event: &Event = unsafe { std::mem::transmute(event) }; - // Make sure the `CondVar` still has not been triggered to prevent race conditions - if self.signaled.load(Ordering::SeqCst) { - return None; - } + let listener = event.listen(); + + Some(listener) + }); - Some(listener) + listener } fn cache_listener(&self, listener: EventListener) -> Result<(), EventListener> { @@ -165,7 +187,6 @@ impl ShortCircuitingCondVar { #[derive(Debug)] pub struct StopSource { signal: Arc, - stop_token: StopToken, } /// `StopToken` is a future which completes when the associated `StopSource` is dropped. @@ -176,12 +197,10 @@ impl Default for StopSource { fn default() -> StopSource { let signal = Arc::new(ShortCircuitingCondVar { cached_listener: AtomicPtr::new(null_mut()), - event: Event::new(), - signaled: AtomicBool::new(false), + event: AtomicPtr::new(Box::into_raw(Box::new(Event::new()))), }); StopSource { signal: signal.clone(), - stop_token: StopToken(signal), } } } @@ -203,7 +222,7 @@ impl StopSource { /// /// Once the source is destroyed, `StopToken` future completes. pub fn stop_token(&self) -> StopToken { - self.stop_token.clone() + StopToken(self.signal.clone()) } } From bbac6b62f98162a142ae7353895653261661b81d Mon Sep 17 00:00:00 2001 From: soruh Date: Mon, 13 Jul 2020 18:18:22 +0200 Subject: [PATCH 11/15] cache listeners in `StopToken` instead of in `StopSource` move most of the `unsafe`ty into a mostly safe `AtomicBool` type --- src/lib.rs | 206 +++++++++++++++++++++++++++++++---------------------- 1 file changed, 121 insertions(+), 85 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 58594ff..b10b188 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -74,11 +74,11 @@ use std::sync::{ atomic::{AtomicPtr, Ordering}, Arc, }; -use std::task::{Context, Poll}; use event_listener::{Event, EventListener}; use futures::stream::Stream; use pin_project_lite::pin_project; +use std::task::{Context, Poll}; /// `StopSource` produces `StopToken` and cancels all of its tokens on drop. /// @@ -91,123 +91,163 @@ use pin_project_lite::pin_project; /// drop(stop_source); // At this point, scheduled work notices that it is canceled. /// ``` -#[derive(Debug)] -/// a custom implementation of a CondVar that short-circuits after -/// being signaled once -struct ShortCircuitingCondVar { - // TODO: is there a better (safer) way to have an atomic `Option`? - cached_listener: AtomicPtr, // This is a raw pointer to a Box - event: AtomicPtr, -} +struct AtomicOption(AtomicPtr); -impl Drop for ShortCircuitingCondVar { - fn drop(&mut self) { - // make sure we don't leak any listeners or the event - { - self.take_cached_listener(); - } - self.notify(usize::MAX); +impl AtomicOption { + fn is_none(&self) -> bool { + self.0.load(Ordering::SeqCst).is_null() } -} -impl ShortCircuitingCondVar { - fn notify(&self, n: usize) { - // TODO: This can probably be `Relaxed` and `notify_relaxed` - let event = self.event.swap(null_mut(), Ordering::SeqCst); - if !event.is_null() { - // Safety: event can only either be a valid raw Box pointer of null - let event = *unsafe { Box::from_raw(event) }; - event.notify(n); - } + #[allow(dead_code)] + fn is_some(&self) -> bool { + !self.is_none() } - fn take_cached_listener(&self) -> Option { - let ptr = self.cached_listener.swap(null_mut(), Ordering::SeqCst); - + /// Safety: the caller needs to make sure no one is writing to this value + /// and that no one will do so until the returned refence it dropped + #[allow(unused_unsafe)] + unsafe fn as_ref<'s>(&'s self) -> Option<&'s T> { + let ptr = self.0.load(Ordering::SeqCst); if ptr.is_null() { None } else { - // Safety: the `cached_listener` is not null and can only come from a Box::into_raw - // it is also replaced with a `null` pointer when read so there can not be another owner. - Some(unsafe { *Box::from_raw(ptr) }) + // Safety: we know that ptr is a valid ptr to a `T` + // since it is not null and it can only be written in `new` or `replace` + Some(unsafe { std::mem::transmute(ptr) }) } } - fn listen(&self) -> Option { - // TODO: These could maybe be `Aquire` - // Check if the `CondVar` has already been triggered - if self.event.load(Ordering::SeqCst).is_null() { - return None; // The `CondVar` has already been triggered so we don't need to wait - } + fn new(value: Option) -> Self { + let ptr = if let Some(value) = value { + Box::into_raw(Box::new(value)) + } else { + null_mut() + }; - let listener = self.take_cached_listener().or_else(|| { - let event = self.event.load(Ordering::SeqCst); + Self(AtomicPtr::new(ptr)) + } - if event.is_null() { - return None; - } + fn take(&self) -> Option { + let ptr = self.0.swap(null_mut(), Ordering::SeqCst); - // Safety: `event` is not null and is not used mutably by anyone until it is dropped - let event: &Event = unsafe { std::mem::transmute(event) }; + if ptr.is_null() { + None + } else { + // Safety: we know that `ptr` is not null and can only have been created from a `Box` by `new` or `replace` + // means it's safe to turn back into a `Box` + Some(*unsafe { Box::from_raw(ptr) }) + } + } - let listener = event.listen(); + #[allow(dead_code)] + fn replace(&self, new: Option) -> Option { + let new_ptr = if let Some(new) = new { + Box::into_raw(Box::new(new)) + } else { + null_mut() + }; - Some(listener) - }); + let ptr = self.0.swap(new_ptr, Ordering::SeqCst); - listener + if ptr.is_null() { + None + } else { + // Safety: we know that `ptr` is not null and can only have been created from a `Box` by `new` or `replace` + // means it's safe to turn back into a `Box` + Some(*unsafe { Box::from_raw(ptr) }) + } + } +} + +impl Drop for AtomicOption { + fn drop(&mut self) { + std::mem::drop(self.take()); } +} - fn cache_listener(&self, listener: EventListener) -> Result<(), EventListener> { - // Check if there is already a cached listener to prevent `Box` if it's not neccessary - if self.cached_listener.load(Ordering::SeqCst).is_null() { - let listener = Box::new(listener); - - let res = self.cached_listener.compare_and_swap( - null_mut(), - Box::into_raw(listener), - Ordering::SeqCst, - ); - if res.is_null() { - Ok(()) - } else { - // We failed to write our new cached listener due to a race - // Turn it back into a Box and return it to the caller - - // Safety: the `cached_listener` is not null and can only come from a Box::into_raw - // it is also replaced with a `null` pointer when read so there can not be another owner. - Err(*unsafe { Box::from_raw(res) }) - } +impl std::fmt::Debug for AtomicOption +where + T: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.is_none() { + write!(f, "None") } else { - Err(listener) + write!(f, "Some()") } } } +/// a custom implementation of a CondVar that short-circuits after +/// being signaled once +#[derive(Debug)] +struct ShortCircuitingCondVar(AtomicOption); + +impl ShortCircuitingCondVar { + fn is_done(&self) -> bool { + self.0.is_none() + } + + fn notify(&self, n: usize) -> bool { + self.0.take().map(|x| x.notify(n)).is_some() + } + + fn listen(&self) -> Option { + // safety: + unsafe { self.0.as_ref() }.map(|event| event.listen()) + } +} + #[derive(Debug)] pub struct StopSource { signal: Arc, } /// `StopToken` is a future which completes when the associated `StopSource` is dropped. -#[derive(Debug, Clone)] -pub struct StopToken(Arc); +#[derive(Debug)] +pub struct StopToken { + cond_var: Arc, + cached_listener: Option, +} + +impl StopToken { + fn new(cond_var: Arc) -> Self { + Self { + cond_var, + cached_listener: None, + } + } + + fn listen(&mut self) -> Option<&mut EventListener> { + if self.cond_var.is_done() { + return None; + } + + if self.cached_listener.is_none() { + self.cached_listener = self.cond_var.listen(); + } + self.cached_listener.as_mut() + } +} + +impl Clone for StopToken { + fn clone(&self) -> Self { + Self::new(self.cond_var.clone()) + } +} impl Default for StopSource { fn default() -> StopSource { - let signal = Arc::new(ShortCircuitingCondVar { - cached_listener: AtomicPtr::new(null_mut()), - event: AtomicPtr::new(Box::into_raw(Box::new(Event::new()))), - }); StopSource { - signal: signal.clone(), + signal: Arc::new(ShortCircuitingCondVar(AtomicOption::new( + Some(Event::new()), + ))), } } } impl Drop for StopSource { fn drop(&mut self) { - // TODO: notifying only one StopToken should be sufficient self.signal.notify(usize::MAX); } } @@ -222,24 +262,20 @@ impl StopSource { /// /// Once the source is destroyed, `StopToken` future completes. pub fn stop_token(&self) -> StopToken { - StopToken(self.signal.clone()) + StopToken::new(self.signal.clone()) } } impl Future for StopToken { type Output = (); - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - if let Some(mut listener) = self.0.listen() { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + if let Some(mut listener) = self.listen() { let result = match Future::poll(Pin::new(&mut listener), cx) { Poll::Pending => Poll::Pending, Poll::Ready(_) => Poll::Ready(()), }; - // Try to cache the listener, if there already is a cached listener - // drop the one we have - let _ = self.0.cache_listener(listener); - return result; } else { Poll::Ready(()) From 785bb0c2cfb0147d60b66286a9f3e0ebf47f3f5b Mon Sep 17 00:00:00 2001 From: soruh Date: Mon, 13 Jul 2020 18:23:50 +0200 Subject: [PATCH 12/15] ammended safety argument for `listen` --- src/lib.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index b10b188..69f7f6d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -193,7 +193,13 @@ impl ShortCircuitingCondVar { } fn listen(&self) -> Option { - // safety: + // Safety: + // The `Event` (`self.0`) is only written when it's dropped and since + // it oviously still exists we can have an immutable reference to it. + // Our reference is also only used during our function scope + // during which we have borrowed `self`. + // This means that the reference returned by `as_ref` adheres to rust borrowing + // rust which makes this operation safe. unsafe { self.0.as_ref() }.map(|event| event.listen()) } } From 351ee6eec87115bf02eb8ba4cbd038f979c9177d Mon Sep 17 00:00:00 2001 From: soruh <33131839+soruh@users.noreply.github.com> Date: Mon, 13 Jul 2020 18:26:45 +0200 Subject: [PATCH 13/15] Add "TODO: relax orderings on atomic accesses" for `AtomicOption` --- src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib.rs b/src/lib.rs index 69f7f6d..a5794f3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,6 +93,7 @@ use std::task::{Context, Poll}; struct AtomicOption(AtomicPtr); +// TODO: relax orderings on atomic accesses impl AtomicOption { fn is_none(&self) -> bool { self.0.load(Ordering::SeqCst).is_null() From 735b561ead04cee6aaf77b14745f9a211b87dae3 Mon Sep 17 00:00:00 2001 From: soruh Date: Mon, 13 Jul 2020 18:55:35 +0200 Subject: [PATCH 14/15] Make AtomicBool use Arcs to be able to safely get references to it's inner value --- src/lib.rs | 52 ++++++++++++++++++++-------------------------------- 1 file changed, 20 insertions(+), 32 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a5794f3..25297c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -91,7 +91,8 @@ use std::task::{Context, Poll}; /// drop(stop_source); // At this point, scheduled work notices that it is canceled. /// ``` -struct AtomicOption(AtomicPtr); +/// An immutable, atomic option type that store data in Boxed Arcs +struct AtomicOption(AtomicPtr>); // TODO: relax orderings on atomic accesses impl AtomicOption { @@ -104,23 +105,26 @@ impl AtomicOption { !self.is_none() } - /// Safety: the caller needs to make sure no one is writing to this value - /// and that no one will do so until the returned refence it dropped - #[allow(unused_unsafe)] - unsafe fn as_ref<'s>(&'s self) -> Option<&'s T> { + fn get(&self) -> Option> { let ptr = self.0.load(Ordering::SeqCst); if ptr.is_null() { None } else { - // Safety: we know that ptr is a valid ptr to a `T` - // since it is not null and it can only be written in `new` or `replace` - Some(unsafe { std::mem::transmute(ptr) }) + // Safety: we know that `ptr` is not null and can only have been created from a `Box` by `new` or `replace` + // this means it's safe to turn back into a `Box` + let arc_box = unsafe { Box::from_raw(ptr as *mut Arc) }; + + let arc = *arc_box.clone(); // Clone the Arc + + Box::leak(arc_box); // And make sure rust doesn't drop our inner value + + Some(arc) } } fn new(value: Option) -> Self { let ptr = if let Some(value) = value { - Box::into_raw(Box::new(value)) + Box::into_raw(Box::new(Arc::new(value))) } else { null_mut() }; @@ -128,22 +132,13 @@ impl AtomicOption { Self(AtomicPtr::new(ptr)) } - fn take(&self) -> Option { - let ptr = self.0.swap(null_mut(), Ordering::SeqCst); - - if ptr.is_null() { - None - } else { - // Safety: we know that `ptr` is not null and can only have been created from a `Box` by `new` or `replace` - // means it's safe to turn back into a `Box` - Some(*unsafe { Box::from_raw(ptr) }) - } + fn take(&self) -> Option> { + self.replace(None) } - #[allow(dead_code)] - fn replace(&self, new: Option) -> Option { + fn replace(&self, new: Option) -> Option> { let new_ptr = if let Some(new) = new { - Box::into_raw(Box::new(new)) + Box::into_raw(Box::new(Arc::new(new))) } else { null_mut() }; @@ -154,8 +149,8 @@ impl AtomicOption { None } else { // Safety: we know that `ptr` is not null and can only have been created from a `Box` by `new` or `replace` - // means it's safe to turn back into a `Box` - Some(*unsafe { Box::from_raw(ptr) }) + // this means it's safe to turn back into a `Box` + Some(unsafe { *Box::from_raw(ptr) }) } } } @@ -194,14 +189,7 @@ impl ShortCircuitingCondVar { } fn listen(&self) -> Option { - // Safety: - // The `Event` (`self.0`) is only written when it's dropped and since - // it oviously still exists we can have an immutable reference to it. - // Our reference is also only used during our function scope - // during which we have borrowed `self`. - // This means that the reference returned by `as_ref` adheres to rust borrowing - // rust which makes this operation safe. - unsafe { self.0.as_ref() }.map(|event| event.listen()) + self.0.get().map(|event| event.listen()) } } From e46ead66e23396c60ee9f90afd2dd747c5374ad3 Mon Sep 17 00:00:00 2001 From: soruh Date: Mon, 13 Jul 2020 19:04:45 +0200 Subject: [PATCH 15/15] make `contest_cached_listener` use more tasks --- tests/tests.rs | 77 +++++++++++++++++++++----------------------------- 1 file changed, 32 insertions(+), 45 deletions(-) diff --git a/tests/tests.rs b/tests/tests.rs index 7da8c2e..98bdbc7 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -103,59 +103,46 @@ fn contest_cached_listener() { task::block_on(async { let stop_source = StopSource::new(); - let (sender_a, receiver_a) = channel::(10); - let recv_task_a = task::spawn({ - let stop_token = stop_source.stop_token(); - let receiver = receiver_a.clone(); - async move { - let mut xs = Vec::new(); - let mut stream = stop_token.stop_stream(receiver); - while let Some(x) = stream.next().await { - xs.push(x) + const N: usize = 8; + + let mut recv_tasks = Vec::with_capacity(N); + let mut send_tasks = Vec::with_capacity(N); + + for _ in 0..N { + let (sender, receiver) = channel::(10); + let recv_task = task::spawn({ + let stop_token = stop_source.stop_token(); + let receiver = receiver.clone(); + async move { + let mut messages = Vec::new(); + let mut stream = stop_token.stop_stream(receiver); + while let Some(msg) = stream.next().await { + messages.push(msg) + } + messages } - xs - } - }); + }); - let (sender_b, receiver_b) = channel::(10); - let recv_task_b = task::spawn({ - let stop_token = stop_source.stop_token(); - let receiver = receiver_b.clone(); - async move { - let mut xs = Vec::new(); - let mut stream = stop_token.stop_stream(receiver); - while let Some(x) = stream.next().await { - xs.push(x) - } - xs - } - }); - - let _send_task_a = task::spawn({ - let sender = sender_a.clone(); - async move { - for i in 0.. { - sender.send(i).await; + let send_task = task::spawn({ + async move { + for msg in 0.. { + sender.send(msg).await; + } } - } - }); + }); - let _send_task_b = task::spawn({ - let sender = sender_b.clone(); - async move { - for i in 0.. { - sender.send(i).await; - } - } - }); + recv_tasks.push(recv_task); + send_tasks.push(send_task); + } - task::sleep(Duration::from_millis(250)).await; + task::sleep(Duration::from_millis(500)).await; drop(stop_source); - task::sleep(Duration::from_millis(250)).await; + task::sleep(Duration::from_millis(500)).await; - assert!(!recv_task_a.await.is_empty()); - assert!(!recv_task_b.await.is_empty()) + for (i, recv_task) in recv_tasks.into_iter().enumerate() { + eprintln!("receiver {} got {} messages", i, recv_task.await.len()); + } }) }