From bb653bd38377e186513239d3d58c69e962c987ec Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Mon, 30 Sep 2024 02:11:19 +0800 Subject: [PATCH] improve `array2d` --- include/typings/array2d.pyi | 9 +++++--- src/modules/array2d.c | 45 ++++++++++++++++++++++++++----------- tests/90_array2d.py | 9 +++++--- 3 files changed, 44 insertions(+), 19 deletions(-) diff --git a/include/typings/array2d.pyi b/include/typings/array2d.pyi index fe52a0cb..b47d5c70 100644 --- a/include/typings/array2d.pyi +++ b/include/typings/array2d.pyi @@ -1,4 +1,5 @@ from typing import Callable, Any, Generic, TypeVar, Literal, overload, Iterator +from linalg import vec2i T = TypeVar('T') @@ -16,10 +17,8 @@ class array2d(Generic[T]): @property def numel(self) -> int: ... - def __new__(self, n_cols: int, n_rows: int, default=None): ... + def __new__(cls, n_cols: int, n_rows: int, default=None): ... def __len__(self) -> int: ... - def __eq__(self, other: 'array2d') -> bool: ... - def __ne__(self, other: 'array2d') -> bool: ... def __repr__(self) -> str: ... def __iter__(self) -> Iterator[tuple[int, int, T]]: ... @@ -35,10 +34,14 @@ class array2d(Generic[T]): @overload def __getitem__(self, index: tuple[int, int]) -> T: ... @overload + def __getitem__(self, index: vec2i) -> T: ... + @overload def __getitem__(self, index: tuple[slice, slice]) -> 'array2d[T]': ... @overload def __setitem__(self, index: tuple[int, int], value: T): ... @overload + def __setitem__(self, index: vec2i, value: T): ... + @overload def __setitem__(self, index: tuple[slice, slice], value: int | float | str | bool | None | 'array2d[T]'): ... def map(self, f: Callable[[T], Any]) -> 'array2d': ... diff --git a/src/modules/array2d.c b/src/modules/array2d.c index 6502cf99..a860dde7 100644 --- a/src/modules/array2d.c +++ b/src/modules/array2d.c @@ -426,13 +426,30 @@ static bool array2d_count_neighbors(int argc, py_Ref argv) { if(slice_width <= 0 || slice_height <= 0) \ return ValueError("slice width and height must be positive"); +static bool _array2d_IndexError(c11_array2d* self, int col, int row) { + return IndexError("(%d, %d) is not a valid index of array2d(%d, %d)", + col, + row, + self->n_cols, + self->n_rows); +} + static bool array2d__getitem__(int argc, py_Ref argv) { PY_CHECK_ARGC(2); + c11_array2d* self = py_touserdata(argv); + if(argv[1].type == tp_vec2i) { + // fastpath for vec2i + c11_vec2i pos = py_tovec2i(&argv[1]); + if(py_array2d_is_valid(self, pos.x, pos.y)) { + py_assign(py_retval(), py_array2d__get(self, pos.x, pos.y)); + return true; + } + return _array2d_IndexError(self, pos.x, pos.y); + } PY_CHECK_ARG_TYPE(1, tp_tuple); if(py_tuple_len(py_arg(1)) != 2) return TypeError("expected a tuple of 2 elements"); py_Ref x = py_tuple_getitem(py_arg(1), 0); py_Ref y = py_tuple_getitem(py_arg(1), 1); - c11_array2d* self = py_touserdata(argv); if(py_isint(x) && py_isint(y)) { int col = py_toint(x); int row = py_toint(y); @@ -440,11 +457,7 @@ static bool array2d__getitem__(int argc, py_Ref argv) { py_assign(py_retval(), py_array2d__get(self, col, row)); return true; } - return IndexError("(%d, %d) is not a valid index of array2d(%d, %d)", - col, - row, - self->n_cols, - self->n_rows); + return _array2d_IndexError(self, col, row); } else if(py_istype(x, tp_slice) && py_istype(y, tp_slice)) { HANDLE_SLICE(); c11_array2d* res = py_array2d(py_retval(), slice_width, slice_height); @@ -461,12 +474,22 @@ static bool array2d__getitem__(int argc, py_Ref argv) { static bool array2d__setitem__(int argc, py_Ref argv) { PY_CHECK_ARGC(3); + c11_array2d* self = py_touserdata(argv); + py_Ref value = py_arg(2); + if(argv[1].type == tp_vec2i) { + // fastpath for vec2i + c11_vec2i pos = py_tovec2i(&argv[1]); + if(py_array2d_is_valid(self, pos.x, pos.y)) { + py_array2d__set(self, pos.x, pos.y, value); + py_newnone(py_retval()); + return true; + } + return _array2d_IndexError(self, pos.x, pos.y); + } PY_CHECK_ARG_TYPE(1, tp_tuple); if(py_tuple_len(py_arg(1)) != 2) return TypeError("expected a tuple of 2 elements"); py_Ref x = py_tuple_getitem(py_arg(1), 0); py_Ref y = py_tuple_getitem(py_arg(1), 1); - c11_array2d* self = py_touserdata(argv); - py_Ref value = py_arg(2); if(py_isint(x) && py_isint(y)) { int col = py_toint(x); int row = py_toint(y); @@ -475,11 +498,7 @@ static bool array2d__setitem__(int argc, py_Ref argv) { py_newnone(py_retval()); return true; } - return IndexError("(%d, %d) is not a valid index of array2d(%d, %d)", - col, - row, - self->n_cols, - self->n_rows); + return _array2d_IndexError(self, col, row); } else if(py_istype(x, tp_slice) && py_istype(y, tp_slice)) { HANDLE_SLICE(); bool is_basic_type = false; diff --git a/tests/90_array2d.py b/tests/90_array2d.py index 840e1021..c8eb2bc0 100644 --- a/tests/90_array2d.py +++ b/tests/90_array2d.py @@ -108,8 +108,11 @@ moore_result[1, 1] = 0 von_neumann_result = array2d(3, 3, default=0) von_neumann_result[0, 1] = von_neumann_result[1, 0] = von_neumann_result[1, 2] = von_neumann_result[2, 1] = 1 -a.count_neighbors(0, 'Moore') == moore_result -a.count_neighbors(0, 'von Neumann') == von_neumann_result + +_0 = a.count_neighbors(1, 'Moore') +assert _0 == moore_result +_1 = a.count_neighbors(1, 'von Neumann') +assert _1 == von_neumann_result # test slice get a = array2d(5, 5, default=0) @@ -152,7 +155,7 @@ except ValueError: pass try: - a[:, :] = [] + a[:, :] = ... exit(1) except TypeError: pass