update array2d

This commit is contained in:
blueloveTH 2024-11-27 13:37:07 +08:00
parent 1ca69b3a35
commit f4597ed01a
3 changed files with 52 additions and 79 deletions

View File

@ -1,11 +1,9 @@
from typing import Callable, Any, Generic, TypeVar, Literal, overload, Iterator from typing import Callable, Any, Generic, TypeVar, Literal, overload, Iterator
from linalg import vec2i from linalg import vec2i
T = TypeVar('T')
Neighborhood = Literal['Moore', 'von Neumann'] Neighborhood = Literal['Moore', 'von Neumann']
class array2d(Generic[T]): class array2d[T]:
@property @property
def n_cols(self) -> int: ... def n_cols(self) -> int: ...
@property @property
@ -17,24 +15,21 @@ class array2d(Generic[T]):
@property @property
def numel(self) -> int: ... def numel(self) -> int: ...
def __new__(cls, n_cols: int, n_rows: int, default=None): ... def __new__(cls, n_cols: int, n_rows: int, default: Callable[[vec2i], T] = None): ...
def __len__(self) -> int: ...
def __eq__(self, other: object) -> array2d[bool]: ... # type: ignore def __eq__(self, other: object) -> array2d[bool]: ... # type: ignore
def __ne__(self, other: object) -> array2d[bool]: ... # type: ignore def __ne__(self, other: object) -> array2d[bool]: ... # type: ignore
def __repr__(self) -> str: ... def __repr__(self) -> str: ...
def __iter__(self) -> Iterator[tuple[int, int, T]]: ... def __iter__(self) -> Iterator[tuple[vec2i, T]]: ...
@overload @overload
def is_valid(self, col: int, row: int) -> bool: ... def is_valid(self, col: int, row: int) -> bool: ...
@overload @overload
def is_valid(self, pos: vec2i) -> bool: ... def is_valid(self, pos: vec2i) -> bool: ...
def get(self, col: int, row: int, default=None) -> T | None: @overload
"""Returns the value at the given position or the default value if out of bounds.""" def get[R](self, col: int, row: int, default: R) -> T | R: ...
def unsafe_get(self, col: int, row: int) -> T: @overload
"""Returns the value at the given position without bounds checking.""" def get[R](self, pos: vec2i, default: R) -> T | R: ...
def unsafe_set(self, col: int, row: int, value: T):
"""Sets the value at the given position without bounds checking."""
@overload @overload
def __getitem__(self, index: tuple[int, int]) -> T: ... def __getitem__(self, index: tuple[int, int]) -> T: ...

View File

@ -40,7 +40,7 @@ static c11_array2d* py_array2d(py_OutRef out, int n_cols, int n_rows) {
/* bindings */ /* bindings */
static bool array2d__new__(int argc, py_Ref argv) { static bool array2d__new__(int argc, py_Ref argv) {
// __new__(cls, n_cols: int, n_rows: int, default=None) // __new__(cls, n_cols: int, n_rows: int, default: Callable[[vec2i], T] = None)
py_Ref default_ = py_arg(3); py_Ref default_ = py_arg(3);
PY_CHECK_ARG_TYPE(0, tp_type); PY_CHECK_ARG_TYPE(0, tp_type);
PY_CHECK_ARG_TYPE(1, tp_int); PY_CHECK_ARG_TYPE(1, tp_int);
@ -52,10 +52,17 @@ static bool array2d__new__(int argc, py_Ref argv) {
c11_array2d* ud = py_array2d(py_pushtmp(), n_cols, n_rows); c11_array2d* ud = py_array2d(py_pushtmp(), n_cols, n_rows);
// setup initial values // setup initial values
if(py_callable(default_)) { if(py_callable(default_)) {
for(int i = 0; i < numel; i++) { for(int j = 0; j < n_rows; j++) {
bool ok = py_call(default_, 0, NULL); for(int i = 0; i < n_cols; i++) {
py_TValue tmp;
py_newvec2i(&tmp,
(c11_vec2i){
{i, j}
});
bool ok = py_call(default_, 1, &tmp);
if(!ok) return false; if(!ok) return false;
ud->data[i] = *py_retval(); ud->data[j * n_cols + i] = *py_retval();
}
} }
} else { } else {
for(int i = 0; i < numel; i++) { for(int i = 0; i < numel; i++) {
@ -111,17 +118,24 @@ static bool array2d_is_valid(int argc, py_Ref argv) {
static bool array2d_get(int argc, py_Ref argv) { static bool array2d_get(int argc, py_Ref argv) {
py_Ref default_; py_Ref default_;
c11_array2d* self = py_touserdata(argv); c11_array2d* self = py_touserdata(argv);
int col, row;
if(argc == 3) {
// get[R](self, pos: vec2i, default: R) -> T | R
PY_CHECK_ARG_TYPE(1, tp_vec2i);
c11_vec2i pos = py_tovec2i(py_arg(1));
col = pos.x;
row = pos.y;
default_ = py_arg(2);
} else if(argc == 4) {
// get(self, col: int, row: int, default: T) -> T
PY_CHECK_ARG_TYPE(1, tp_int); PY_CHECK_ARG_TYPE(1, tp_int);
PY_CHECK_ARG_TYPE(2, tp_int); PY_CHECK_ARG_TYPE(2, tp_int);
if(argc == 3) { col = py_toint(py_arg(1));
default_ = py_None(); row = py_toint(py_arg(2));
} else if(argc == 4) {
default_ = py_arg(3); default_ = py_arg(3);
} else { } else {
return TypeError("get() expected 3 or 4 arguments"); return TypeError("get() expected 3 or 4 arguments");
} }
int col = py_toint(py_arg(1));
int row = py_toint(py_arg(2));
if(py_array2d_is_valid(self, col, row)) { if(py_array2d_is_valid(self, col, row)) {
py_assign(py_retval(), py_array2d__get(self, col, row)); py_assign(py_retval(), py_array2d__get(self, col, row));
} else { } else {
@ -130,36 +144,6 @@ static bool array2d_get(int argc, py_Ref argv) {
return true; return true;
} }
static bool array2d_unsafe_get(int argc, py_Ref argv) {
PY_CHECK_ARGC(3);
c11_array2d* self = py_touserdata(argv);
PY_CHECK_ARG_TYPE(1, tp_int);
PY_CHECK_ARG_TYPE(2, tp_int);
int col = py_toint(py_arg(1));
int row = py_toint(py_arg(2));
py_assign(py_retval(), py_array2d__get(self, col, row));
return true;
}
static bool array2d_unsafe_set(int argc, py_Ref argv) {
PY_CHECK_ARGC(4);
c11_array2d* self = py_touserdata(argv);
PY_CHECK_ARG_TYPE(1, tp_int);
PY_CHECK_ARG_TYPE(2, tp_int);
int col = py_toint(py_arg(1));
int row = py_toint(py_arg(2));
py_array2d__set(self, col, row, py_arg(3));
py_newnone(py_retval());
return true;
}
static bool array2d__len__(int argc, py_Ref argv) {
PY_CHECK_ARGC(1);
c11_array2d* self = py_touserdata(argv);
py_newint(py_retval(), self->numel);
return true;
}
static bool _array2d_check_all_type(c11_array2d* self, py_Type type) { static bool _array2d_check_all_type(c11_array2d* self, py_Type type) {
for(int i = 0; i < self->numel; i++) { for(int i = 0; i < self->numel; i++) {
py_Type item_type = self->data[i].type; py_Type item_type = self->data[i].type;
@ -273,11 +257,13 @@ static bool array2d_iterator__next__(int argc, py_Ref argv) {
c11_array2d_iterator* self = py_touserdata(argv); c11_array2d_iterator* self = py_touserdata(argv);
if(self->index < self->array->numel) { if(self->index < self->array->numel) {
div_t res = div(self->index, self->array->n_cols); div_t res = div(self->index, self->array->n_cols);
py_newtuple(py_retval(), 3); py_newtuple(py_retval(), 2);
py_TValue* data = py_tuple_data(py_retval()); py_TValue* data = py_tuple_data(py_retval());
py_newint(&data[0], res.rem); py_newvec2i(&data[0],
py_newint(&data[1], res.quot); (c11_vec2i){
py_assign(&data[2], self->array->data + self->index); {res.rem, res.quot}
});
py_assign(&data[1], self->array->data + self->index);
self->index++; self->index++;
return true; return true;
} }
@ -725,7 +711,6 @@ void pk__add_module_array2d() {
"__new__(cls, n_cols: int, n_rows: int, default=None)", "__new__(cls, n_cols: int, n_rows: int, default=None)",
array2d__new__); array2d__new__);
py_bindmagic(array2d, __len__, array2d__len__);
py_bindmagic(array2d, __eq__, array2d__eq__); py_bindmagic(array2d, __eq__, array2d__eq__);
py_bindmagic(array2d, __ne__, array2d__ne__); py_bindmagic(array2d, __ne__, array2d__ne__);
py_bindmagic(array2d, __repr__, array2d__repr__); py_bindmagic(array2d, __repr__, array2d__repr__);
@ -742,8 +727,6 @@ void pk__add_module_array2d() {
py_bindmethod(array2d, "is_valid", array2d_is_valid); py_bindmethod(array2d, "is_valid", array2d_is_valid);
py_bindmethod(array2d, "get", array2d_get); py_bindmethod(array2d, "get", array2d_get);
py_bindmethod(array2d, "unsafe_get", array2d_unsafe_get);
py_bindmethod(array2d, "unsafe_set", array2d_unsafe_set);
py_bindmethod(array2d, "map", array2d_map); py_bindmethod(array2d, "map", array2d_map);
py_bindmethod(array2d, "copy", array2d_copy); py_bindmethod(array2d, "copy", array2d_copy);

View File

@ -9,11 +9,16 @@ except ValueError:
pass pass
# test callable constructor # test callable constructor
a = array2d[int](2, 4, lambda: 0) a = array2d[int](2, 4, lambda pos: (pos.x, pos.y))
assert a.width == a.n_cols == 2 assert a.width == a.n_cols == 2
assert a.height == a.n_rows == 4 assert a.height == a.n_rows == 4
assert a.numel == 8 assert a.numel == 8
assert a.tolist() == [
[(0, 0), (1, 0)],
[(0, 1), (1, 1)],
[(0, 2), (1, 2)],
[(0, 3), (1, 3)]]
# test is_valid # test is_valid
assert a.is_valid(0, 0) and a.is_valid(vec2i(0, 0)) assert a.is_valid(0, 0) and a.is_valid(vec2i(0, 0))
@ -24,14 +29,14 @@ assert not a.is_valid(-1, 0) and not a.is_valid(vec2i(-1, 0))
assert not a.is_valid(0, -1) and not a.is_valid(vec2i(0, -1)) assert not a.is_valid(0, -1) and not a.is_valid(vec2i(0, -1))
# test get # test get
assert a.get(0, 0) == 0 assert a.get(0, 0, -1) == (0, 0)
assert a.get(1, 3) == 0 assert a.get(vec2i(1, 3), -1) == (1, 3)
assert a.get(2, 0) is None assert a.get(2, 0, None) is None
assert a.get(0, 4, 'S') == 'S' assert a.get(vec2i(0, 4), 'S') == 'S'
# test __getitem__ # test __getitem__
assert a[0, 0] == 0 assert a[0, 0] == (0, 0)
assert a[1, 3] == 0 assert a[1, 3] == (1, 3)
try: try:
a[2, 0] a[2, 0]
exit(1) exit(1)
@ -39,6 +44,7 @@ except IndexError:
pass pass
# test __setitem__ # test __setitem__
a = array2d[int](2, 4, default=0)
a[0, 0] = 5 a[0, 0] = 5
assert a[0, 0] == 5 assert a[0, 0] == 5
a[1, 3] = 6 a[1, 3] = 6
@ -53,9 +59,6 @@ except IndexError:
a_list = [[5, 0], [0, 0], [0, 0], [0, 6]] a_list = [[5, 0], [0, 0], [0, 0], [0, 6]]
assert a_list == a.tolist() assert a_list == a.tolist()
# test __len__
assert len(a) == 4*2
# test __eq__ # test __eq__
x = array2d(2, 4, default=0) x = array2d(2, 4, default=0)
b = array2d(2, 4, default=0) b = array2d(2, 4, default=0)
@ -174,16 +177,8 @@ except TypeError:
# test __iter__ # test __iter__
a = array2d(3, 4, default=1) a = array2d(3, 4, default=1)
for i, j, x in a: for xy, val in a:
assert a[i, j] == x assert a[xy] == x
assert len(a) == a.numel
# test _get and _set
a = array2d(3, 4, default=1)
assert a.unsafe_get(0, 0) == 1
a.unsafe_set(0, 0, 2)
assert a.unsafe_get(0, 0) == 2
# test convolve # test convolve
a = array2d[int].fromlist([[1, 0, 2, 4, 0], [3, 1, 0, 5, 1]]) a = array2d[int].fromlist([[1, 0, 2, 4, 0], [3, 1, 0, 5, 1]])