diff --git a/src/modules/array2d.c b/src/modules/array2d.c index a5a79af7..cf595440 100644 --- a/src/modules/array2d.c +++ b/src/modules/array2d.c @@ -272,17 +272,31 @@ static bool array2d_apply_(int argc, py_Ref argv) { static bool array2d_copy_(int argc, py_Ref argv) { // def copy_(self, src: 'array2d') -> None: ... PY_CHECK_ARGC(2); - PY_CHECK_ARG_TYPE(1, tp_array2d); c11_array2d* self = py_touserdata(argv); - c11_array2d* src = py_touserdata(py_arg(1)); - if(self->n_cols != src->n_cols || self->n_rows != src->n_rows) { - return ValueError("copy_() expected the same shape: (%d, %d) != (%d, %d)", - self->n_cols, - self->n_rows, - src->n_cols, - src->n_rows); + + py_Type src_type = py_typeof(py_arg(1)); + if(src_type == tp_array2d) { + c11_array2d* src = py_touserdata(py_arg(1)); + if(self->n_cols != src->n_cols || self->n_rows != src->n_rows) { + return ValueError("copy_() expected the same shape: (%d, %d) != (%d, %d)", + self->n_cols, + self->n_rows, + src->n_cols, + src->n_rows); + } + memcpy(self->data, src->data, self->numel * sizeof(py_TValue)); + } else { + py_TValue* data; + int length = pk_arrayview(py_arg(1), &data); + if(length != -1) { + if(self->numel != length) { + return ValueError("copy_() expected the same numel: %d != %d", self->numel, length); + } + memcpy(self->data, data, self->numel * sizeof(py_TValue)); + } else { + return TypeError("copy_() expected `array2d`, `list` or `tuple`, got '%t", src_type); + } } - memcpy(self->data, src->data, self->numel * sizeof(py_TValue)); py_newnone(py_retval()); return true; } diff --git a/tests/90_array2d.py b/tests/90_array2d.py index ea8b733e..840e1021 100644 --- a/tests/90_array2d.py +++ b/tests/90_array2d.py @@ -92,6 +92,8 @@ assert a == d and a is not d x = array2d(2, 4, default=0) x.copy_(d) assert x == d and x is not d +x.copy_([1, 2, 3, 4, 5, 6, 7, 8]) +assert x.tolist() == [[1, 2], [3, 4], [5, 6], [7, 8]] # test alive_neighbors a = array2d(3, 3, default=0)