mirror of
https://github.com/pocketpy/pocketpy
synced 2025-10-20 11:30:18 +00:00
add pkl support for array2d
This commit is contained in:
parent
d0546c16da
commit
6b332dbfbb
21
include/pocketpy/interpreter/array2d.h
Normal file
21
include/pocketpy/interpreter/array2d.h
Normal file
@ -0,0 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include "pocketpy/pocketpy.h"
|
||||
|
||||
#include "pocketpy/common/utils.h"
|
||||
#include "pocketpy/common/sstream.h"
|
||||
#include "pocketpy/interpreter/vm.h"
|
||||
|
||||
typedef struct c11_array2d {
|
||||
py_TValue* data; // slots
|
||||
int n_cols;
|
||||
int n_rows;
|
||||
int numel;
|
||||
} c11_array2d;
|
||||
|
||||
typedef struct c11_array2d_iterator {
|
||||
c11_array2d* array;
|
||||
int index;
|
||||
} c11_array2d_iterator;
|
||||
|
||||
c11_array2d* py_newarray2d(py_OutRef out, int n_cols, int n_rows);
|
@ -1,20 +1,4 @@
|
||||
#include "pocketpy/pocketpy.h"
|
||||
|
||||
#include "pocketpy/common/utils.h"
|
||||
#include "pocketpy/common/sstream.h"
|
||||
#include "pocketpy/interpreter/vm.h"
|
||||
|
||||
typedef struct c11_array2d {
|
||||
py_TValue* data; // slots
|
||||
int n_cols;
|
||||
int n_rows;
|
||||
int numel;
|
||||
} c11_array2d;
|
||||
|
||||
typedef struct c11_array2d_iterator {
|
||||
c11_array2d* array;
|
||||
int index;
|
||||
} c11_array2d_iterator;
|
||||
#include "pocketpy/interpreter/array2d.h"
|
||||
|
||||
static bool py_array2d_is_valid(c11_array2d* self, int col, int row) {
|
||||
return col >= 0 && col < self->n_cols && row >= 0 && row < self->n_rows;
|
||||
@ -28,7 +12,7 @@ static void py_array2d__set(c11_array2d* self, int col, int row, py_Ref value) {
|
||||
self->data[row * self->n_cols + col] = *value;
|
||||
}
|
||||
|
||||
static c11_array2d* py_array2d(py_OutRef out, int n_cols, int n_rows) {
|
||||
c11_array2d* py_newarray2d(py_OutRef out, int n_cols, int n_rows) {
|
||||
int numel = n_cols * n_rows;
|
||||
c11_array2d* ud = py_newobject(out, tp_array2d, numel, sizeof(c11_array2d));
|
||||
ud->data = py_getslot(out, 0);
|
||||
@ -49,7 +33,7 @@ static bool array2d__new__(int argc, py_Ref argv) {
|
||||
int n_rows = argv[2]._i64;
|
||||
int numel = n_cols * n_rows;
|
||||
if(n_cols <= 0 || n_rows <= 0) return ValueError("array2d() expected positive dimensions");
|
||||
c11_array2d* ud = py_array2d(py_pushtmp(), n_cols, n_rows);
|
||||
c11_array2d* ud = py_newarray2d(py_pushtmp(), n_cols, n_rows);
|
||||
// setup initial values
|
||||
if(py_callable(default_)) {
|
||||
for(int j = 0; j < n_rows; j++) {
|
||||
@ -191,7 +175,7 @@ static bool array2d_any(int argc, py_Ref argv) {
|
||||
static bool array2d__eq__(int argc, py_Ref argv) {
|
||||
PY_CHECK_ARGC(2);
|
||||
c11_array2d* self = py_touserdata(argv);
|
||||
c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows);
|
||||
c11_array2d* res = py_newarray2d(py_pushtmp(), self->n_cols, self->n_rows);
|
||||
if(py_istype(py_arg(1), tp_array2d)) {
|
||||
c11_array2d* other = py_touserdata(py_arg(1));
|
||||
if(!_array2d_check_same_shape(self, other)) return false;
|
||||
@ -268,7 +252,7 @@ static bool array2d_map(int argc, py_Ref argv) {
|
||||
PY_CHECK_ARGC(2);
|
||||
c11_array2d* self = py_touserdata(argv);
|
||||
py_Ref f = py_arg(1);
|
||||
c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows);
|
||||
c11_array2d* res = py_newarray2d(py_pushtmp(), self->n_cols, self->n_rows);
|
||||
for(int i = 0; i < self->numel; i++) {
|
||||
bool ok = py_call(f, 1, self->data + i);
|
||||
if(!ok) return false;
|
||||
@ -283,7 +267,7 @@ static bool array2d_copy(int argc, py_Ref argv) {
|
||||
// def copy(self) -> 'array2d': ...
|
||||
PY_CHECK_ARGC(1);
|
||||
c11_array2d* self = py_touserdata(argv);
|
||||
c11_array2d* res = py_array2d(py_retval(), self->n_cols, self->n_rows);
|
||||
c11_array2d* res = py_newarray2d(py_retval(), self->n_cols, self->n_rows);
|
||||
memcpy(res->data, self->data, self->numel * sizeof(py_TValue));
|
||||
return true;
|
||||
}
|
||||
@ -356,7 +340,7 @@ static bool array2d_fromlist_STATIC(int argc, py_Ref argv) {
|
||||
return ValueError("fromlist() expected a list of lists with the same length");
|
||||
}
|
||||
}
|
||||
c11_array2d* res = py_array2d(py_retval(), n_cols, n_rows);
|
||||
c11_array2d* res = py_newarray2d(py_retval(), n_cols, n_rows);
|
||||
for(int j = 0; j < n_rows; j++) {
|
||||
py_Ref row_j = py_list_getitem(argv, j);
|
||||
for(int i = 0; i < n_cols; i++) {
|
||||
@ -452,7 +436,7 @@ static bool array2d_get_bounding_rect(int argc, py_Ref argv) {
|
||||
static bool array2d_count_neighbors(int argc, py_Ref argv) {
|
||||
PY_CHECK_ARGC(3);
|
||||
c11_array2d* self = py_touserdata(argv);
|
||||
c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows);
|
||||
c11_array2d* res = py_newarray2d(py_pushtmp(), self->n_cols, self->n_rows);
|
||||
py_Ref value = py_arg(1);
|
||||
const char* neighborhood = py_tostr(py_arg(2));
|
||||
|
||||
@ -556,7 +540,7 @@ static bool array2d__getitem__(int argc, py_Ref argv) {
|
||||
return _array2d_IndexError(self, col, row);
|
||||
} else if(py_istype(x, tp_slice) && py_istype(y, tp_slice)) {
|
||||
HANDLE_SLICE();
|
||||
c11_array2d* res = py_array2d(py_retval(), slice_width, slice_height);
|
||||
c11_array2d* res = py_newarray2d(py_retval(), slice_width, slice_height);
|
||||
for(int j = start_row; j < stop_row; j++) {
|
||||
for(int i = start_col; i < stop_col; i++) {
|
||||
py_array2d__set(res, i - start_col, j - start_row, py_array2d__get(self, i, j));
|
||||
@ -660,7 +644,7 @@ static bool array2d_convolve(int argc, py_Ref argv) {
|
||||
int ksize_half = ksize / 2;
|
||||
if(!_array2d_check_all_type(self, tp_int)) return false;
|
||||
if(!_array2d_check_all_type(kernel, tp_int)) return false;
|
||||
c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows);
|
||||
c11_array2d* res = py_newarray2d(py_pushtmp(), self->n_cols, self->n_rows);
|
||||
for(int j = 0; j < self->n_rows; j++) {
|
||||
for(int i = 0; i < self->n_cols; i++) {
|
||||
py_i64 sum = 0;
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
#include "pocketpy/common/utils.h"
|
||||
#include "pocketpy/common/sstream.h"
|
||||
#include "pocketpy/interpreter/vm.h"
|
||||
#include "pocketpy/interpreter/array2d.h"
|
||||
#include <stdint.h>
|
||||
|
||||
typedef enum {
|
||||
@ -23,6 +23,7 @@ typedef enum {
|
||||
PKL_VEC2, PKL_VEC3,
|
||||
PKL_VEC2I, PKL_VEC3I,
|
||||
PKL_TYPE,
|
||||
PKL_ARRAY2D,
|
||||
PKL_EOF,
|
||||
// clang-format on
|
||||
} PickleOp;
|
||||
@ -299,6 +300,20 @@ static bool pickle__write_object(PickleObject* buf, py_TValue* obj) {
|
||||
c11_string__delete(path);
|
||||
break;
|
||||
}
|
||||
case tp_array2d: {
|
||||
c11_array2d* arr = py_touserdata(obj);
|
||||
for(int i = 0; i < arr->numel; i++) {
|
||||
if(arr->data[i].is_ptr)
|
||||
return TypeError(
|
||||
"'array2d' object is not picklable because it contains heap-allocated objects");
|
||||
}
|
||||
pkl__emit_op(buf, PKL_ARRAY2D);
|
||||
pkl__emit_int(buf, arr->n_cols);
|
||||
pkl__emit_int(buf, arr->n_rows);
|
||||
// TODO: fix type index which is not stable
|
||||
PickleObject__write_bytes(buf, arr->data, arr->numel * sizeof(py_TValue));
|
||||
break;
|
||||
}
|
||||
default: return TypeError("'%t' object is not picklable", obj->type);
|
||||
}
|
||||
if(obj->is_ptr) {
|
||||
@ -503,6 +518,15 @@ bool py_pickle_loads(const unsigned char* data, int size) {
|
||||
py_push(py_tpobject(t));
|
||||
break;
|
||||
}
|
||||
case PKL_ARRAY2D: {
|
||||
int n_cols = pkl__read_int(&p);
|
||||
int n_rows = pkl__read_int(&p);
|
||||
c11_array2d* arr = py_newarray2d(py_pushtmp(), n_cols, n_rows);
|
||||
int total_size = arr->numel * sizeof(py_TValue);
|
||||
memcpy(arr->data, p, total_size);
|
||||
p += total_size;
|
||||
break;
|
||||
}
|
||||
case PKL_EOF: {
|
||||
// [memo, obj]
|
||||
if(py_peek(0) - p0 != 2) return ValueError("invalid pickle data");
|
||||
|
@ -31,6 +31,20 @@ test(vec3i(1, 2, 3)) # PKL_VEC3I
|
||||
|
||||
test(vec3i) # PKL_TYPE
|
||||
|
||||
print('-'*50)
|
||||
from array2d import array2d
|
||||
a = array2d[int].fromlist([
|
||||
[1, 2, 3],
|
||||
[4, 5, 6]
|
||||
])
|
||||
a_encoded = pkl.dumps(a)
|
||||
print(a_encoded)
|
||||
a_decoded = pkl.loads(a_encoded)
|
||||
assert isinstance(a_decoded, array2d)
|
||||
assert a_decoded.width == 3 and a_decoded.height == 2
|
||||
assert (a == a_decoded).all()
|
||||
print(a_decoded)
|
||||
|
||||
test([1, 2, 3]) # PKL_LIST
|
||||
test((1, 2, 3)) # PKL_TUPLE
|
||||
test({1: 2, 3: 4}) # PKL_DICT
|
||||
|
Loading…
x
Reference in New Issue
Block a user