mirror of
https://github.com/pocketpy/pocketpy
synced 2025-10-20 11:30:18 +00:00
add array2d.__iter__
This commit is contained in:
parent
9a88fd06f4
commit
665b95a162
@ -1,4 +1,4 @@
|
||||
from typing import Callable, Any, Generic, TypeVar, Literal, overload
|
||||
from typing import Callable, Any, Generic, TypeVar, Literal, overload, Iterator
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
@ -27,6 +27,7 @@ class array2d(Generic[T]):
|
||||
def __setitem__(self, index: tuple[slice, slice], value: int | float | str | bool | None | 'array2d[T]'): ...
|
||||
|
||||
def __len__(self) -> int: ...
|
||||
def __iter__(self) -> Iterator[T]: ...
|
||||
def __eq__(self, other: 'array2d') -> bool: ...
|
||||
def __ne__(self, other: 'array2d') -> bool: ...
|
||||
def __repr__(self): ...
|
||||
|
@ -181,7 +181,7 @@ struct Array2d{
|
||||
|
||||
vm->bind__len__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0){
|
||||
Array2d& self = PK_OBJ_GET(Array2d, _0);
|
||||
return (i64)self.n_rows;
|
||||
return (i64)self.numel;
|
||||
});
|
||||
|
||||
vm->bind__repr__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0){
|
||||
@ -354,10 +354,38 @@ struct Array2d{
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct Array2dIter{
|
||||
PY_CLASS(Array2dIter, array2d, _array2d_iterator)
|
||||
PyObject* ref;
|
||||
int i;
|
||||
Array2dIter(PyObject* ref) : ref(ref), i(0) {}
|
||||
|
||||
void _gc_mark() const{ PK_OBJ_MARK(ref); }
|
||||
|
||||
static void _register(VM* vm, PyObject* mod, PyObject* type){
|
||||
vm->_all_types[PK_OBJ_GET(Type, type)].subclass_enabled = false;
|
||||
vm->bind_notimplemented_constructor<Array2dIter>(type);
|
||||
vm->bind__iter__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* obj){ return obj; });
|
||||
vm->bind__next__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* obj){
|
||||
Array2dIter& self = _CAST(Array2dIter&, obj);
|
||||
Array2d& a = PK_OBJ_GET(Array2d, self.ref);
|
||||
if(self.i == a.numel) return vm->StopIteration;
|
||||
std::div_t res = std::div(self.i, a.n_cols);
|
||||
return VAR(Tuple(VAR(res.rem), VAR(res.quot), a.data[self.i++]));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
void add_module_array2d(VM* vm){
|
||||
PyObject* mod = vm->new_module("array2d");
|
||||
|
||||
Array2d::register_class(vm, mod);
|
||||
Array2dIter::register_class(vm, mod);
|
||||
|
||||
vm->bind__iter__(Array2d::_type(vm), [](VM* vm, PyObject* obj){
|
||||
return VAR_T(Array2dIter, obj);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
|
@ -53,7 +53,7 @@ a_list = [[5, 0], [0, 0], [0, 0], [0, 6]]
|
||||
assert a_list == a.tolist()
|
||||
|
||||
# test __len__
|
||||
assert len(a) == 4
|
||||
assert len(a) == 4*2
|
||||
|
||||
# test __eq__
|
||||
x = array2d(2, 4, default=0)
|
||||
@ -172,3 +172,8 @@ a.indexed_apply_(lambda x, y, val: x+y)
|
||||
assert a[0, 0] == 0
|
||||
assert a[1, 2] == 3
|
||||
assert a[2, 0] == 2
|
||||
|
||||
for i, j, x in a:
|
||||
assert a[i, j] == x
|
||||
|
||||
assert len(a) == a.numel
|
||||
|
Loading…
x
Reference in New Issue
Block a user