add pkl support for array2d

This commit is contained in:
blueloveTH 2024-12-14 17:08:59 +08:00
parent d0546c16da
commit 6b332dbfbb
4 changed files with 70 additions and 27 deletions

View File

@ -0,0 +1,21 @@
#pragma once
#include "pocketpy/pocketpy.h"
#include "pocketpy/common/utils.h"
#include "pocketpy/common/sstream.h"
#include "pocketpy/interpreter/vm.h"
typedef struct c11_array2d {
py_TValue* data; // slots
int n_cols;
int n_rows;
int numel;
} c11_array2d;
typedef struct c11_array2d_iterator {
c11_array2d* array;
int index;
} c11_array2d_iterator;
c11_array2d* py_newarray2d(py_OutRef out, int n_cols, int n_rows);

View File

@ -1,20 +1,4 @@
#include "pocketpy/pocketpy.h"
#include "pocketpy/common/utils.h"
#include "pocketpy/common/sstream.h"
#include "pocketpy/interpreter/vm.h"
typedef struct c11_array2d {
py_TValue* data; // slots
int n_cols;
int n_rows;
int numel;
} c11_array2d;
typedef struct c11_array2d_iterator {
c11_array2d* array;
int index;
} c11_array2d_iterator;
#include "pocketpy/interpreter/array2d.h"
static bool py_array2d_is_valid(c11_array2d* self, int col, int row) {
return col >= 0 && col < self->n_cols && row >= 0 && row < self->n_rows;
@ -28,7 +12,7 @@ static void py_array2d__set(c11_array2d* self, int col, int row, py_Ref value) {
self->data[row * self->n_cols + col] = *value;
}
static c11_array2d* py_array2d(py_OutRef out, int n_cols, int n_rows) {
c11_array2d* py_newarray2d(py_OutRef out, int n_cols, int n_rows) {
int numel = n_cols * n_rows;
c11_array2d* ud = py_newobject(out, tp_array2d, numel, sizeof(c11_array2d));
ud->data = py_getslot(out, 0);
@ -49,7 +33,7 @@ static bool array2d__new__(int argc, py_Ref argv) {
int n_rows = argv[2]._i64;
int numel = n_cols * n_rows;
if(n_cols <= 0 || n_rows <= 0) return ValueError("array2d() expected positive dimensions");
c11_array2d* ud = py_array2d(py_pushtmp(), n_cols, n_rows);
c11_array2d* ud = py_newarray2d(py_pushtmp(), n_cols, n_rows);
// setup initial values
if(py_callable(default_)) {
for(int j = 0; j < n_rows; j++) {
@ -191,7 +175,7 @@ static bool array2d_any(int argc, py_Ref argv) {
static bool array2d__eq__(int argc, py_Ref argv) {
PY_CHECK_ARGC(2);
c11_array2d* self = py_touserdata(argv);
c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows);
c11_array2d* res = py_newarray2d(py_pushtmp(), self->n_cols, self->n_rows);
if(py_istype(py_arg(1), tp_array2d)) {
c11_array2d* other = py_touserdata(py_arg(1));
if(!_array2d_check_same_shape(self, other)) return false;
@ -268,7 +252,7 @@ static bool array2d_map(int argc, py_Ref argv) {
PY_CHECK_ARGC(2);
c11_array2d* self = py_touserdata(argv);
py_Ref f = py_arg(1);
c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows);
c11_array2d* res = py_newarray2d(py_pushtmp(), self->n_cols, self->n_rows);
for(int i = 0; i < self->numel; i++) {
bool ok = py_call(f, 1, self->data + i);
if(!ok) return false;
@ -283,7 +267,7 @@ static bool array2d_copy(int argc, py_Ref argv) {
// def copy(self) -> 'array2d': ...
PY_CHECK_ARGC(1);
c11_array2d* self = py_touserdata(argv);
c11_array2d* res = py_array2d(py_retval(), self->n_cols, self->n_rows);
c11_array2d* res = py_newarray2d(py_retval(), self->n_cols, self->n_rows);
memcpy(res->data, self->data, self->numel * sizeof(py_TValue));
return true;
}
@ -356,7 +340,7 @@ static bool array2d_fromlist_STATIC(int argc, py_Ref argv) {
return ValueError("fromlist() expected a list of lists with the same length");
}
}
c11_array2d* res = py_array2d(py_retval(), n_cols, n_rows);
c11_array2d* res = py_newarray2d(py_retval(), n_cols, n_rows);
for(int j = 0; j < n_rows; j++) {
py_Ref row_j = py_list_getitem(argv, j);
for(int i = 0; i < n_cols; i++) {
@ -452,7 +436,7 @@ static bool array2d_get_bounding_rect(int argc, py_Ref argv) {
static bool array2d_count_neighbors(int argc, py_Ref argv) {
PY_CHECK_ARGC(3);
c11_array2d* self = py_touserdata(argv);
c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows);
c11_array2d* res = py_newarray2d(py_pushtmp(), self->n_cols, self->n_rows);
py_Ref value = py_arg(1);
const char* neighborhood = py_tostr(py_arg(2));
@ -556,7 +540,7 @@ static bool array2d__getitem__(int argc, py_Ref argv) {
return _array2d_IndexError(self, col, row);
} else if(py_istype(x, tp_slice) && py_istype(y, tp_slice)) {
HANDLE_SLICE();
c11_array2d* res = py_array2d(py_retval(), slice_width, slice_height);
c11_array2d* res = py_newarray2d(py_retval(), slice_width, slice_height);
for(int j = start_row; j < stop_row; j++) {
for(int i = start_col; i < stop_col; i++) {
py_array2d__set(res, i - start_col, j - start_row, py_array2d__get(self, i, j));
@ -660,7 +644,7 @@ static bool array2d_convolve(int argc, py_Ref argv) {
int ksize_half = ksize / 2;
if(!_array2d_check_all_type(self, tp_int)) return false;
if(!_array2d_check_all_type(kernel, tp_int)) return false;
c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows);
c11_array2d* res = py_newarray2d(py_pushtmp(), self->n_cols, self->n_rows);
for(int j = 0; j < self->n_rows; j++) {
for(int i = 0; i < self->n_cols; i++) {
py_i64 sum = 0;

View File

@ -3,7 +3,7 @@
#include "pocketpy/common/utils.h"
#include "pocketpy/common/sstream.h"
#include "pocketpy/interpreter/vm.h"
#include "pocketpy/interpreter/array2d.h"
#include <stdint.h>
typedef enum {
@ -23,6 +23,7 @@ typedef enum {
PKL_VEC2, PKL_VEC3,
PKL_VEC2I, PKL_VEC3I,
PKL_TYPE,
PKL_ARRAY2D,
PKL_EOF,
// clang-format on
} PickleOp;
@ -299,6 +300,20 @@ static bool pickle__write_object(PickleObject* buf, py_TValue* obj) {
c11_string__delete(path);
break;
}
case tp_array2d: {
c11_array2d* arr = py_touserdata(obj);
for(int i = 0; i < arr->numel; i++) {
if(arr->data[i].is_ptr)
return TypeError(
"'array2d' object is not picklable because it contains heap-allocated objects");
}
pkl__emit_op(buf, PKL_ARRAY2D);
pkl__emit_int(buf, arr->n_cols);
pkl__emit_int(buf, arr->n_rows);
// TODO: fix type index which is not stable
PickleObject__write_bytes(buf, arr->data, arr->numel * sizeof(py_TValue));
break;
}
default: return TypeError("'%t' object is not picklable", obj->type);
}
if(obj->is_ptr) {
@ -503,6 +518,15 @@ bool py_pickle_loads(const unsigned char* data, int size) {
py_push(py_tpobject(t));
break;
}
case PKL_ARRAY2D: {
int n_cols = pkl__read_int(&p);
int n_rows = pkl__read_int(&p);
c11_array2d* arr = py_newarray2d(py_pushtmp(), n_cols, n_rows);
int total_size = arr->numel * sizeof(py_TValue);
memcpy(arr->data, p, total_size);
p += total_size;
break;
}
case PKL_EOF: {
// [memo, obj]
if(py_peek(0) - p0 != 2) return ValueError("invalid pickle data");

View File

@ -31,6 +31,20 @@ test(vec3i(1, 2, 3)) # PKL_VEC3I
test(vec3i) # PKL_TYPE
print('-'*50)
from array2d import array2d
a = array2d[int].fromlist([
[1, 2, 3],
[4, 5, 6]
])
a_encoded = pkl.dumps(a)
print(a_encoded)
a_decoded = pkl.loads(a_encoded)
assert isinstance(a_decoded, array2d)
assert a_decoded.width == 3 and a_decoded.height == 2
assert (a == a_decoded).all()
print(a_decoded)
test([1, 2, 3]) # PKL_LIST
test((1, 2, 3)) # PKL_TUPLE
test({1: 2, 3: 4}) # PKL_DICT