ndmspc  v1.1.1-1
NDimensionalExecutor.h
1 #ifndef N_DIMENSIONAL_EXECUTOR_H
2 #define N_DIMENSIONAL_EXECUTOR_H
3 
4 #include <iomanip>
5 #include <sstream>
6 #include <vector>
7 #include <functional>
8 #include <cstddef>
9 #include <thread>
10 #include <mutex>
11 #include <condition_variable>
12 #include <atomic>
13 #include <queue>
14 #include <stdexcept>
15 #include <utility>
16 #include <exception>
17 #include <THnSparse.h>
18 #include "NLogger.h"
19 #include "NThreadData.h"
20 
21 namespace Ndmspc {
22 
29  public:
35  NDimensionalExecutor(const std::vector<int> & minBounds, const std::vector<int> & maxBounds);
36 
42  NDimensionalExecutor(THnSparse * hist, bool onlyfilled = false);
43 
48  void Execute(const std::function<void(const std::vector<int> & coords)> & func);
49 
56  template <typename TObject>
57  void ExecuteParallel(const std::function<void(const std::vector<int> & coords, TObject & thread_object)> & func,
58  std::vector<TObject> & thread_objects);
59 
64  size_t Dimensions() const { return fNumDimensions; }
65 
70  const std::vector<int> & GetMinBounds() const { return fMinBounds; }
71 
76  const std::vector<int> & GetMaxBounds() const { return fMaxBounds; }
77 
78  private:
79  size_t fNumDimensions;
80  std::vector<int> fMinBounds;
81  std::vector<int> fMaxBounds;
82  std::vector<int> fCurrentCoords;
83 
88  bool Increment();
89 };
90 
91 // --- Template Implementation for ExecuteParallel ---
98 template <typename TObject>
100  const std::function<void(const std::vector<int> & coords, TObject & thread_object)> & func,
101  std::vector<TObject> & thread_objects)
102 {
103  if (fNumDimensions == 0) {
104  return;
105  }
106  size_t threads_to_use = thread_objects.size();
107  if (threads_to_use == 0) {
108  throw std::invalid_argument("Thread objects vector cannot be empty.");
109  }
110 
111  std::vector<std::thread> workers;
112  std::queue<std::function<void(TObject &)>> tasks;
113  std::mutex queue_mutex;
114  std::condition_variable condition_producer;
115  std::condition_variable condition_consumer;
116  std::atomic<size_t> active_tasks = 0;
117  std::atomic<bool> stop_pool = false;
118  // Optional: Store first exception encountered in workers
119  std::exception_ptr first_exception = nullptr;
120  std::mutex exception_mutex;
121 
122  // Worker thread logic: fetch and execute tasks, handle exceptions, signal completion.
123  auto worker_logic = [&](TObject & my_object) {
124  NThreadData * md = (NThreadData *)&my_object;
125 
126  std::ostringstream oss;
127  oss << "wk_" << std::setw(3) << std::setfill('0') << md->GetAssignedIndex();
128 
129  NLogger::SetThreadName(oss.str());
130  while (true) {
131  std::function<void(TObject &)> task_payload;
132  bool task_acquired = false; // Track if we actually got a task this iteration
133 
134  try {
135  { // Lock scope for queue access
136  std::unique_lock<std::mutex> lock(queue_mutex);
137  condition_producer.wait(lock, [&] { return stop_pool || !tasks.empty(); });
138 
139  // Check stop condition *after* waking up
140  if (stop_pool && tasks.empty()) {
141  break; // Exit the while loop normally
142  }
143  // If stopping but tasks remain, continue processing them
144 
145  // Only proceed if not stopping or if tasks are still present
146  if (!tasks.empty()) {
147  task_payload = std::move(tasks.front());
148  tasks.pop();
149  task_acquired = true; // We got a task
150  }
151  else {
152  // Spurious wakeup or stop_pool=true with empty queue
153  continue; // Go back to wait
154  }
155  } // Mutex unlocked
156 
157  // Execute the task if we acquired one
158  if (task_acquired) {
159  task_payload(my_object); // Execute task with assigned object
160  }
161  }
162  catch (...) {
163  // --- Exception Handling ---
164  { // Lock to safely store the first exception
165  std::lock_guard<std::mutex> lock(exception_mutex);
166  if (!first_exception) {
167  first_exception = std::current_exception(); // Store it
168  }
169  }
170  // Signal pool to stop immediately on any error
171  {
172  std::unique_lock<std::mutex> lock(queue_mutex);
173  stop_pool = true;
174  }
175  condition_producer.notify_all(); // Wake all threads to check stop flag
176 
177  // *** Crucial Fix: Decrement active_tasks even on exception ***
178  // Check if we actually acquired a task before decrementing
179  if (task_acquired) {
180  if (--active_tasks == 0 && stop_pool) {
181  // Also notify consumer here in case this was the last task
182  condition_consumer.notify_one();
183  }
184  }
185  // Decide whether to exit the worker or try processing remaining tasks
186  // For simplicity, let's exit the worker on error.
187  return; // Exit worker thread immediately on error
188  }
189 
190  // --- Normal Task Completion ---
191  // Decrement active task count *after* successful execution
192  // Check if we actually acquired and processed a task
193  if (task_acquired) {
194  if (--active_tasks == 0 && stop_pool) {
195  condition_consumer.notify_one();
196  }
197  }
198  } // End of while loop
199  }; // End of worker_logic lambda
200 
201  // --- Start Worker Threads ---
202  workers.reserve(threads_to_use);
203  for (size_t i = 0; i < threads_to_use; ++i) {
204  workers.emplace_back(worker_logic, std::ref(thread_objects[i]));
205  }
206 
207  // --- Main Thread: Iterate and Enqueue Tasks ---
208  try {
210  do {
211  // Check if pool was stopped prematurely (e.g., by an exception in a worker)
212  // Lock needed to safely check stop_pool
213  {
214  std::unique_lock<std::mutex> lock(queue_mutex);
215  if (stop_pool) break;
216  }
217 
218  std::vector<int> coords_copy = fCurrentCoords;
219  {
220  std::unique_lock<std::mutex> lock(queue_mutex);
221  // Double check stop_pool after acquiring lock
222  if (stop_pool) break;
223 
224  active_tasks++;
225  tasks.emplace([func, coords_copy](TObject & obj) { func(coords_copy, obj); });
226  }
227  condition_producer.notify_one();
228  } while (Increment());
229  }
230  catch (...) {
231  // Exception during iteration/enqueueing
232  {
233  std::unique_lock<std::mutex> lock(queue_mutex);
234  stop_pool = true; // Signal workers to stop
235  if (!first_exception) { // Store exception if none from workers yet
236  first_exception = std::current_exception();
237  }
238  }
239  condition_producer.notify_all();
240  // Proceed to join threads
241  }
242 
243  // --- Signal Workers to Stop (if not already stopped by error) ---
244  {
245  std::unique_lock<std::mutex> lock(queue_mutex);
246  stop_pool = true;
247  }
248  condition_producer.notify_all();
249 
250  // --- Wait for Tasks to Complete ---
251  {
252  std::unique_lock<std::mutex> lock(queue_mutex);
253  condition_consumer.wait(lock, [&] { return stop_pool && active_tasks == 0; });
254  }
255 
256  // --- Join Worker Threads ---
257  for (std::thread & worker : workers) {
258  if (worker.joinable()) {
259  worker.join();
260  }
261  }
262 
263  // --- Check for and rethrow exception from workers ---
264  if (first_exception) {
265  std::rethrow_exception(first_exception);
266  }
267 }
268 
269 } // namespace Ndmspc
270 
271 #endif
Executes a function over all points in an N-dimensional space, optionally in parallel.
std::vector< int > fMinBounds
Minimum bounds for each dimension.
void ExecuteParallel(const std::function< void(const std::vector< int > &coords, TObject &thread_object)> &func, std::vector< TObject > &thread_objects)
Execute a function in parallel over all coordinates, using thread-local objects.
const std::vector< int > & GetMaxBounds() const
Returns the maximum bounds for each dimension.
std::vector< int > fMaxBounds
Maximum bounds for each dimension.
std::vector< int > fCurrentCoords
Current coordinates during iteration.
bool Increment()
Increment the current coordinates to the next point in the N-dimensional space.
NDimensionalExecutor(const std::vector< int > &minBounds, const std::vector< int > &maxBounds)
Constructor from min/max bounds for each dimension.
size_t Dimensions() const
Returns the number of dimensions.
size_t fNumDimensions
Number of dimensions.
void Execute(const std::function< void(const std::vector< int > &coords)> &func)
Execute a function over all coordinates in the N-dimensional space.
const std::vector< int > & GetMinBounds() const
Returns the minimum bounds for each dimension.
static void SetThreadName(const std::string &name, std::thread::id thread_id=std::this_thread::get_id())
Sets the name of a thread.
Definition: NLogger.cxx:128
Thread-local data object for NDMSPC processing.
Definition: NThreadData.h:21
size_t GetAssignedIndex() const
Get the assigned index for the thread.
Definition: NThreadData.h:96