add array2d.__iter__

This commit is contained in:
blueloveTH 2024-04-27 22:22:51 +08:00
parent 9a88fd06f4
commit 665b95a162
3 changed files with 37 additions and 3 deletions

View File

@ -1,4 +1,4 @@
from typing import Callable, Any, Generic, TypeVar, Literal, overload
from typing import Callable, Any, Generic, TypeVar, Literal, overload, Iterator
T = TypeVar('T')
@ -27,6 +27,7 @@ class array2d(Generic[T]):
def __setitem__(self, index: tuple[slice, slice], value: int | float | str | bool | None | 'array2d[T]'): ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[T]: ...
def __eq__(self, other: 'array2d') -> bool: ...
def __ne__(self, other: 'array2d') -> bool: ...
def __repr__(self): ...

View File

@ -181,7 +181,7 @@ struct Array2d{
vm->bind__len__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0){
Array2d& self = PK_OBJ_GET(Array2d, _0);
return (i64)self.n_rows;
return (i64)self.numel;
});
vm->bind__repr__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0){
@ -354,10 +354,38 @@ struct Array2d{
}
};
struct Array2dIter{
PY_CLASS(Array2dIter, array2d, _array2d_iterator)
PyObject* ref;
int i;
Array2dIter(PyObject* ref) : ref(ref), i(0) {}
void _gc_mark() const{ PK_OBJ_MARK(ref); }
static void _register(VM* vm, PyObject* mod, PyObject* type){
vm->_all_types[PK_OBJ_GET(Type, type)].subclass_enabled = false;
vm->bind_notimplemented_constructor<Array2dIter>(type);
vm->bind__iter__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* obj){ return obj; });
vm->bind__next__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* obj){
Array2dIter& self = _CAST(Array2dIter&, obj);
Array2d& a = PK_OBJ_GET(Array2d, self.ref);
if(self.i == a.numel) return vm->StopIteration;
std::div_t res = std::div(self.i, a.n_cols);
return VAR(Tuple(VAR(res.rem), VAR(res.quot), a.data[self.i++]));
});
}
};
void add_module_array2d(VM* vm){
PyObject* mod = vm->new_module("array2d");
Array2d::register_class(vm, mod);
Array2dIter::register_class(vm, mod);
vm->bind__iter__(Array2d::_type(vm), [](VM* vm, PyObject* obj){
return VAR_T(Array2dIter, obj);
});
}

View File

@ -53,7 +53,7 @@ a_list = [[5, 0], [0, 0], [0, 0], [0, 6]]
assert a_list == a.tolist()
# test __len__
assert len(a) == 4
assert len(a) == 4*2
# test __eq__
x = array2d(2, 4, default=0)
@ -172,3 +172,8 @@ a.indexed_apply_(lambda x, y, val: x+y)
assert a[0, 0] == 0
assert a[1, 2] == 3
assert a[2, 0] == 2
for i, j, x in a:
assert a[i, j] == x
assert len(a) == a.numel