update array2d

This commit is contained in:
blueloveTH 2024-11-27 13:37:07 +08:00
parent 1ca69b3a35
commit f4597ed01a
3 changed files with 52 additions and 79 deletions

View File

@ -1,11 +1,9 @@
from typing import Callable, Any, Generic, TypeVar, Literal, overload, Iterator
from linalg import vec2i
T = TypeVar('T')
Neighborhood = Literal['Moore', 'von Neumann']
class array2d(Generic[T]):
class array2d[T]:
@property
def n_cols(self) -> int: ...
@property
@ -17,24 +15,21 @@ class array2d(Generic[T]):
@property
def numel(self) -> int: ...
def __new__(cls, n_cols: int, n_rows: int, default=None): ...
def __len__(self) -> int: ...
def __new__(cls, n_cols: int, n_rows: int, default: Callable[[vec2i], T] = None): ...
def __eq__(self, other: object) -> array2d[bool]: ... # type: ignore
def __ne__(self, other: object) -> array2d[bool]: ... # type: ignore
def __repr__(self) -> str: ...
def __iter__(self) -> Iterator[tuple[int, int, T]]: ...
def __iter__(self) -> Iterator[tuple[vec2i, T]]: ...
@overload
def is_valid(self, col: int, row: int) -> bool: ...
@overload
def is_valid(self, pos: vec2i) -> bool: ...
def get(self, col: int, row: int, default=None) -> T | None:
"""Returns the value at the given position or the default value if out of bounds."""
def unsafe_get(self, col: int, row: int) -> T:
"""Returns the value at the given position without bounds checking."""
def unsafe_set(self, col: int, row: int, value: T):
"""Sets the value at the given position without bounds checking."""
@overload
def get[R](self, col: int, row: int, default: R) -> T | R: ...
@overload
def get[R](self, pos: vec2i, default: R) -> T | R: ...
@overload
def __getitem__(self, index: tuple[int, int]) -> T: ...

View File

@ -40,7 +40,7 @@ static c11_array2d* py_array2d(py_OutRef out, int n_cols, int n_rows) {
/* bindings */
static bool array2d__new__(int argc, py_Ref argv) {
// __new__(cls, n_cols: int, n_rows: int, default=None)
// __new__(cls, n_cols: int, n_rows: int, default: Callable[[vec2i], T] = None)
py_Ref default_ = py_arg(3);
PY_CHECK_ARG_TYPE(0, tp_type);
PY_CHECK_ARG_TYPE(1, tp_int);
@ -52,10 +52,17 @@ static bool array2d__new__(int argc, py_Ref argv) {
c11_array2d* ud = py_array2d(py_pushtmp(), n_cols, n_rows);
// setup initial values
if(py_callable(default_)) {
for(int i = 0; i < numel; i++) {
bool ok = py_call(default_, 0, NULL);
if(!ok) return false;
ud->data[i] = *py_retval();
for(int j = 0; j < n_rows; j++) {
for(int i = 0; i < n_cols; i++) {
py_TValue tmp;
py_newvec2i(&tmp,
(c11_vec2i){
{i, j}
});
bool ok = py_call(default_, 1, &tmp);
if(!ok) return false;
ud->data[j * n_cols + i] = *py_retval();
}
}
} else {
for(int i = 0; i < numel; i++) {
@ -111,17 +118,24 @@ static bool array2d_is_valid(int argc, py_Ref argv) {
static bool array2d_get(int argc, py_Ref argv) {
py_Ref default_;
c11_array2d* self = py_touserdata(argv);
PY_CHECK_ARG_TYPE(1, tp_int);
PY_CHECK_ARG_TYPE(2, tp_int);
int col, row;
if(argc == 3) {
default_ = py_None();
// get[R](self, pos: vec2i, default: R) -> T | R
PY_CHECK_ARG_TYPE(1, tp_vec2i);
c11_vec2i pos = py_tovec2i(py_arg(1));
col = pos.x;
row = pos.y;
default_ = py_arg(2);
} else if(argc == 4) {
// get(self, col: int, row: int, default: T) -> T
PY_CHECK_ARG_TYPE(1, tp_int);
PY_CHECK_ARG_TYPE(2, tp_int);
col = py_toint(py_arg(1));
row = py_toint(py_arg(2));
default_ = py_arg(3);
} else {
return TypeError("get() expected 3 or 4 arguments");
}
int col = py_toint(py_arg(1));
int row = py_toint(py_arg(2));
if(py_array2d_is_valid(self, col, row)) {
py_assign(py_retval(), py_array2d__get(self, col, row));
} else {
@ -130,36 +144,6 @@ static bool array2d_get(int argc, py_Ref argv) {
return true;
}
static bool array2d_unsafe_get(int argc, py_Ref argv) {
PY_CHECK_ARGC(3);
c11_array2d* self = py_touserdata(argv);
PY_CHECK_ARG_TYPE(1, tp_int);
PY_CHECK_ARG_TYPE(2, tp_int);
int col = py_toint(py_arg(1));
int row = py_toint(py_arg(2));
py_assign(py_retval(), py_array2d__get(self, col, row));
return true;
}
static bool array2d_unsafe_set(int argc, py_Ref argv) {
PY_CHECK_ARGC(4);
c11_array2d* self = py_touserdata(argv);
PY_CHECK_ARG_TYPE(1, tp_int);
PY_CHECK_ARG_TYPE(2, tp_int);
int col = py_toint(py_arg(1));
int row = py_toint(py_arg(2));
py_array2d__set(self, col, row, py_arg(3));
py_newnone(py_retval());
return true;
}
static bool array2d__len__(int argc, py_Ref argv) {
PY_CHECK_ARGC(1);
c11_array2d* self = py_touserdata(argv);
py_newint(py_retval(), self->numel);
return true;
}
static bool _array2d_check_all_type(c11_array2d* self, py_Type type) {
for(int i = 0; i < self->numel; i++) {
py_Type item_type = self->data[i].type;
@ -273,11 +257,13 @@ static bool array2d_iterator__next__(int argc, py_Ref argv) {
c11_array2d_iterator* self = py_touserdata(argv);
if(self->index < self->array->numel) {
div_t res = div(self->index, self->array->n_cols);
py_newtuple(py_retval(), 3);
py_newtuple(py_retval(), 2);
py_TValue* data = py_tuple_data(py_retval());
py_newint(&data[0], res.rem);
py_newint(&data[1], res.quot);
py_assign(&data[2], self->array->data + self->index);
py_newvec2i(&data[0],
(c11_vec2i){
{res.rem, res.quot}
});
py_assign(&data[1], self->array->data + self->index);
self->index++;
return true;
}
@ -725,7 +711,6 @@ void pk__add_module_array2d() {
"__new__(cls, n_cols: int, n_rows: int, default=None)",
array2d__new__);
py_bindmagic(array2d, __len__, array2d__len__);
py_bindmagic(array2d, __eq__, array2d__eq__);
py_bindmagic(array2d, __ne__, array2d__ne__);
py_bindmagic(array2d, __repr__, array2d__repr__);
@ -742,8 +727,6 @@ void pk__add_module_array2d() {
py_bindmethod(array2d, "is_valid", array2d_is_valid);
py_bindmethod(array2d, "get", array2d_get);
py_bindmethod(array2d, "unsafe_get", array2d_unsafe_get);
py_bindmethod(array2d, "unsafe_set", array2d_unsafe_set);
py_bindmethod(array2d, "map", array2d_map);
py_bindmethod(array2d, "copy", array2d_copy);

View File

@ -9,11 +9,16 @@ except ValueError:
pass
# test callable constructor
a = array2d[int](2, 4, lambda: 0)
a = array2d[int](2, 4, lambda pos: (pos.x, pos.y))
assert a.width == a.n_cols == 2
assert a.height == a.n_rows == 4
assert a.numel == 8
assert a.tolist() == [
[(0, 0), (1, 0)],
[(0, 1), (1, 1)],
[(0, 2), (1, 2)],
[(0, 3), (1, 3)]]
# test is_valid
assert a.is_valid(0, 0) and a.is_valid(vec2i(0, 0))
@ -24,14 +29,14 @@ assert not a.is_valid(-1, 0) and not a.is_valid(vec2i(-1, 0))
assert not a.is_valid(0, -1) and not a.is_valid(vec2i(0, -1))
# test get
assert a.get(0, 0) == 0
assert a.get(1, 3) == 0
assert a.get(2, 0) is None
assert a.get(0, 4, 'S') == 'S'
assert a.get(0, 0, -1) == (0, 0)
assert a.get(vec2i(1, 3), -1) == (1, 3)
assert a.get(2, 0, None) is None
assert a.get(vec2i(0, 4), 'S') == 'S'
# test __getitem__
assert a[0, 0] == 0
assert a[1, 3] == 0
assert a[0, 0] == (0, 0)
assert a[1, 3] == (1, 3)
try:
a[2, 0]
exit(1)
@ -39,6 +44,7 @@ except IndexError:
pass
# test __setitem__
a = array2d[int](2, 4, default=0)
a[0, 0] = 5
assert a[0, 0] == 5
a[1, 3] = 6
@ -53,9 +59,6 @@ except IndexError:
a_list = [[5, 0], [0, 0], [0, 0], [0, 6]]
assert a_list == a.tolist()
# test __len__
assert len(a) == 4*2
# test __eq__
x = array2d(2, 4, default=0)
b = array2d(2, 4, default=0)
@ -174,16 +177,8 @@ except TypeError:
# test __iter__
a = array2d(3, 4, default=1)
for i, j, x in a:
assert a[i, j] == x
assert len(a) == a.numel
# test _get and _set
a = array2d(3, 4, default=1)
assert a.unsafe_get(0, 0) == 1
a.unsafe_set(0, 0, 2)
assert a.unsafe_get(0, 0) == 2
for xy, val in a:
assert a[xy] == x
# test convolve
a = array2d[int].fromlist([[1, 0, 2, 4, 0], [3, 1, 0, 5, 1]])