diff --git a/include/typings/array2d.pyi b/include/typings/array2d.pyi index a14d0d25..81313d70 100644 --- a/include/typings/array2d.pyi +++ b/include/typings/array2d.pyi @@ -1,4 +1,4 @@ -from typing import Callable, Any, Generic, TypeVar, Literal, overload +from typing import Callable, Any, Generic, TypeVar, Literal, overload, Iterator T = TypeVar('T') @@ -27,6 +27,7 @@ class array2d(Generic[T]): def __setitem__(self, index: tuple[slice, slice], value: int | float | str | bool | None | 'array2d[T]'): ... def __len__(self) -> int: ... + def __iter__(self) -> Iterator[T]: ... def __eq__(self, other: 'array2d') -> bool: ... def __ne__(self, other: 'array2d') -> bool: ... def __repr__(self): ... diff --git a/src/array2d.cpp b/src/array2d.cpp index 89e457d1..3519dbc2 100644 --- a/src/array2d.cpp +++ b/src/array2d.cpp @@ -181,7 +181,7 @@ struct Array2d{ vm->bind__len__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0){ Array2d& self = PK_OBJ_GET(Array2d, _0); - return (i64)self.n_rows; + return (i64)self.numel; }); vm->bind__repr__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0){ @@ -354,10 +354,38 @@ struct Array2d{ } }; + +struct Array2dIter{ + PY_CLASS(Array2dIter, array2d, _array2d_iterator) + PyObject* ref; + int i; + Array2dIter(PyObject* ref) : ref(ref), i(0) {} + + void _gc_mark() const{ PK_OBJ_MARK(ref); } + + static void _register(VM* vm, PyObject* mod, PyObject* type){ + vm->_all_types[PK_OBJ_GET(Type, type)].subclass_enabled = false; + vm->bind_notimplemented_constructor(type); + vm->bind__iter__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* obj){ return obj; }); + vm->bind__next__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* obj){ + Array2dIter& self = _CAST(Array2dIter&, obj); + Array2d& a = PK_OBJ_GET(Array2d, self.ref); + if(self.i == a.numel) return vm->StopIteration; + std::div_t res = std::div(self.i, a.n_cols); + return VAR(Tuple(VAR(res.rem), VAR(res.quot), a.data[self.i++])); + }); + } +}; + void add_module_array2d(VM* vm){ PyObject* mod = vm->new_module("array2d"); Array2d::register_class(vm, mod); + Array2dIter::register_class(vm, mod); + + vm->bind__iter__(Array2d::_type(vm), [](VM* vm, PyObject* obj){ + return VAR_T(Array2dIter, obj); + }); } diff --git a/tests/83_array2d.py b/tests/83_array2d.py index 2444be30..577da838 100644 --- a/tests/83_array2d.py +++ b/tests/83_array2d.py @@ -53,7 +53,7 @@ a_list = [[5, 0], [0, 0], [0, 0], [0, 6]] assert a_list == a.tolist() # test __len__ -assert len(a) == 4 +assert len(a) == 4*2 # test __eq__ x = array2d(2, 4, default=0) @@ -172,3 +172,8 @@ a.indexed_apply_(lambda x, y, val: x+y) assert a[0, 0] == 0 assert a[1, 2] == 3 assert a[2, 0] == 2 + +for i, j, x in a: + assert a[i, j] == x + +assert len(a) == a.numel