update array2d.copy_

This commit is contained in:
blueloveTH 2024-02-07 14:57:09 +08:00
parent 72ed48fc6b
commit 0bbf6af3ff
3 changed files with 19 additions and 4 deletions

View File

@ -78,14 +78,18 @@ class array2d(Generic[T]):
return new_a return new_a
def fill_(self, value: T) -> None: def fill_(self, value: T) -> None:
for i in range(self.n_cols * self.n_rows): for i in range(self.numel):
self.data[i] = value self.data[i] = value
def apply_(self, f: Callable[[T], T]) -> None: def apply_(self, f: Callable[[T], T]) -> None:
for i in range(self.n_cols * self.n_rows): for i in range(self.numel):
self.data[i] = f(self.data[i]) self.data[i] = f(self.data[i])
def copy_(self, other: 'array2d[T]') -> None: def copy_(self, other: 'array2d[T]' | list['T']) -> None:
if isinstance(other, list):
assert len(other) == self.numel
self.data = other.copy()
return
self.n_cols = other.n_cols self.n_cols = other.n_cols
self.n_rows = other.n_rows self.n_rows = other.n_rows
self.data = other.data.copy() self.data = other.data.copy()
@ -106,4 +110,3 @@ class array2d(Generic[T]):
count += int(self.is_valid(i+1, j+1) and self[i+1, j+1] == value) count += int(self.is_valid(i+1, j+1) and self[i+1, j+1] == value)
new_a[i, j] = count new_a[i, j] = count
return new_a return new_a

View File

@ -167,6 +167,16 @@ struct Array2d{
vm->bind(type, "copy_(self, other)", [](VM* vm, ArgsView args){ vm->bind(type, "copy_(self, other)", [](VM* vm, ArgsView args){
Array2d& self = PK_OBJ_GET(Array2d, args[0]); Array2d& self = PK_OBJ_GET(Array2d, args[0]);
if(is_non_tagged_type(args[1], VM::tp_list)){
const List& list = PK_OBJ_GET(List, args[1]);
if(list.size() != self.numel){
vm->ValueError("list size must be equal to the number of elements in the array2d");
}
for(int i = 0; i < self.numel; i++){
self.data[i] = list[i];
}
return vm->None;
}
Array2d& other = CAST(Array2d&, args[1]); Array2d& other = CAST(Array2d&, args[1]);
// if self and other have different sizes, re-initialize self // if self and other have different sizes, re-initialize self
if(self.n_cols != other.n_cols || self.n_rows != other.n_rows){ if(self.n_cols != other.n_cols || self.n_rows != other.n_rows){

View File

@ -92,6 +92,8 @@ assert a == d and a is not d
x = array2d(4, 4, default=0) x = array2d(4, 4, default=0)
x.copy_(d) x.copy_(d)
assert x == d and x is not d assert x == d and x is not d
x.copy_(['a']*d.numel)
assert x == array2d(d.width, d.height, default='a')
# test subclass array2d # test subclass array2d
class A(array2d): class A(array2d):