From 6b332dbfbbd52e830dfe0fb98e4fa7071f8fd0d8 Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Sat, 14 Dec 2024 17:08:59 +0800 Subject: [PATCH] add pkl support for array2d --- include/pocketpy/interpreter/array2d.h | 21 +++++++++++++++ src/modules/array2d.c | 36 +++++++------------------- src/modules/pickle.c | 26 ++++++++++++++++++- tests/90_pickle.py | 14 ++++++++++ 4 files changed, 70 insertions(+), 27 deletions(-) create mode 100644 include/pocketpy/interpreter/array2d.h diff --git a/include/pocketpy/interpreter/array2d.h b/include/pocketpy/interpreter/array2d.h new file mode 100644 index 00000000..ac040e67 --- /dev/null +++ b/include/pocketpy/interpreter/array2d.h @@ -0,0 +1,21 @@ +#pragma once + +#include "pocketpy/pocketpy.h" + +#include "pocketpy/common/utils.h" +#include "pocketpy/common/sstream.h" +#include "pocketpy/interpreter/vm.h" + +typedef struct c11_array2d { + py_TValue* data; // slots + int n_cols; + int n_rows; + int numel; +} c11_array2d; + +typedef struct c11_array2d_iterator { + c11_array2d* array; + int index; +} c11_array2d_iterator; + +c11_array2d* py_newarray2d(py_OutRef out, int n_cols, int n_rows); diff --git a/src/modules/array2d.c b/src/modules/array2d.c index 4c010351..9a5bdae7 100644 --- a/src/modules/array2d.c +++ b/src/modules/array2d.c @@ -1,20 +1,4 @@ -#include "pocketpy/pocketpy.h" - -#include "pocketpy/common/utils.h" -#include "pocketpy/common/sstream.h" -#include "pocketpy/interpreter/vm.h" - -typedef struct c11_array2d { - py_TValue* data; // slots - int n_cols; - int n_rows; - int numel; -} c11_array2d; - -typedef struct c11_array2d_iterator { - c11_array2d* array; - int index; -} c11_array2d_iterator; +#include "pocketpy/interpreter/array2d.h" static bool py_array2d_is_valid(c11_array2d* self, int col, int row) { return col >= 0 && col < self->n_cols && row >= 0 && row < self->n_rows; @@ -28,7 +12,7 @@ static void py_array2d__set(c11_array2d* self, int col, int row, py_Ref value) { self->data[row * self->n_cols + col] = *value; } -static c11_array2d* py_array2d(py_OutRef out, int n_cols, int n_rows) { +c11_array2d* py_newarray2d(py_OutRef out, int n_cols, int n_rows) { int numel = n_cols * n_rows; c11_array2d* ud = py_newobject(out, tp_array2d, numel, sizeof(c11_array2d)); ud->data = py_getslot(out, 0); @@ -49,7 +33,7 @@ static bool array2d__new__(int argc, py_Ref argv) { int n_rows = argv[2]._i64; int numel = n_cols * n_rows; if(n_cols <= 0 || n_rows <= 0) return ValueError("array2d() expected positive dimensions"); - c11_array2d* ud = py_array2d(py_pushtmp(), n_cols, n_rows); + c11_array2d* ud = py_newarray2d(py_pushtmp(), n_cols, n_rows); // setup initial values if(py_callable(default_)) { for(int j = 0; j < n_rows; j++) { @@ -191,7 +175,7 @@ static bool array2d_any(int argc, py_Ref argv) { static bool array2d__eq__(int argc, py_Ref argv) { PY_CHECK_ARGC(2); c11_array2d* self = py_touserdata(argv); - c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows); + c11_array2d* res = py_newarray2d(py_pushtmp(), self->n_cols, self->n_rows); if(py_istype(py_arg(1), tp_array2d)) { c11_array2d* other = py_touserdata(py_arg(1)); if(!_array2d_check_same_shape(self, other)) return false; @@ -268,7 +252,7 @@ static bool array2d_map(int argc, py_Ref argv) { PY_CHECK_ARGC(2); c11_array2d* self = py_touserdata(argv); py_Ref f = py_arg(1); - c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows); + c11_array2d* res = py_newarray2d(py_pushtmp(), self->n_cols, self->n_rows); for(int i = 0; i < self->numel; i++) { bool ok = py_call(f, 1, self->data + i); if(!ok) return false; @@ -283,7 +267,7 @@ static bool array2d_copy(int argc, py_Ref argv) { // def copy(self) -> 'array2d': ... PY_CHECK_ARGC(1); c11_array2d* self = py_touserdata(argv); - c11_array2d* res = py_array2d(py_retval(), self->n_cols, self->n_rows); + c11_array2d* res = py_newarray2d(py_retval(), self->n_cols, self->n_rows); memcpy(res->data, self->data, self->numel * sizeof(py_TValue)); return true; } @@ -356,7 +340,7 @@ static bool array2d_fromlist_STATIC(int argc, py_Ref argv) { return ValueError("fromlist() expected a list of lists with the same length"); } } - c11_array2d* res = py_array2d(py_retval(), n_cols, n_rows); + c11_array2d* res = py_newarray2d(py_retval(), n_cols, n_rows); for(int j = 0; j < n_rows; j++) { py_Ref row_j = py_list_getitem(argv, j); for(int i = 0; i < n_cols; i++) { @@ -452,7 +436,7 @@ static bool array2d_get_bounding_rect(int argc, py_Ref argv) { static bool array2d_count_neighbors(int argc, py_Ref argv) { PY_CHECK_ARGC(3); c11_array2d* self = py_touserdata(argv); - c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows); + c11_array2d* res = py_newarray2d(py_pushtmp(), self->n_cols, self->n_rows); py_Ref value = py_arg(1); const char* neighborhood = py_tostr(py_arg(2)); @@ -556,7 +540,7 @@ static bool array2d__getitem__(int argc, py_Ref argv) { return _array2d_IndexError(self, col, row); } else if(py_istype(x, tp_slice) && py_istype(y, tp_slice)) { HANDLE_SLICE(); - c11_array2d* res = py_array2d(py_retval(), slice_width, slice_height); + c11_array2d* res = py_newarray2d(py_retval(), slice_width, slice_height); for(int j = start_row; j < stop_row; j++) { for(int i = start_col; i < stop_col; i++) { py_array2d__set(res, i - start_col, j - start_row, py_array2d__get(self, i, j)); @@ -660,7 +644,7 @@ static bool array2d_convolve(int argc, py_Ref argv) { int ksize_half = ksize / 2; if(!_array2d_check_all_type(self, tp_int)) return false; if(!_array2d_check_all_type(kernel, tp_int)) return false; - c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows); + c11_array2d* res = py_newarray2d(py_pushtmp(), self->n_cols, self->n_rows); for(int j = 0; j < self->n_rows; j++) { for(int i = 0; i < self->n_cols; i++) { py_i64 sum = 0; diff --git a/src/modules/pickle.c b/src/modules/pickle.c index fd3a35d9..c87e0346 100644 --- a/src/modules/pickle.c +++ b/src/modules/pickle.c @@ -3,7 +3,7 @@ #include "pocketpy/common/utils.h" #include "pocketpy/common/sstream.h" -#include "pocketpy/interpreter/vm.h" +#include "pocketpy/interpreter/array2d.h" #include typedef enum { @@ -23,6 +23,7 @@ typedef enum { PKL_VEC2, PKL_VEC3, PKL_VEC2I, PKL_VEC3I, PKL_TYPE, + PKL_ARRAY2D, PKL_EOF, // clang-format on } PickleOp; @@ -299,6 +300,20 @@ static bool pickle__write_object(PickleObject* buf, py_TValue* obj) { c11_string__delete(path); break; } + case tp_array2d: { + c11_array2d* arr = py_touserdata(obj); + for(int i = 0; i < arr->numel; i++) { + if(arr->data[i].is_ptr) + return TypeError( + "'array2d' object is not picklable because it contains heap-allocated objects"); + } + pkl__emit_op(buf, PKL_ARRAY2D); + pkl__emit_int(buf, arr->n_cols); + pkl__emit_int(buf, arr->n_rows); + // TODO: fix type index which is not stable + PickleObject__write_bytes(buf, arr->data, arr->numel * sizeof(py_TValue)); + break; + } default: return TypeError("'%t' object is not picklable", obj->type); } if(obj->is_ptr) { @@ -503,6 +518,15 @@ bool py_pickle_loads(const unsigned char* data, int size) { py_push(py_tpobject(t)); break; } + case PKL_ARRAY2D: { + int n_cols = pkl__read_int(&p); + int n_rows = pkl__read_int(&p); + c11_array2d* arr = py_newarray2d(py_pushtmp(), n_cols, n_rows); + int total_size = arr->numel * sizeof(py_TValue); + memcpy(arr->data, p, total_size); + p += total_size; + break; + } case PKL_EOF: { // [memo, obj] if(py_peek(0) - p0 != 2) return ValueError("invalid pickle data"); diff --git a/tests/90_pickle.py b/tests/90_pickle.py index 21cbc156..a3976097 100644 --- a/tests/90_pickle.py +++ b/tests/90_pickle.py @@ -31,6 +31,20 @@ test(vec3i(1, 2, 3)) # PKL_VEC3I test(vec3i) # PKL_TYPE +print('-'*50) +from array2d import array2d +a = array2d[int].fromlist([ + [1, 2, 3], + [4, 5, 6] +]) +a_encoded = pkl.dumps(a) +print(a_encoded) +a_decoded = pkl.loads(a_encoded) +assert isinstance(a_decoded, array2d) +assert a_decoded.width == 3 and a_decoded.height == 2 +assert (a == a_decoded).all() +print(a_decoded) + test([1, 2, 3]) # PKL_LIST test((1, 2, 3)) # PKL_TUPLE test({1: 2, 3: 4}) # PKL_DICT