22
33#include < atomic>
44#include < thread>
5- #include < climits >
5+ #include < limits >
66
77namespace tp
88{
@@ -16,6 +16,8 @@ namespace tp
1616template <typename Task, template <typename > class Queue >
1717class Worker
1818{
19+ using WorkerVector = std::vector<std::unique_ptr<Worker<Task, Queue>>>;
20+
1921public:
2022 /* *
2123 * @brief Worker Constructor.
@@ -36,9 +38,9 @@ class Worker
3638 /* *
3739 * @brief start Create the executing thread and start tasks execution.
3840 * @param id Worker ID.
39- * @param steal_donor Sibling worker to steal task from it .
41+ * @param workers Sibling workers for performing round robin work stealing .
4042 */
41- void start (size_t id, Worker* steal_donor );
43+ void start (size_t id, WorkerVector* workers );
4244
4345 /* *
4446 * @brief stop Stop all worker's thread and stealing activity.
@@ -47,19 +49,19 @@ class Worker
4749 void stop ();
4850
4951 /* *
50- * @brief post Post task to queue.
52+ * @brief tryPost Post task to queue.
5153 * @param handler Handler to be executed in executing thread.
5254 * @return true on success.
5355 */
5456 template <typename Handler>
55- bool post (Handler&& handler);
57+ bool tryPost (Handler&& handler);
5658
5759 /* *
58- * @brief steal Steal one task from this worker queue.
59- * @param task Place for stealed task to be stored.
60+ * @brief tryGetLocalTask Get one task from this worker queue.
61+ * @param task Place for the obtained task to be stored.
6062 * @return true on success.
6163 */
62- bool steal (Task& task);
64+ bool tryGetLocalTask (Task& task);
6365
6466 /* *
6567 * @brief getWorkerIdForCurrentThread Return worker ID associated with
@@ -69,16 +71,24 @@ class Worker
6971 static size_t getWorkerIdForCurrentThread ();
7072
7173private:
74+ /* *
75+ * @brief tryRoundRobinSteal Try stealing a thread from sibling workers in a round-robin fashion.
76+ * @param task Place for the obtained task to be stored.
77+ * @param workers Sibling workers for performing round robin work stealing.
78+ */
79+ bool tryRoundRobinSteal (Task& task, WorkerVector* workers);
80+
7281 /* *
7382 * @brief threadFunc Executing thread function.
7483 * @param id Worker ID to be associated with this thread.
75- * @param steal_donor Sibling worker to steal task from it .
84+ * @param workers Sibling workers for performing round robin work stealing .
7685 */
77- void threadFunc (size_t id, Worker* steal_donor );
86+ void threadFunc (size_t id, WorkerVector* workers );
7887
7988 Queue<Task> m_queue;
8089 std::atomic<bool > m_running_flag;
8190 std::thread m_thread;
91+ size_t m_next_donor;
8292};
8393
8494
@@ -88,7 +98,7 @@ namespace detail
8898{
8999 inline size_t * thread_id ()
90100 {
91- static thread_local size_t tss_id = UINT_MAX ;
101+ static thread_local size_t tss_id = std::numeric_limits< size_t >:: max () ;
92102 return &tss_id;
93103 }
94104}
@@ -97,6 +107,7 @@ template <typename Task, template<typename> class Queue>
97107inline Worker<Task, Queue>::Worker(size_t queue_size)
98108 : m_queue(queue_size)
99109 , m_running_flag(true )
110+ , m_next_donor(0 ) // Initialized in threadFunc.
100111{
101112}
102113
@@ -126,9 +137,9 @@ inline void Worker<Task, Queue>::stop()
126137}
127138
128139template <typename Task, template <typename > class Queue >
129- inline void Worker<Task, Queue>::start(size_t id, Worker* steal_donor )
140+ inline void Worker<Task, Queue>::start(size_t id, WorkerVector* workers )
130141{
131- m_thread = std::thread (&Worker<Task, Queue>::threadFunc, this , id, steal_donor );
142+ m_thread = std::thread (&Worker<Task, Queue>::threadFunc, this , id, workers );
132143}
133144
134145template <typename Task, template <typename > class Queue >
@@ -139,35 +150,60 @@ inline size_t Worker<Task, Queue>::getWorkerIdForCurrentThread()
139150
140151template <typename Task, template <typename > class Queue >
141152template <typename Handler>
142- inline bool Worker<Task, Queue>::post (Handler&& handler)
153+ inline bool Worker<Task, Queue>::tryPost (Handler&& handler)
143154{
144155 return m_queue.push (std::forward<Handler>(handler));
145156}
146157
147158template <typename Task, template <typename > class Queue >
148- inline bool Worker<Task, Queue>::steal (Task& task)
159+ inline bool Worker<Task, Queue>::tryGetLocalTask (Task& task)
149160{
150161 return m_queue.pop (task);
151162}
152163
153164template <typename Task, template <typename > class Queue >
154- inline void Worker<Task, Queue>::threadFunc(size_t id, Worker* steal_donor)
165+ inline bool Worker<Task, Queue>::tryRoundRobinSteal(Task& task, WorkerVector* workers)
166+ {
167+ auto starting_index = m_next_donor;
168+
169+ // Iterate once through the worker ring, checking for queued work items on each thread.
170+ do
171+ {
172+ // Don't steal from local queue.
173+ if (m_next_donor != *detail::thread_id () && workers->at (m_next_donor)->tryGetLocalTask (task))
174+ {
175+ // Increment before returning so that m_next_donor always points to the worker that has gone the longest
176+ // without a steal attempt. This helps enforce fairness in the stealing.
177+ ++m_next_donor %= workers->size ();
178+ return true ;
179+ }
180+
181+ ++m_next_donor %= workers->size ();
182+ } while (m_next_donor != starting_index);
183+
184+ return false ;
185+ }
186+
187+ template <typename Task, template <typename > class Queue >
188+ inline void Worker<Task, Queue>::threadFunc(size_t id, WorkerVector* workers)
155189{
156190 *detail::thread_id () = id;
191+ m_next_donor = ++id % workers->size ();
157192
158193 Task handler;
159194
160195 while (m_running_flag.load (std::memory_order_relaxed))
161196 {
162- if (m_queue.pop (handler) || steal_donor->steal (handler))
197+ // Prioritize local queue, then try stealing from sibling workers.
198+ if (tryGetLocalTask (handler) || tryRoundRobinSteal (handler, workers))
163199 {
164200 try
165201 {
166202 handler ();
167203 }
168204 catch (...)
169205 {
170- // suppress all exceptions
206+ // Suppress all exceptions.
171207 }
172208 }
173209 else
0 commit comments