improve array2d.__setitem__

This commit is contained in:
blueloveTH 2024-02-20 01:56:10 +08:00
parent 74bf8d86a0
commit 3534492bb6
3 changed files with 47 additions and 6 deletions

View File

@ -24,7 +24,7 @@ class array2d(Generic[T]):
@overload @overload
def __setitem__(self, index: tuple[int, int], value: T): ... def __setitem__(self, index: tuple[int, int], value: T): ...
@overload @overload
def __setitem__(self, index: tuple[slice, slice], value: 'array2d[T]'): ... def __setitem__(self, index: tuple[slice, slice], value: int | float | str | bool | None | 'array2d[T]'): ...
def __len__(self) -> int: ... def __len__(self) -> int: ...
def __eq__(self, other: 'array2d') -> bool: ... def __eq__(self, other: 'array2d') -> bool: ...

View File

@ -133,15 +133,34 @@ struct Array2d{
if(is_non_tagged_type(xy[0], VM::tp_slice) && is_non_tagged_type(xy[1], VM::tp_slice)){ if(is_non_tagged_type(xy[0], VM::tp_slice) && is_non_tagged_type(xy[1], VM::tp_slice)){
HANDLE_SLICE(); HANDLE_SLICE();
Array2d& other = CAST(Array2d&, _2); // _2 must be an array2d
bool is_basic_type = false;
switch(vm->_tp(_2).index){
case VM::tp_int.index: is_basic_type = true; break;
case VM::tp_float.index: is_basic_type = true; break;
case VM::tp_str.index: is_basic_type = true; break;
case VM::tp_bool.index: is_basic_type = true; break;
default: is_basic_type = _2 == vm->None;
}
if(is_basic_type){
for(int j = 0; j < slice_height; j++)
for(int i = 0; i < slice_width; i++)
self._set(i + start_col, j + start_row, _2);
return;
}
if(!is_non_tagged_type(_2, Array2d::_type(vm))){
vm->TypeError(_S("expected int/float/str/bool/None or an array2d instance"));
}
Array2d& other = PK_OBJ_GET(Array2d, _2);
if(slice_width != other.n_cols || slice_height != other.n_rows){ if(slice_width != other.n_cols || slice_height != other.n_rows){
vm->ValueError("array2d size does not match the slice size"); vm->ValueError("array2d size does not match the slice size");
} }
for(int j = 0; j < slice_height; j++){ for(int j = 0; j < slice_height; j++)
for(int i = 0; i < slice_width; i++){ for(int i = 0; i < slice_width; i++)
self._set(i + start_col, j + start_row, other._get(i, j)); self._set(i + start_col, j + start_row, other._get(i, j));
}
}
return; return;
} }
vm->TypeError("expected `tuple[int, int]` or `tuple[slice, slice]` as index"); vm->TypeError("expected `tuple[int, int]` or `tuple[slice, slice]` as index");

View File

@ -144,3 +144,25 @@ assert a.find_bounding_rect(0) == (0, 0, 5, 5)
assert a.find_bounding_rect(2) == None assert a.find_bounding_rect(2) == None
a = array2d(3, 2, default='?')
# int/float/str/bool/None
for value in [0, 0.0, '0', False, None]:
a[0:2, 0:1] = value
assert a[2, 1] == '?'
assert a[0, 0] == value
a[:, :] = 3
assert a == array2d(3, 2, default=3)
try:
a[:, :] = array2d(1, 1)
exit(1)
except ValueError:
pass
try:
a[:, :] = []
exit(1)
except TypeError:
pass