WPILibC++ 2027.0.0-alpha-4
Loading...
Searching...
No Matches
CallbackManager.hpp
Go to the documentation of this file.
1// Copyright (c) FIRST and other WPILib contributors.
2// Open Source Software; you can modify and/or share it under the terms of
3// the WPILib BSD license file in the root directory of this project.
4
5#pragma once
6
7#include <atomic>
8#include <climits>
9#include <functional>
10#include <memory>
11#include <queue>
12#include <utility>
13#include <vector>
14
18#include "wpi/util/mutex.hpp"
19#include "wpi/util/raw_ostream.hpp"
20
21namespace wpi::util {
22
23template <typename Callback>
25 public:
27 explicit CallbackListenerData(Callback callback_) : callback(callback_) {}
28 explicit CallbackListenerData(unsigned int poller_uid_)
29 : poller_uid(poller_uid_) {}
30
31 explicit operator bool() const { return callback || poller_uid != UINT_MAX; }
32
33 Callback callback;
34 unsigned int poller_uid = UINT_MAX;
35};
36
37// CRTP callback manager thread
38// @tparam Derived derived class
39// @tparam NotifierData data buffered for each callback
40// @tparam ListenerData data stored for each listener
41// Derived must define the following functions:
42// bool Matches(const ListenerData& listener, const NotifierData& data);
43// void SetListener(NotifierData* data, unsigned int listener_uid);
44// void DoCallback(Callback callback, const NotifierData& data);
45template <typename Derived, typename TUserInfo,
46 typename TListenerData =
47 CallbackListenerData<std::function<void(const TUserInfo& info)>>,
48 typename TNotifierData = TUserInfo>
50 public:
51 using UserInfo = TUserInfo;
52 using NotifierData = TNotifierData;
53 using ListenerData = TListenerData;
54
55 CallbackThread(std::function<void()> on_start, std::function<void()> on_exit)
56 : m_on_start(std::move(on_start)), m_on_exit(std::move(on_exit)) {}
57
58 ~CallbackThread() override {
59 // Wake up any blocked pollers
60 for (size_t i = 0; i < m_pollers.size(); ++i) {
61 if (auto poller = m_pollers[i]) {
62 poller->Terminate();
63 }
64 }
65 }
66
67 void Main() override;
68
70
71 std::queue<std::pair<unsigned int, NotifierData>> m_queue;
73
74 struct Poller {
75 void Terminate() {
76 {
77 std::scoped_lock lock(poll_mutex);
78 terminating = true;
79 }
80 poll_cond.notify_all();
81 }
82 std::queue<NotifierData> poll_queue;
85 bool terminating = false;
86 bool canceling = false;
87 };
89
90 std::function<void()> m_on_start;
91 std::function<void()> m_on_exit;
92
93 // Must be called with m_mutex held
94 template <typename... Args>
95 void SendPoller(unsigned int poller_uid, Args&&... args) {
96 if (poller_uid > m_pollers.size()) {
97 return;
98 }
99 auto poller = m_pollers[poller_uid];
100 if (!poller) {
101 return;
102 }
103 {
104 std::scoped_lock lock(poller->poll_mutex);
105 poller->poll_queue.emplace(std::forward<Args>(args)...);
106 }
107 poller->poll_cond.notify_one();
108 }
109};
110
111template <typename Derived, typename TUserInfo, typename TListenerData,
112 typename TNotifierData>
114 if (m_on_start) {
115 m_on_start();
116 }
117
118 std::unique_lock lock(m_mutex);
119 while (m_active) {
120 while (m_queue.empty()) {
121 m_cond.wait(lock);
122 if (!m_active) {
123 goto done;
124 }
125 }
126
127 while (!m_queue.empty()) {
128 if (!m_active) {
129 goto done;
130 }
131 auto item = std::move(m_queue.front());
132
133 if (item.first != UINT_MAX) {
134 if (item.first < m_listeners.size()) {
135 auto& listener = m_listeners[item.first];
136 if (listener &&
137 static_cast<Derived*>(this)->Matches(listener, item.second)) {
138 static_cast<Derived*>(this)->SetListener(&item.second, item.first);
139 if (listener.callback) {
140 lock.unlock();
141 static_cast<Derived*>(this)->DoCallback(listener.callback,
142 item.second);
143 lock.lock();
144 } else if (listener.poller_uid != UINT_MAX) {
145 SendPoller(listener.poller_uid, std::move(item.second));
146 }
147 }
148 }
149 } else {
150 // Use index because iterator might get invalidated.
151 for (size_t i = 0; i < m_listeners.size(); ++i) {
152 auto& listener = m_listeners[i];
153 if (!listener) {
154 continue;
155 }
156
157 if (!static_cast<Derived*>(this)->Matches(listener, item.second)) {
158 continue;
159 }
160 static_cast<Derived*>(this)->SetListener(&item.second,
161 static_cast<unsigned>(i));
162 if (listener.callback) {
163 lock.unlock();
164 static_cast<Derived*>(this)->DoCallback(listener.callback,
165 item.second);
166 lock.lock();
167 } else if (listener.poller_uid != UINT_MAX) {
168 SendPoller(listener.poller_uid, item.second);
169 }
170 }
171 }
172 m_queue.pop();
173 }
174
175 m_queue_empty.notify_all();
176 }
177
178done:
179 if (m_on_exit) {
180 m_on_exit();
181 }
182}
183
184// CRTP callback manager
185// @tparam Derived derived class
186// @tparam Thread custom thread (must be derived from impl::CallbackThread)
187//
188// Derived must define the following functions:
189// void Start();
190template <typename Derived, typename Thread>
192 friend class RpcServerTest;
193
194 public:
195 void SetOnStart(std::function<void()> on_start) {
196 m_on_start = std::move(on_start);
197 }
198
199 void SetOnExit(std::function<void()> on_exit) {
200 m_on_exit = std::move(on_exit);
201 }
202
203 void Stop() { m_owner.Stop(); }
204
205 void Remove(unsigned int listener_uid) {
206 auto thr = m_owner.GetThread();
207 if (!thr) {
208 return;
209 }
210 thr->m_listeners.erase(listener_uid);
211 }
212
213 unsigned int CreatePoller() {
214 static_cast<Derived*>(this)->Start();
215 auto thr = m_owner.GetThread();
216 return thr->m_pollers.emplace_back(
217 std::make_shared<typename Thread::Poller>());
218 }
219
220 void RemovePoller(unsigned int poller_uid) {
221 auto thr = m_owner.GetThread();
222 if (!thr) {
223 return;
224 }
225
226 // Remove any listeners that are associated with this poller
227 for (size_t i = 0; i < thr->m_listeners.size(); ++i) {
228 if (thr->m_listeners[i].poller_uid == poller_uid) {
229 thr->m_listeners.erase(i);
230 }
231 }
232
233 // Wake up any blocked pollers
234 if (poller_uid >= thr->m_pollers.size()) {
235 return;
236 }
237 auto poller = thr->m_pollers[poller_uid];
238 if (!poller) {
239 return;
240 }
241 poller->Terminate();
242 thr->m_pollers.erase(poller_uid);
243 }
244
245 bool WaitForQueue(double timeout) {
246 auto thr = m_owner.GetThread();
247 if (!thr) {
248 return true;
249 }
250
251 auto& lock = thr.GetLock();
252 auto timeout_time = std::chrono::steady_clock::now() +
253 std::chrono::duration<double>(timeout);
254 while (!thr->m_queue.empty()) {
255 if (!thr->m_active) {
256 return true;
257 }
258 if (timeout == 0) {
259 return false;
260 }
261 if (timeout < 0) {
262 thr->m_queue_empty.wait(lock);
263 } else {
264 auto cond_timed_out = thr->m_queue_empty.wait_until(lock, timeout_time);
265 if (cond_timed_out == std::cv_status::timeout) {
266 return false;
267 }
268 }
269 }
270
271 return true;
272 }
273
274 std::vector<typename Thread::UserInfo> Poll(unsigned int poller_uid) {
275 bool timed_out = false;
276 return Poll(poller_uid, -1, &timed_out);
277 }
278
279 std::vector<typename Thread::UserInfo> Poll(unsigned int poller_uid,
280 double timeout, bool* timed_out) {
281 std::vector<typename Thread::UserInfo> infos;
282 std::shared_ptr<typename Thread::Poller> poller;
283 {
284 auto thr = m_owner.GetThread();
285 if (!thr) {
286 return infos;
287 }
288 if (poller_uid > thr->m_pollers.size()) {
289 return infos;
290 }
291 poller = thr->m_pollers[poller_uid];
292 if (!poller) {
293 return infos;
294 }
295 }
296
297 std::unique_lock lock(poller->poll_mutex);
298 auto timeout_time = std::chrono::steady_clock::now() +
299 std::chrono::duration<double>(timeout);
300 *timed_out = false;
301 while (poller->poll_queue.empty()) {
302 if (poller->terminating) {
303 return infos;
304 }
305 if (poller->canceling) {
306 // Note: this only works if there's a single thread calling this
307 // function for any particular poller, but that's the intended use.
308 poller->canceling = false;
309 return infos;
310 }
311 if (timeout == 0) {
312 *timed_out = true;
313 return infos;
314 }
315 if (timeout < 0) {
316 poller->poll_cond.wait(lock);
317 } else {
318 auto cond_timed_out = poller->poll_cond.wait_until(lock, timeout_time);
319 if (cond_timed_out == std::cv_status::timeout) {
320 *timed_out = true;
321 return infos;
322 }
323 }
324 }
325
326 while (!poller->poll_queue.empty()) {
327 infos.emplace_back(std::move(poller->poll_queue.front()));
328 poller->poll_queue.pop();
329 }
330 return infos;
331 }
332
333 void CancelPoll(unsigned int poller_uid) {
334 std::shared_ptr<typename Thread::Poller> poller;
335 {
336 auto thr = m_owner.GetThread();
337 if (!thr) {
338 return;
339 }
340 if (poller_uid > thr->m_pollers.size()) {
341 return;
342 }
343 poller = thr->m_pollers[poller_uid];
344 if (!poller) {
345 return;
346 }
347 }
348
349 {
350 std::scoped_lock lock(poller->poll_mutex);
351 poller->canceling = true;
352 }
353 poller->poll_cond.notify_one();
354 }
355
356 protected:
357 template <typename... Args>
358 void DoStart(Args&&... args) {
359 m_owner.Start(m_on_start, m_on_exit, std::forward<Args>(args)...);
360 }
361
362 template <typename... Args>
363 unsigned int DoAdd(Args&&... args) {
364 static_cast<Derived*>(this)->Start();
365 auto thr = m_owner.GetThread();
366 return thr->m_listeners.emplace_back(std::forward<Args>(args)...);
367 }
368
369 template <typename... Args>
370 void Send(unsigned int only_listener, Args&&... args) {
371 auto thr = m_owner.GetThread();
372 if (!thr || thr->m_listeners.empty()) {
373 return;
374 }
375 thr->m_queue.emplace(std::piecewise_construct,
376 std::make_tuple(only_listener),
377 std::forward_as_tuple(std::forward<Args>(args)...));
378 thr->m_cond.notify_one();
379 }
380
382 return m_owner.GetThread();
383 }
384
385 private:
387
388 std::function<void()> m_on_start;
389 std::function<void()> m_on_exit;
390};
391
392} // namespace wpi::util
Definition CallbackManager.hpp:24
unsigned int poller_uid
Definition CallbackManager.hpp:34
CallbackListenerData(Callback callback_)
Definition CallbackManager.hpp:27
Callback callback
Definition CallbackManager.hpp:33
CallbackListenerData(unsigned int poller_uid_)
Definition CallbackManager.hpp:28
Definition CallbackManager.hpp:191
std::vector< typename Thread::UserInfo > Poll(unsigned int poller_uid)
Definition CallbackManager.hpp:274
void SetOnStart(std::function< void()> on_start)
Definition CallbackManager.hpp:195
unsigned int CreatePoller()
Definition CallbackManager.hpp:213
void SetOnExit(std::function< void()> on_exit)
Definition CallbackManager.hpp:199
void CancelPoll(unsigned int poller_uid)
Definition CallbackManager.hpp:333
void Stop()
Definition CallbackManager.hpp:203
void DoStart(Args &&... args)
Definition CallbackManager.hpp:358
bool WaitForQueue(double timeout)
Definition CallbackManager.hpp:245
friend class RpcServerTest
Definition CallbackManager.hpp:192
void Remove(unsigned int listener_uid)
Definition CallbackManager.hpp:205
wpi::util::SafeThreadOwner< Thread >::Proxy GetThread() const
Definition CallbackManager.hpp:381
void RemovePoller(unsigned int poller_uid)
Definition CallbackManager.hpp:220
void Send(unsigned int only_listener, Args &&... args)
Definition CallbackManager.hpp:370
std::vector< typename Thread::UserInfo > Poll(unsigned int poller_uid, double timeout, bool *timed_out)
Definition CallbackManager.hpp:279
unsigned int DoAdd(Args &&... args)
Definition CallbackManager.hpp:363
CallbackThread(std::function< void()> on_start, std::function< void()> on_exit)
Definition CallbackManager.hpp:55
TListenerData ListenerData
Definition CallbackManager.hpp:53
wpi::util::UidVector< ListenerData, 64 > m_listeners
Definition CallbackManager.hpp:69
std::function< void()> m_on_start
Definition CallbackManager.hpp:90
wpi::util::condition_variable m_queue_empty
Definition CallbackManager.hpp:72
void Main() override
Definition CallbackManager.hpp:113
std::function< void()> m_on_exit
Definition CallbackManager.hpp:91
wpi::util::UidVector< std::shared_ptr< Poller >, 64 > m_pollers
Definition CallbackManager.hpp:88
void SendPoller(unsigned int poller_uid, Args &&... args)
Definition CallbackManager.hpp:95
std::queue< std::pair< unsigned int, NotifierData > > m_queue
Definition CallbackManager.hpp:71
~CallbackThread() override
Definition CallbackManager.hpp:58
TUserInfo UserInfo
Definition CallbackManager.hpp:51
TNotifierData NotifierData
Definition CallbackManager.hpp:52
wpi::util::mutex m_mutex
Definition SafeThread.hpp:27
std::atomic_bool m_active
Definition SafeThread.hpp:28
Definition SafeThread.hpp:32
wpi::util::condition_variable m_cond
Definition SafeThread.hpp:36
Definition SafeThread.hpp:123
typename detail::SafeThreadProxy< T > Proxy
Definition SafeThread.hpp:131
Proxy GetThread() const
Definition SafeThread.hpp:132
Vector which provides an integrated freelist for removal and reuse of individual elements.
Definition UidVector.hpp:72
Definition StringMap.hpp:773
Definition raw_os_ostream.hpp:19
::std::condition_variable condition_variable
Definition condition_variable.hpp:16
::std::mutex mutex
Definition mutex.hpp:17
Definition CallbackManager.hpp:74
bool terminating
Definition CallbackManager.hpp:85
void Terminate()
Definition CallbackManager.hpp:75
bool canceling
Definition CallbackManager.hpp:86
wpi::util::condition_variable poll_cond
Definition CallbackManager.hpp:84
wpi::util::mutex poll_mutex
Definition CallbackManager.hpp:83
std::queue< NotifierData > poll_queue
Definition CallbackManager.hpp:82