diff --git a/include/pocketpy/common/threads.h b/include/pocketpy/common/threads.h index a84731bd..4be7458b 100644 --- a/include/pocketpy/common/threads.h +++ b/include/pocketpy/common/threads.h @@ -44,10 +44,12 @@ void c11_cond__broadcast(c11_cond_t* cond); typedef void (*c11_thrdpool_func_t)(void* arg); + typedef struct c11_thrdpool_tasks { c11_thrdpool_func_t func; void** args; int length; + int sync_val; atomic_int current_index; atomic_int completed_count; } c11_thrdpool_tasks; @@ -56,8 +58,6 @@ typedef struct c11_thrdpool_worker { c11_mutex_t* p_mutex; c11_cond_t* p_cond; c11_thrdpool_tasks* p_tasks; - - bool should_exit; c11_thrd_t thread; } c11_thrdpool_worker; @@ -66,8 +66,8 @@ typedef struct c11_thrdpool { c11_thrdpool_worker* workers; atomic_bool is_busy; - c11_mutex_t workers_mutex; - c11_cond_t workers_cond; + c11_mutex_t workers_mutex[2]; + c11_cond_t workers_cond[2]; c11_thrdpool_tasks tasks; } c11_thrdpool; diff --git a/src/common/threads.c b/src/common/threads.c index 95bfc4c5..2b1e2a73 100644 --- a/src/common/threads.c +++ b/src/common/threads.c @@ -71,23 +71,30 @@ void c11_cond__broadcast(c11_cond_t* cond) { cnd_broadcast(cond); } #endif +static bool _thrdpool_worker_sync(c11_mutex_t* p_mutex, + c11_cond_t* p_cond, + int* p_sync_val, + int expected_sync_val) { + int index = (int)expected_sync_val; + c11_mutex__lock(p_mutex + index); + while(true) { + c11_cond__wait(p_cond + index, p_mutex + index); + if(*p_sync_val == -1) return false; + if(*p_sync_val == expected_sync_val) break; + } + c11_mutex__unlock(p_mutex + index); + return true; +} + static c11_thrd_retval_t _thrdpool_worker(void* arg) { c11_thrdpool_worker* p_worker = (c11_thrdpool_worker*)arg; + c11_thrdpool_tasks* p_tasks = p_worker->p_tasks; while(true) { - // wait for tasks - c11_mutex__lock(p_worker->p_mutex); - while(!p_worker->p_tasks && !p_worker->should_exit) { - c11_cond__wait(p_worker->p_cond, p_worker->p_mutex); - } - if(p_worker->should_exit) { - c11_mutex__unlock(p_worker->p_mutex); + if(!_thrdpool_worker_sync(p_worker->p_mutex, p_worker->p_cond, &p_tasks->sync_val, 0)) { break; } - c11_thrdpool_tasks* p_tasks = p_worker->p_tasks; - c11_mutex__unlock(p_worker->p_mutex); - // execute tasks while(true) { int arg_index = atomic_fetch_add(&p_tasks->current_index, 1); @@ -96,12 +103,13 @@ static c11_thrd_retval_t _thrdpool_worker(void* arg) { p_tasks->func(arg); atomic_fetch_add(&p_tasks->completed_count, 1); } else { - c11_mutex__lock(p_worker->p_mutex); - p_worker->p_tasks = NULL; - c11_mutex__unlock(p_worker->p_mutex); break; } } + + if(!_thrdpool_worker_sync(p_worker->p_mutex, p_worker->p_cond, &p_tasks->sync_val, 1)) { + break; + } } return 0; } @@ -111,28 +119,30 @@ void c11_thrdpool__ctor(c11_thrdpool* pool, int length) { pool->workers = PK_MALLOC(sizeof(c11_thrdpool_worker) * length); atomic_store(&pool->is_busy, false); - c11_mutex__ctor(&pool->workers_mutex); - c11_cond__ctor(&pool->workers_cond); + c11_mutex__ctor(&pool->workers_mutex[0]); + c11_mutex__ctor(&pool->workers_mutex[1]); + c11_cond__ctor(&pool->workers_cond[0]); + c11_cond__ctor(&pool->workers_cond[1]); for(int i = 0; i < length; i++) { c11_thrdpool_worker* p_worker = &pool->workers[i]; - - p_worker->p_mutex = &pool->workers_mutex; - p_worker->p_cond = &pool->workers_cond; + p_worker->p_mutex = pool->workers_mutex; + p_worker->p_cond = pool->workers_cond; p_worker->p_tasks = &pool->tasks; - p_worker->should_exit = false; - bool ok = c11_thrd__create(&p_worker->thread, _thrdpool_worker, p_worker); c11__rtassert(ok); } } void c11_thrdpool__dtor(c11_thrdpool* pool) { - for(int i = 0; i < pool->length; i++) { - c11_thrdpool_worker* p_worker = &pool->workers[i]; - atomic_store(&p_worker->should_exit, true); - } - c11_cond__broadcast(&pool->workers_cond); + c11_mutex__lock(&pool->workers_mutex[0]); + c11_mutex__lock(&pool->workers_mutex[1]); + pool->tasks.sync_val = -1; + c11_mutex__unlock(&pool->workers_mutex[1]); + c11_mutex__unlock(&pool->workers_mutex[0]); + + c11_cond__broadcast(&pool->workers_cond[0]); + c11_cond__broadcast(&pool->workers_cond[1]); for(int i = 0; i < pool->length; i++) { c11_thrdpool_worker* p_worker = &pool->workers[i]; @@ -140,8 +150,10 @@ void c11_thrdpool__dtor(c11_thrdpool* pool) { } PK_FREE(pool->workers); - c11_mutex__dtor(&pool->workers_mutex); - c11_cond__dtor(&pool->workers_cond); + c11_mutex__dtor(&pool->workers_mutex[0]); + c11_mutex__dtor(&pool->workers_mutex[1]); + c11_cond__dtor(&pool->workers_cond[0]); + c11_cond__dtor(&pool->workers_cond[1]); } void c11_thrdpool__map(c11_thrdpool* pool, c11_thrdpool_func_t func, void** args, int num_tasks) { @@ -152,19 +164,25 @@ void c11_thrdpool__map(c11_thrdpool* pool, c11_thrdpool_func_t func, void** args c11_thrd__yield(); } // assign tasks - c11_mutex__lock(&pool->workers_mutex); + c11_mutex__lock(&pool->workers_mutex[0]); pool->tasks.func = func; pool->tasks.args = args; pool->tasks.length = num_tasks; + pool->tasks.sync_val = 0; atomic_store(&pool->tasks.current_index, 0); atomic_store(&pool->tasks.completed_count, 0); - // wake up all workers - c11_cond__broadcast(&pool->workers_cond); - c11_mutex__unlock(&pool->workers_mutex); + c11_cond__broadcast(&pool->workers_cond[0]); + c11_mutex__unlock(&pool->workers_mutex[0]); // wait for complete while(atomic_load(&pool->tasks.completed_count) < num_tasks) { c11_thrd__yield(); } + // notify workers to proceed + c11_mutex__lock(&pool->workers_mutex[1]); + pool->tasks.sync_val = 1; + c11_cond__broadcast(&pool->workers_cond[1]); + c11_mutex__unlock(&pool->workers_mutex[1]); + // mark as not busy atomic_store(&pool->is_busy, false); }