fix list.sort

This commit is contained in:
blueloveTH 2024-07-07 13:19:48 +08:00
parent 400b8fbef4
commit 22ae57fc9b
4 changed files with 43 additions and 14 deletions

View File

@ -37,10 +37,10 @@ extern "C" {
* @param cmp Comparison function that takes two elements and returns an integer similar to * @param cmp Comparison function that takes two elements and returns an integer similar to
* `strcmp`. * `strcmp`.
*/ */
void c11__stable_sort(void* ptr, bool c11__stable_sort(void* ptr,
int count, int count,
int elem_size, int elem_size,
int (*cmp)(const void* a, const void* b)); int (*f_le)(const void* a, const void* b));
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -2,16 +2,19 @@
#include <string.h> #include <string.h>
#include <stdlib.h> #include <stdlib.h>
static void merge(char* a_begin, static bool merge(char* a_begin,
char* a_end, char* a_end,
char* b_begin, char* b_begin,
char* b_end, char* b_end,
char* res, char* res,
int elem_size, int elem_size,
int (*cmp)(const void* a, const void* b)) { int (*f_le)(const void* a, const void* b)) {
char *a = a_begin, *b = b_begin, *r = res; char *a = a_begin, *b = b_begin, *r = res;
while(a < a_end && b < b_end) { while(a < a_end && b < b_end) {
if(cmp(a, b) <= 0) { int res = f_le(a, b);
// check error
if(res == -1) return false;
if(res) {
memcpy(r, a, elem_size); memcpy(r, a, elem_size);
a += elem_size; a += elem_size;
} else { } else {
@ -26,22 +29,27 @@ static void merge(char* a_begin,
memcpy(r, a, elem_size); memcpy(r, a, elem_size);
for(; b < b_end; b += elem_size, r += elem_size) for(; b < b_end; b += elem_size, r += elem_size)
memcpy(r, b, elem_size); memcpy(r, b, elem_size);
return true;
} }
void c11__stable_sort(void* ptr_, bool c11__stable_sort(void* ptr_,
int count, int count,
int elem_size, int elem_size,
int (*cmp)(const void* a, const void* b)) { int (*f_le)(const void* a, const void* b)) {
// merge sort // merge sort
char* ptr = ptr_, *tmp = malloc(count * elem_size); char *ptr = ptr_, *tmp = malloc(count * elem_size);
for(int seg = 1; seg < count; seg *= 2) { for(int seg = 1; seg < count; seg *= 2) {
for(char* a = ptr; a < ptr + (count - seg) * elem_size; a += 2 * seg * elem_size) { for(char* a = ptr; a < ptr + (count - seg) * elem_size; a += 2 * seg * elem_size) {
char* b = a + seg * elem_size, *a_end = b, *b_end = b + seg * elem_size; char *b = a + seg * elem_size, *a_end = b, *b_end = b + seg * elem_size;
if (b_end > ptr + count * elem_size) if(b_end > ptr + count * elem_size) b_end = ptr + count * elem_size;
b_end = ptr + count * elem_size; bool ok = merge(a, a_end, b, b_end, tmp, elem_size, f_le);
merge(a, a_end, b, b_end, tmp, elem_size, cmp); if(!ok) {
memcpy(a, tmp, b_end - a); free(tmp);
return false;
}
memcpy(a, tmp, b_end - a);
} }
} }
free(tmp); free(tmp);
return true;
} }

View File

@ -319,6 +319,14 @@ static bool _py_list__insert(int argc, py_Ref argv) {
return true; return true;
} }
static bool _py_list__sort(int argc, py_Ref argv) {
PY_CHECK_ARGC(1);
List* self = py_touserdata(py_arg(0));
c11__stable_sort(self->data, self->count, sizeof(py_TValue), (int (*)(const void*, const void*))py_le);
py_newnone(py_retval());
return true;
}
py_Type pk_list__register() { py_Type pk_list__register() {
pk_VM* vm = pk_current_vm; pk_VM* vm = pk_current_vm;
py_Type type = pk_VM__new_type(vm, "list", tp_object, NULL, false); py_Type type = pk_VM__new_type(vm, "list", tp_object, NULL, false);
@ -346,5 +354,6 @@ py_Type pk_list__register() {
py_bindmethod(type, "remove", _py_list__remove); py_bindmethod(type, "remove", _py_list__remove);
py_bindmethod(type, "pop", _py_list__pop); py_bindmethod(type, "pop", _py_list__pop);
py_bindmethod(type, "insert", _py_list__insert); py_bindmethod(type, "insert", _py_list__insert);
py_bindmethod(type, "sort", _py_list__sort);
return type; return type;
} }

View File

@ -87,6 +87,18 @@ assert list(range(5, 1, -2)) == [5, 3]
# test sort # test sort
a = [8, 2, 4, 2, 9] a = [8, 2, 4, 2, 9]
assert a.sort() == None
assert a == [2, 2, 4, 8, 9]
a = []
assert a.sort() == None
assert a == []
a = [0, 0, 0, 0, 1, 1, 3, -1]
assert a.sort() == None
assert a == [-1, 0, 0, 0, 0, 1, 1, 3]
# test sorted
assert sorted(a) == [2, 2, 4, 8, 9] assert sorted(a) == [2, 2, 4, 8, 9]
assert sorted(a, reverse=True) == [9, 8, 4, 2, 2] assert sorted(a, reverse=True) == [9, 8, 4, 2, 2]