diff --git a/include/pocketpy/common/threads.h b/include/pocketpy/common/threads.h index a78a1621..f75c18a0 100644 --- a/include/pocketpy/common/threads.h +++ b/include/pocketpy/common/threads.h @@ -76,5 +76,6 @@ typedef struct c11_thrdpool { void c11_thrdpool__ctor(c11_thrdpool* pool, int length); void c11_thrdpool__dtor(c11_thrdpool* pool); void c11_thrdpool__map(c11_thrdpool* pool, c11_thrdpool_func_t func, void** args, int num_tasks); +void c11_thrdpool__join(c11_thrdpool* pool); #endif \ No newline at end of file diff --git a/src/common/threads.c b/src/common/threads.c index 20e95eb2..a53ed569 100644 --- a/src/common/threads.c +++ b/src/common/threads.c @@ -193,7 +193,6 @@ void c11_thrdpool__dtor(c11_thrdpool* pool) { } void c11_thrdpool__map(c11_thrdpool* pool, c11_thrdpool_func_t func, void** args, int num_tasks) { - if(num_tasks == 0) return; c11_thrdpool_debug_log(-1, "c11_thrdpool__map() called on %d tasks...", num_tasks); while(atomic_load_explicit(&pool->ready_workers_num, memory_order_relaxed) < pool->length) { c11_thrd__yield(); @@ -210,13 +209,15 @@ void c11_thrdpool__map(c11_thrdpool* pool, c11_thrdpool_func_t func, void** args atomic_store_explicit(&pool->tasks.completed_count, 0, memory_order_relaxed); c11_cond__broadcast(&pool->workers_cond); c11_mutex__unlock(&pool->workers_mutex); +} +void c11_thrdpool__join(c11_thrdpool *pool) { // wait for complete + int num_tasks = pool->tasks.length; c11_thrdpool_debug_log(-1, "Waiting for %d tasks to complete...", num_tasks); while(atomic_load_explicit(&pool->tasks.completed_count, memory_order_acquire) < num_tasks) { c11_thrd__yield(); } - atomic_store_explicit(&pool->tasks.sync_val, 0, memory_order_relaxed); c11_thrdpool_debug_log(-1, "All %d tasks completed, `sync_val` was reset.", num_tasks); } diff --git a/src2/test_threads.c b/src2/test_threads.c index f2bf42da..bbf673dd 100644 --- a/src2/test_threads.c +++ b/src2/test_threads.c @@ -33,6 +33,7 @@ int main(int argc, char** argv) { printf("==> %dth run\n", i + 1); int64_t start_ns = time_ns(); c11_thrdpool__map(&pool, func, args, num_tasks); + c11_thrdpool__join(&pool); int64_t end_ns = time_ns(); double elapsed = (end_ns - start_ns) / 1e9; printf(" Results: %lld, %lld, %lld, %lld, %lld\n",