From 4c15f278d00234297b11b2a6c19625059dc3aab0 Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Sun, 26 May 2024 01:53:21 +0800 Subject: [PATCH] add `_get` and `_set` --- include/typings/array2d.pyi | 3 +++ src/array2d.cpp | 50 ++++++++++++++++++++++++++----------- tests/83_array2d.py | 6 +++++ 3 files changed, 44 insertions(+), 15 deletions(-) diff --git a/include/typings/array2d.pyi b/include/typings/array2d.pyi index 445cf057..48d37ad9 100644 --- a/include/typings/array2d.pyi +++ b/include/typings/array2d.pyi @@ -17,6 +17,9 @@ class array2d(Generic[T]): def get(self, col: int, row: int, default=None) -> T | None: ... + def _get(self, col: int, row: int) -> T: ... + def _set(self, col: int, row: int, value: T) -> None: ... + @overload def __getitem__(self, index: tuple[int, int]) -> T: ... @overload diff --git a/src/array2d.cpp b/src/array2d.cpp index 6a1a09a1..c04b8529 100644 --- a/src/array2d.cpp +++ b/src/array2d.cpp @@ -28,6 +28,11 @@ struct Array2d{ return 0 <= col && col < n_cols && 0 <= row && row < n_rows; } + void check_valid(VM* vm, int col, int row) const{ + if(is_valid(col, row)) return; + vm->IndexError(_S('(', col, ", ", row, ')', " is not a valid index for array2d(", n_cols, ", ", n_rows, ')')); + } + PyVar _get(int col, int row){ return data[row * n_cols + col]; } @@ -64,7 +69,26 @@ struct Array2d{ PY_READONLY_FIELD(Array2d, "height", n_rows); PY_READONLY_FIELD(Array2d, "numel", numel); - vm->bind(type, "is_valid(self, col: int, row: int)", [](VM* vm, ArgsView args){ + // _get + vm->bind_func(type, "_get", 3, [](VM* vm, ArgsView args){ + Array2d& self = PK_OBJ_GET(Array2d, args[0]); + int col = CAST(int, args[1]); + int row = CAST(int, args[2]); + self.check_valid(vm, col, row); + return self._get(col, row); + }); + + // _set + vm->bind_func(type, "_set", 4, [](VM* vm, ArgsView args){ + Array2d& self = PK_OBJ_GET(Array2d, args[0]); + int col = CAST(int, args[1]); + int row = CAST(int, args[2]); + self.check_valid(vm, col, row); + self._set(col, row, args[3]); + return vm->None; + }); + + vm->bind_func(type, "is_valid", 3, [](VM* vm, ArgsView args){ Array2d& self = PK_OBJ_GET(Array2d, args[0]); int col = CAST(int, args[1]); int row = CAST(int, args[2]); @@ -94,9 +118,7 @@ struct Array2d{ const Tuple& xy = CAST(Tuple&, _1); i64 col, row; if(try_cast_int(xy[0], &col) && try_cast_int(xy[1], &row)){ - if(!self.is_valid(col, row)){ - vm->IndexError(_S('(', col, ", ", row, ')', " is not a valid index for array2d(", self.n_cols, ", ", self.n_rows, ')')); - } + self.check_valid(vm, col, row); return self._get(col, row); } @@ -121,9 +143,7 @@ struct Array2d{ const Tuple& xy = CAST(Tuple&, _1); i64 col, row; if(try_cast_int(xy[0], &col) && try_cast_int(xy[1], &row)){ - if(!self.is_valid(col, row)){ - vm->IndexError(_S('(', col, ", ", row, ')', " is not a valid index for array2d(", self.n_cols, ", ", self.n_rows, ')')); - } + self.check_valid(vm, col, row); self._set(col, row, _2); return; } @@ -165,7 +185,7 @@ struct Array2d{ #undef HANDLE_SLICE - vm->bind(type, "tolist(self)", [](VM* vm, ArgsView args){ + vm->bind_func(type, "tolist", 1, [](VM* vm, ArgsView args){ Array2d& self = PK_OBJ_GET(Array2d, args[0]); List t(self.n_rows); for(int j = 0; j < self.n_rows; j++){ @@ -186,7 +206,7 @@ struct Array2d{ return _S("array2d(", self.n_cols, ", ", self.n_rows, ')'); }); - vm->bind(type, "map(self, f)", [](VM* vm, ArgsView args){ + vm->bind_func(type, "map", 2, [](VM* vm, ArgsView args){ Array2d& self = PK_OBJ_GET(Array2d, args[0]); PyVar f = args[1]; PyVar new_array_obj = vm->new_user_object(); @@ -198,7 +218,7 @@ struct Array2d{ return new_array_obj; }); - vm->bind(type, "copy(self)", [](VM* vm, ArgsView args){ + vm->bind_func(type, "copy", 1, [](VM* vm, ArgsView args){ Array2d& self = PK_OBJ_GET(Array2d, args[0]); PyVar new_array_obj = vm->new_user_object(); Array2d& new_array = PK_OBJ_GET(Array2d, new_array_obj); @@ -209,7 +229,7 @@ struct Array2d{ return new_array_obj; }); - vm->bind(type, "fill_(self, value)", [](VM* vm, ArgsView args){ + vm->bind_func(type, "fill_", 2, [](VM* vm, ArgsView args){ Array2d& self = PK_OBJ_GET(Array2d, args[0]); for(int i = 0; i < self.numel; i++){ self.data[i] = args[1]; @@ -217,7 +237,7 @@ struct Array2d{ return vm->None; }); - vm->bind(type, "apply_(self, f)", [](VM* vm, ArgsView args){ + vm->bind_func(type, "apply_", 2, [](VM* vm, ArgsView args){ Array2d& self = PK_OBJ_GET(Array2d, args[0]); PyVar f = args[1]; for(int i = 0; i < self.numel; i++){ @@ -226,7 +246,7 @@ struct Array2d{ return vm->None; }); - vm->bind(type, "copy_(self, other)", [](VM* vm, ArgsView args){ + vm->bind_func(type, "copy_", 2, [](VM* vm, ArgsView args){ Array2d& self = PK_OBJ_GET(Array2d, args[0]); if(is_type(args[1], VM::tp_list)){ const List& list = PK_OBJ_GET(List, args[1]); @@ -300,7 +320,7 @@ struct Array2d{ return new_array_obj; }); - vm->bind(type, "count(self, value) -> int", [](VM* vm, ArgsView args){ + vm->bind_func(type, "count", 2, [](VM* vm, ArgsView args){ Array2d& self = PK_OBJ_GET(Array2d, args[0]); PyVar value = args[1]; int count = 0; @@ -308,7 +328,7 @@ struct Array2d{ return VAR(count); }); - vm->bind(type, "find_bounding_rect(self, value)", [](VM* vm, ArgsView args){ + vm->bind_func(type, "find_bounding_rect", 2, [](VM* vm, ArgsView args){ Array2d& self = PK_OBJ_GET(Array2d, args[0]); PyVar value = args[1]; int left = self.n_cols; diff --git a/tests/83_array2d.py b/tests/83_array2d.py index 8d94fe88..ac0b6b9c 100644 --- a/tests/83_array2d.py +++ b/tests/83_array2d.py @@ -173,6 +173,12 @@ for i, j, x in a: assert len(a) == a.numel +# test _get and _set +a = array2d(3, 4, default=1) +assert a._get(0, 0) == 1 +a._set(0, 0, 2) +assert a._get(0, 0) == 2 + # stackoverflow bug due to recursive mark-and-sweep # class Cell: # neighbors: list['Cell']