Skip to content

Commit 52e1755

Browse files
committed
Add stdexec::simple_counting_scope
This diff adds `stdexec::simple_counting_scope` and one compeletely trivial test case. Still a work in progress.
1 parent e75778e commit 52e1755

File tree

4 files changed

+425
-0
lines changed

4 files changed

+425
-0
lines changed
Lines changed: 395 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,395 @@
1+
/*
2+
* Copyright (c) 2025 Ian Petersen
3+
* Copyright (c) 2025 NVIDIA Corporation
4+
*
5+
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* https://llvm.org/LICENSE.txt
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
#pragma once
18+
19+
#include "__execution_fwd.hpp"
20+
21+
#include "__concepts.hpp"
22+
#include "__env.hpp"
23+
#include "__receivers.hpp"
24+
#include "__schedulers.hpp"
25+
#include "__senders_core.hpp"
26+
#include "__sender_introspection.hpp"
27+
#include "__type_traits.hpp"
28+
29+
#include <atomic>
30+
#include <cstddef>
31+
#include <exception>
32+
#include <limits>
33+
#include <memory>
34+
#include <type_traits>
35+
#include <utility>
36+
37+
namespace stdexec {
38+
class simple_counting_scope;
39+
40+
namespace __counting_scopes {
41+
struct __state_base {
42+
__state_base* __next_{};
43+
44+
__state_base() = default;
45+
__state_base(__state_base&&) = delete;
46+
47+
virtual void __complete() noexcept = 0;
48+
49+
protected:
50+
~__state_base() = default;
51+
};
52+
53+
struct __scope_join_t { };
54+
55+
struct __scope_join_impl : __sexpr_defaults {
56+
template <class _Scope, class _Rcvr>
57+
struct __state final : __state_base {
58+
struct __rcvr_t {
59+
using receiver_concept = receiver_t;
60+
61+
_Rcvr& __rcvr_;
62+
63+
void set_value() && noexcept {
64+
stdexec::set_value(std::move(__rcvr_));
65+
}
66+
67+
template <class E>
68+
void set_error(E&& e) && noexcept {
69+
stdexec::set_error(std::move(__rcvr_), static_cast<E&&>(e));
70+
};
71+
72+
void set_stopped() && noexcept {
73+
stdexec::set_stopped(std::move(__rcvr_));
74+
}
75+
76+
decltype(auto) get_env() const noexcept {
77+
return stdexec::get_env(__rcvr_);
78+
}
79+
};
80+
81+
using __sched_sender = decltype(schedule(get_scheduler(get_env(__declval<_Rcvr&>()))));
82+
83+
using __op_t = connect_result_t<__sched_sender, __rcvr_t>;
84+
85+
_Scope* __scope_;
86+
_Rcvr& __receiver_;
87+
__op_t __op_;
88+
89+
__state(_Scope* __scope, _Rcvr& __rcvr)
90+
noexcept(__nothrow_callable<connect_t, __sched_sender, __rcvr_t>)
91+
: __scope_(__scope)
92+
, __receiver_(__rcvr)
93+
, __op_(connect(schedule(get_scheduler(get_env(__rcvr)))), __rcvr_t(__rcvr)) {
94+
}
95+
96+
void __complete() noexcept override {
97+
start(__op_);
98+
}
99+
100+
void __complete_inline() noexcept {
101+
set_value(std::move(__receiver_));
102+
}
103+
};
104+
105+
static constexpr auto get_state =
106+
[]<class _Sndr, class _Rcvr>(_Sndr&& __sender, _Rcvr& __receiver) noexcept(
107+
__nothrow_constructible_from<
108+
__state<_Rcvr, __data_of<std::remove_cvref_t<_Sndr>>>,
109+
__data_of<std::remove_cvref_t<_Sndr>>,
110+
_Rcvr&
111+
>) {
112+
auto [_, self] = __sender;
113+
return __state(self, __receiver);
114+
};
115+
116+
static constexpr auto start = [](auto& __s, auto&) noexcept {
117+
if (__s.__scope_->__start_join_sender(__s)) {
118+
__s.__complete_inline();
119+
}
120+
};
121+
};
122+
123+
template <class _Scope>
124+
struct __association_t {
125+
constexpr __association_t() = default;
126+
127+
constexpr __association_t(__association_t&& __other) noexcept
128+
: __scope_(std::exchange(__other.__scope_, nullptr)) {
129+
}
130+
131+
~__association_t() {
132+
if (__scope_ != nullptr) {
133+
__scope_->__disassociate();
134+
}
135+
}
136+
137+
__association_t& operator=(__association_t __rhs) noexcept {
138+
std::swap(__scope_, __rhs.__scope_);
139+
return *this;
140+
}
141+
142+
constexpr explicit operator bool() const noexcept {
143+
return __scope_ != nullptr;
144+
}
145+
146+
__association_t try_associate() const noexcept {
147+
if (__scope_) {
148+
return __scope_->__try_associate();
149+
} else {
150+
return __association_t();
151+
}
152+
}
153+
154+
private:
155+
friend simple_counting_scope;
156+
157+
_Scope* __scope_{};
158+
159+
constexpr __association_t(_Scope& __scope) noexcept
160+
: __scope_(std::addressof(__scope)) {
161+
}
162+
};
163+
} // namespace __counting_scopes
164+
165+
template <>
166+
struct __sexpr_impl<__counting_scopes::__scope_join_t> : __counting_scopes::__scope_join_impl { };
167+
168+
class simple_counting_scope {
169+
public:
170+
using __assoc_t = __counting_scopes::__association_t<simple_counting_scope>;
171+
172+
struct token {
173+
template <sender _Sender>
174+
_Sender&& wrap(_Sender&& __snd) const noexcept {
175+
return static_cast<_Sender&&>(__snd);
176+
}
177+
178+
__assoc_t try_associate() const noexcept {
179+
return __scope_->__try_associate();
180+
}
181+
182+
private:
183+
friend class simple_counting_scope;
184+
185+
simple_counting_scope* __scope_;
186+
187+
explicit token(simple_counting_scope* __scope) noexcept
188+
: __scope_(__scope) {
189+
}
190+
};
191+
192+
static constexpr std::size_t max_associations = std::numeric_limits<std::size_t>::max() >> 3;
193+
194+
simple_counting_scope() = default;
195+
196+
simple_counting_scope(simple_counting_scope&&) = delete;
197+
198+
~simple_counting_scope() {
199+
auto state = __state_.load(std::memory_order_relaxed);
200+
if (__is_join_needed(state) || __count(state) != 0ul) {
201+
std::terminate();
202+
}
203+
}
204+
205+
token get_token() noexcept {
206+
return token{this};
207+
}
208+
209+
void close() noexcept {
210+
__state_.fetch_or(__closed, std::memory_order_relaxed);
211+
}
212+
213+
sender auto join() noexcept {
214+
return __make_sexpr<__counting_scopes::__scope_join_t>(this);
215+
}
216+
217+
private:
218+
friend __assoc_t;
219+
220+
static constexpr std::size_t __closed{1ul};
221+
static constexpr std::size_t __join_needed{2ul};
222+
static constexpr std::size_t __join_running{4ul};
223+
224+
std::atomic<std::size_t> __state_{0ul};
225+
std::atomic<void*> __waitingJoinOps_{nullptr};
226+
227+
__assoc_t __try_associate() noexcept {
228+
auto state = __state_.load(std::memory_order_relaxed);
229+
230+
do {
231+
if (__is_closed(state) || __count(state) == max_associations) {
232+
// the scope is closed or full so deny a new association
233+
return __assoc_t();
234+
}
235+
236+
// increment the count and ensure the join-needed bit is set
237+
const auto newState = __make_state(__count(state) + 1ul, __bits(state) | __join_needed);
238+
239+
assert(__count(newState) <= max_associations);
240+
241+
if (__state_.compare_exchange_weak(
242+
state,
243+
newState,
244+
// this is effectively a ref-count increment so
245+
// there's no need to synchronize
246+
std::memory_order_relaxed)) {
247+
return __assoc_t{*this};
248+
}
249+
} while (true);
250+
}
251+
252+
void __disassociate() noexcept {
253+
auto state = __state_.load(std::memory_order_relaxed);
254+
255+
assert(__count(state) > 0ul);
256+
257+
do {
258+
const auto newCount = __count(state) - 1ul;
259+
const auto newBits = (newCount == 0ul && __is_joining(state) ? __closed : __bits(state));
260+
261+
if (__state_.compare_exchange_weak(
262+
state,
263+
__make_state(newCount, newBits),
264+
std::memory_order_acq_rel,
265+
std::memory_order_relaxed)) {
266+
// successfully updated; now decide whether we need to complete
267+
// outstanding join-senders or just bail out
268+
if (newCount == 0ul && newBits == __closed) {
269+
// launch the outstanding join-senders after the loop
270+
break;
271+
} else {
272+
// just bail out--either we're not the last op, or the scope
273+
// isn't closed yet
274+
return;
275+
}
276+
}
277+
} while (true);
278+
279+
// mark the linked list as having been consumed
280+
auto* joinOpsToComplete = static_cast<__counting_scopes::__state_base*>(
281+
__waitingJoinOps_.exchange(this, std::memory_order_acq_rel));
282+
283+
for (; joinOpsToComplete != nullptr; joinOpsToComplete = joinOpsToComplete->__next_) {
284+
joinOpsToComplete->__complete();
285+
}
286+
}
287+
288+
constexpr static bool __is_unused(std::size_t __state) noexcept {
289+
return (__state & __join_needed) == 0ul;
290+
}
291+
292+
constexpr static bool __is_open(std::size_t __state) noexcept {
293+
return (__state & __closed) == 0ul;
294+
}
295+
296+
constexpr static bool __is_closed(std::size_t __state) noexcept {
297+
return !__is_open(__state);
298+
}
299+
300+
constexpr static bool __is_joined(std::size_t __state) noexcept {
301+
return __is_closed(__state) && !__is_unused(__state);
302+
}
303+
304+
constexpr static bool __is_joining(std::size_t __state) noexcept {
305+
return (__state & __join_running) != 0ul;
306+
}
307+
308+
constexpr static bool __is_join_needed(std::size_t __state) noexcept {
309+
return !__is_unused(__state);
310+
}
311+
312+
constexpr static std::size_t __count(std::size_t __state) noexcept {
313+
return __state >> 3;
314+
}
315+
316+
constexpr static std::size_t __bits(std::size_t __state) noexcept {
317+
return __state & 7ul;
318+
}
319+
320+
constexpr static std::size_t
321+
__make_state(std::size_t __opCount, std::size_t __lowBits) noexcept {
322+
// no high bits set
323+
assert(__count(__opCount << 3) == __opCount);
324+
325+
// no high bits set
326+
assert(__bits(__lowBits) == __lowBits);
327+
328+
return (__opCount << 3) | __bits(__lowBits);
329+
}
330+
331+
bool __start_join_sender(__counting_scopes::__state_base& __joinOp) noexcept {
332+
auto state = __state_.load(std::memory_order_relaxed);
333+
334+
do {
335+
// [exec.simple.counting.mem] para (9.1)
336+
// unused, unused-and-closed, or joined -> joined
337+
if (__is_unused(state) || __is_closed(state)) {
338+
assert(__count(state) == 0ul);
339+
340+
const auto newState = __make_state(__count(state), __closed);
341+
342+
assert(__is_joined(newState));
343+
344+
// try to make it joined
345+
if (__state_.compare_exchange_weak(
346+
state, newState, std::memory_order_acq_rel, std::memory_order_relaxed)) {
347+
return true;
348+
}
349+
}
350+
// [exec.simple.counting.mem] para (9.2)
351+
// open or open-and-joining -> open-and-joining
352+
// [exec.simple.counting.mem] para (9.3)
353+
// closed or closed-and-joining -> closed-and-joining
354+
else {
355+
assert(__is_join_needed(state));
356+
357+
// try to make it {open|closed}-and-joining
358+
const auto newState = state | __join_running;
359+
360+
assert(__is_joining(newState));
361+
362+
// TODO: does this need to do any synchronizing or is relaxed OK?
363+
if (__state_.compare_exchange_weak(state, newState, std::memory_order_relaxed)) {
364+
return !__register(__joinOp);
365+
}
366+
}
367+
} while (true);
368+
}
369+
370+
bool __register(__counting_scopes::__state_base& __joinOp) noexcept {
371+
auto* ptr = __waitingJoinOps_.load(std::memory_order_relaxed);
372+
373+
do {
374+
if (ptr == this) {
375+
// __waitingJoinOps_ == this when the list has been cleared
376+
return false;
377+
}
378+
379+
// make __joinOp's next point to the current head
380+
__joinOp.__next_ = static_cast<__counting_scopes::__state_base*>(ptr);
381+
} while (
382+
// try to make the head point to __joinOp
383+
__waitingJoinOps_.compare_exchange_weak(
384+
ptr,
385+
&__joinOp,
386+
// I don't know what synchronization semantics we need
387+
// on success, but acquire-release feels safe
388+
std::memory_order_acq_rel,
389+
// on failure, we'll try again so relaxed should be ok
390+
std::memory_order_relaxed));
391+
392+
return true;
393+
}
394+
};
395+
} // namespace stdexec

0 commit comments

Comments
 (0)