diff --git a/include/pocketpy/common/algorithm.h b/include/pocketpy/common/algorithm.h index 4caed690..2e63fa8e 100644 --- a/include/pocketpy/common/algorithm.h +++ b/include/pocketpy/common/algorithm.h @@ -37,10 +37,10 @@ extern "C" { * @param cmp Comparison function that takes two elements and returns an integer similar to * `strcmp`. */ -void c11__stable_sort(void* ptr, +bool c11__stable_sort(void* ptr, int count, int elem_size, - int (*cmp)(const void* a, const void* b)); + int (*f_le)(const void* a, const void* b)); #ifdef __cplusplus } diff --git a/src/common/algorithm.c b/src/common/algorithm.c index e99c7e69..f60c725c 100644 --- a/src/common/algorithm.c +++ b/src/common/algorithm.c @@ -2,16 +2,19 @@ #include #include -static void merge(char* a_begin, +static bool merge(char* a_begin, char* a_end, char* b_begin, char* b_end, char* res, int elem_size, - int (*cmp)(const void* a, const void* b)) { + int (*f_le)(const void* a, const void* b)) { char *a = a_begin, *b = b_begin, *r = res; while(a < a_end && b < b_end) { - if(cmp(a, b) <= 0) { + int res = f_le(a, b); + // check error + if(res == -1) return false; + if(res) { memcpy(r, a, elem_size); a += elem_size; } else { @@ -26,22 +29,27 @@ static void merge(char* a_begin, memcpy(r, a, elem_size); for(; b < b_end; b += elem_size, r += elem_size) memcpy(r, b, elem_size); + return true; } -void c11__stable_sort(void* ptr_, +bool c11__stable_sort(void* ptr_, int count, int elem_size, - int (*cmp)(const void* a, const void* b)) { + int (*f_le)(const void* a, const void* b)) { // merge sort - char* ptr = ptr_, *tmp = malloc(count * elem_size); + char *ptr = ptr_, *tmp = malloc(count * elem_size); for(int seg = 1; seg < count; seg *= 2) { for(char* a = ptr; a < ptr + (count - seg) * elem_size; a += 2 * seg * elem_size) { - char* b = a + seg * elem_size, *a_end = b, *b_end = b + seg * elem_size; - if (b_end > ptr + count * elem_size) - b_end = ptr + count * elem_size; - merge(a, a_end, b, b_end, tmp, elem_size, cmp); - memcpy(a, tmp, b_end - a); + char *b = a + seg * elem_size, *a_end = b, *b_end = b + seg * elem_size; + if(b_end > ptr + count * elem_size) b_end = ptr + count * elem_size; + bool ok = merge(a, a_end, b, b_end, tmp, elem_size, f_le); + if(!ok) { + free(tmp); + return false; + } + memcpy(a, tmp, b_end - a); } } - free(tmp); + free(tmp); + return true; } diff --git a/src/public/py_list.c b/src/public/py_list.c index b995b2da..8d189a23 100644 --- a/src/public/py_list.c +++ b/src/public/py_list.c @@ -319,6 +319,14 @@ static bool _py_list__insert(int argc, py_Ref argv) { return true; } +static bool _py_list__sort(int argc, py_Ref argv) { + PY_CHECK_ARGC(1); + List* self = py_touserdata(py_arg(0)); + c11__stable_sort(self->data, self->count, sizeof(py_TValue), (int (*)(const void*, const void*))py_le); + py_newnone(py_retval()); + return true; +} + py_Type pk_list__register() { pk_VM* vm = pk_current_vm; py_Type type = pk_VM__new_type(vm, "list", tp_object, NULL, false); @@ -346,5 +354,6 @@ py_Type pk_list__register() { py_bindmethod(type, "remove", _py_list__remove); py_bindmethod(type, "pop", _py_list__pop); py_bindmethod(type, "insert", _py_list__insert); + py_bindmethod(type, "sort", _py_list__sort); return type; } \ No newline at end of file diff --git a/tests/05_list.py b/tests/05_list.py index 81be02ea..ab71605b 100644 --- a/tests/05_list.py +++ b/tests/05_list.py @@ -87,6 +87,18 @@ assert list(range(5, 1, -2)) == [5, 3] # test sort a = [8, 2, 4, 2, 9] +assert a.sort() == None +assert a == [2, 2, 4, 8, 9] + +a = [] +assert a.sort() == None +assert a == [] + +a = [0, 0, 0, 0, 1, 1, 3, -1] +assert a.sort() == None +assert a == [-1, 0, 0, 0, 0, 1, 1, 3] + +# test sorted assert sorted(a) == [2, 2, 4, 8, 9] assert sorted(a, reverse=True) == [9, 8, 4, 2, 2]