diff --git a/include/typings/array2d.pyi b/include/typings/array2d.pyi index 223694b7..8af47eed 100644 --- a/include/typings/array2d.pyi +++ b/include/typings/array2d.pyi @@ -24,7 +24,7 @@ class array2d(Generic[T]): @overload def __setitem__(self, index: tuple[int, int], value: T): ... @overload - def __setitem__(self, index: tuple[slice, slice], value: 'array2d[T]'): ... + def __setitem__(self, index: tuple[slice, slice], value: int | float | str | bool | None | 'array2d[T]'): ... def __len__(self) -> int: ... def __eq__(self, other: 'array2d') -> bool: ... diff --git a/src/array2d.cpp b/src/array2d.cpp index 4249cae1..4fd8548d 100644 --- a/src/array2d.cpp +++ b/src/array2d.cpp @@ -133,15 +133,34 @@ struct Array2d{ if(is_non_tagged_type(xy[0], VM::tp_slice) && is_non_tagged_type(xy[1], VM::tp_slice)){ HANDLE_SLICE(); - Array2d& other = CAST(Array2d&, _2); // _2 must be an array2d + + bool is_basic_type = false; + switch(vm->_tp(_2).index){ + case VM::tp_int.index: is_basic_type = true; break; + case VM::tp_float.index: is_basic_type = true; break; + case VM::tp_str.index: is_basic_type = true; break; + case VM::tp_bool.index: is_basic_type = true; break; + default: is_basic_type = _2 == vm->None; + } + + if(is_basic_type){ + for(int j = 0; j < slice_height; j++) + for(int i = 0; i < slice_width; i++) + self._set(i + start_col, j + start_row, _2); + return; + } + + if(!is_non_tagged_type(_2, Array2d::_type(vm))){ + vm->TypeError(_S("expected int/float/str/bool/None or an array2d instance")); + } + + Array2d& other = PK_OBJ_GET(Array2d, _2); if(slice_width != other.n_cols || slice_height != other.n_rows){ vm->ValueError("array2d size does not match the slice size"); } - for(int j = 0; j < slice_height; j++){ - for(int i = 0; i < slice_width; i++){ + for(int j = 0; j < slice_height; j++) + for(int i = 0; i < slice_width; i++) self._set(i + start_col, j + start_row, other._get(i, j)); - } - } return; } vm->TypeError("expected `tuple[int, int]` or `tuple[slice, slice]` as index"); diff --git a/tests/83_array2d.py b/tests/83_array2d.py index 7158c4f3..9c1825e5 100644 --- a/tests/83_array2d.py +++ b/tests/83_array2d.py @@ -144,3 +144,25 @@ assert a.find_bounding_rect(0) == (0, 0, 5, 5) assert a.find_bounding_rect(2) == None +a = array2d(3, 2, default='?') +# int/float/str/bool/None + +for value in [0, 0.0, '0', False, None]: + a[0:2, 0:1] = value + assert a[2, 1] == '?' + assert a[0, 0] == value + +a[:, :] = 3 +assert a == array2d(3, 2, default=3) + +try: + a[:, :] = array2d(1, 1) + exit(1) +except ValueError: + pass + +try: + a[:, :] = [] + exit(1) +except TypeError: + pass