diff --git a/docs/modules/linalg.md b/docs/modules/linalg.md index a952e053..d9ab07e9 100644 --- a/docs/modules/linalg.md +++ b/docs/modules/linalg.md @@ -115,6 +115,9 @@ class mat3x3(_StructLike['mat3x3']): @overload def __matmul__(self, other: vec3) -> vec3: ... + def __imatmul__(self, other: mat3x3) -> None: ... + def invert_(self) -> None: ... + @staticmethod def zeros() -> mat3x3: ... @staticmethod diff --git a/include/pocketpy/linalg.h b/include/pocketpy/linalg.h index fda0a3af..577280b7 100644 --- a/include/pocketpy/linalg.h +++ b/include/pocketpy/linalg.h @@ -158,26 +158,22 @@ struct Mat3x3{ return *this; } - Mat3x3 matmul(const Mat3x3& other) const{ - Mat3x3 ret; - ret._11 = _11 * other._11 + _12 * other._21 + _13 * other._31; - ret._12 = _11 * other._12 + _12 * other._22 + _13 * other._32; - ret._13 = _11 * other._13 + _12 * other._23 + _13 * other._33; - ret._21 = _21 * other._11 + _22 * other._21 + _23 * other._31; - ret._22 = _21 * other._12 + _22 * other._22 + _23 * other._32; - ret._23 = _21 * other._13 + _22 * other._23 + _23 * other._33; - ret._31 = _31 * other._11 + _32 * other._21 + _33 * other._31; - ret._32 = _31 * other._12 + _32 * other._22 + _33 * other._32; - ret._33 = _31 * other._13 + _32 * other._23 + _33 * other._33; - return ret; + void matmul(const Mat3x3& other, Mat3x3& out) const{ + out._11 = _11 * other._11 + _12 * other._21 + _13 * other._31; + out._12 = _11 * other._12 + _12 * other._22 + _13 * other._32; + out._13 = _11 * other._13 + _12 * other._23 + _13 * other._33; + out._21 = _21 * other._11 + _22 * other._21 + _23 * other._31; + out._22 = _21 * other._12 + _22 * other._22 + _23 * other._32; + out._23 = _21 * other._13 + _22 * other._23 + _23 * other._33; + out._31 = _31 * other._11 + _32 * other._21 + _33 * other._31; + out._32 = _31 * other._12 + _32 * other._22 + _33 * other._32; + out._33 = _31 * other._13 + _32 * other._23 + _33 * other._33; } - Vec3 matmul(const Vec3& other) const{ - Vec3 ret; - ret.x = _11 * other.x + _12 * other.y + _13 * other.z; - ret.y = _21 * other.x + _22 * other.y + _23 * other.z; - ret.z = _31 * other.x + _32 * other.y + _33 * other.z; - return ret; + void matmul(const Vec3& other, Vec3& out) const{ + out.x = _11 * other.x + _12 * other.y + _13 * other.z; + out.y = _21 * other.x + _22 * other.y + _23 * other.z; + out.z = _31 * other.x + _32 * other.y + _33 * other.z; } bool operator==(const Mat3x3& other) const{ @@ -207,19 +203,19 @@ struct Mat3x3{ return ret; } - bool inverse(Mat3x3& ret) const{ + bool inverse(Mat3x3& out) const{ float det = determinant(); if (isclose(det, 0)) return false; float inv_det = 1.0f / det; - ret._11 = (_22 * _33 - _23 * _32) * inv_det; - ret._12 = (_13 * _32 - _12 * _33) * inv_det; - ret._13 = (_12 * _23 - _13 * _22) * inv_det; - ret._21 = (_23 * _31 - _21 * _33) * inv_det; - ret._22 = (_11 * _33 - _13 * _31) * inv_det; - ret._23 = (_13 * _21 - _11 * _23) * inv_det; - ret._31 = (_21 * _32 - _22 * _31) * inv_det; - ret._32 = (_12 * _31 - _11 * _32) * inv_det; - ret._33 = (_11 * _22 - _12 * _21) * inv_det; + out._11 = (_22 * _33 - _23 * _32) * inv_det; + out._12 = (_13 * _32 - _12 * _33) * inv_det; + out._13 = (_12 * _23 - _13 * _22) * inv_det; + out._21 = (_23 * _31 - _21 * _33) * inv_det; + out._22 = (_11 * _33 - _13 * _31) * inv_det; + out._23 = (_13 * _21 - _11 * _23) * inv_det; + out._31 = (_21 * _32 - _22 * _31) * inv_det; + out._32 = (_12 * _31 - _11 * _32) * inv_det; + out._33 = (_11 * _22 - _12 * _21) * inv_det; return true; } diff --git a/include/typings/linalg.pyi b/include/typings/linalg.pyi index 9bf342c0..a8504c98 100644 --- a/include/typings/linalg.pyi +++ b/include/typings/linalg.pyi @@ -105,6 +105,9 @@ class mat3x3(_StructLike['mat3x3']): @overload def __matmul__(self, other: vec3) -> vec3: ... + def __imatmul__(self, other: mat3x3) -> None: ... + def invert_(self) -> None: ... + @staticmethod def zeros() -> mat3x3: ... @staticmethod diff --git a/src/linalg.cpp b/src/linalg.cpp index 6199d0d0..366f08f6 100644 --- a/src/linalg.cpp +++ b/src/linalg.cpp @@ -279,26 +279,17 @@ static Vec2 SmoothDamp(Vec2 current, Vec2 target, PyVec2& currentVelocity, float return VAR(std::move(t)); }); -#define METHOD_PROXY_NONE(name) \ - vm->bind_method<0>(type, #name, [](VM* vm, ArgsView args){ \ - PyMat3x3& self = _CAST(PyMat3x3&, args[0]); \ - self.name(); \ - return vm->None; \ - }); - - METHOD_PROXY_NONE(set_zeros) - METHOD_PROXY_NONE(set_ones) - METHOD_PROXY_NONE(set_identity) - -#undef METHOD_PROXY_NONE + vm->bind_method<0>(type, "set_zeros", PK_ACTION(PK_OBJ_GET(PyMat3x3, args[0]).set_zeros())); + vm->bind_method<0>(type, "set_ones", PK_ACTION(PK_OBJ_GET(PyMat3x3, args[0]).set_ones())); + vm->bind_method<0>(type, "set_identity", PK_ACTION(PK_OBJ_GET(PyMat3x3, args[0]).set_identity())); vm->bind__repr__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* obj){ PyMat3x3& self = _CAST(PyMat3x3&, obj); std::stringstream ss; ss << std::fixed << std::setprecision(3); - ss << "mat3x3([[" << self._11 << ", " << self._12 << ", " << self._13 << "],\n"; - ss << " [" << self._21 << ", " << self._22 << ", " << self._23 << "],\n"; - ss << " [" << self._31 << ", " << self._32 << ", " << self._33 << "]])"; + ss << "mat3x3([" << self._11 << ", " << self._12 << ", " << self._13 << ",\n"; + ss << " " << self._21 << ", " << self._22 << ", " << self._23 << ",\n"; + ss << " " << self._31 << ", " << self._32 << ", " << self._33 << "])"; return VAR(ss.str()); }); @@ -390,16 +381,30 @@ static Vec2 SmoothDamp(Vec2 current, Vec2 target, PyVec2& currentVelocity, float vm->bind__matmul__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0, PyObject* _1){ PyMat3x3& self = _CAST(PyMat3x3&, _0); if(is_non_tagged_type(_1, PyMat3x3::_type(vm))){ - PyMat3x3& other = _CAST(PyMat3x3&, _1); - return VAR_T(PyMat3x3, self.matmul(other)); + const PyMat3x3& other = _CAST(PyMat3x3&, _1); + Mat3x3 out; + self.matmul(other, out); + return VAR_T(PyMat3x3, out); } if(is_non_tagged_type(_1, PyVec3::_type(vm))){ - PyVec3& other = _CAST(PyVec3&, _1); - return VAR_T(PyVec3, self.matmul(other)); + const PyVec3& other = _CAST(PyVec3&, _1); + Vec3 out; + self.matmul(other, out); + return VAR_T(PyVec3, out); } return vm->NotImplemented; }); + vm->bind_method<1>(type, "__imatmul__", [](VM* vm, ArgsView args){ + PyMat3x3& self = _CAST(PyMat3x3&, args[0]); + vm->check_non_tagged_type(args[1], PyMat3x3::_type(vm)); + const PyMat3x3& other = _CAST(PyMat3x3&, args[1]); + Mat3x3 out; + self.matmul(other, out); + self = out; + return vm->None; + }); + vm->bind_method<0>(type, "determinant", [](VM* vm, ArgsView args){ PyMat3x3& self = _CAST(PyMat3x3&, args[0]); return VAR(self.determinant()); @@ -418,6 +423,15 @@ static Vec2 SmoothDamp(Vec2 current, Vec2 target, PyVec2& currentVelocity, float return VAR_T(PyMat3x3, ret); }); + vm->bind_method<0>(type, "invert_", [](VM* vm, ArgsView args){ + PyMat3x3& self = _CAST(PyMat3x3&, args[0]); + Mat3x3 ret; + bool ok = self.inverse(ret); + if(!ok) vm->ValueError("matrix is not invertible"); + self = ret; + return vm->None; + }); + // @staticmethod vm->bind(type, "zeros()", [](VM* vm, ArgsView args){ PK_UNUSED(args); diff --git a/tests/80_linalg.py b/tests/80_linalg.py index 08db4665..eadcc114 100644 --- a/tests/80_linalg.py +++ b/tests/80_linalg.py @@ -330,6 +330,9 @@ for i in range(3): correct_result_mat[i, j] = sum([e1*e2 for e1, e2 in zip(get_row(test_mat_copy, i), get_col(test_mat_copy_2, j))]) assert result_mat == correct_result_mat +test_mat_copy.__imatmul__(test_mat_copy_2) +assert test_mat_copy == correct_result_mat + # test determinant test_mat_copy = test_mat.copy() test_mat_copy.determinant() @@ -382,6 +385,8 @@ assert test_mat_copy.transpose() == test_mat_copy.transpose().transpose().transp # test inverse assert ~static_test_mat_float == static_test_mat_float_inv +assert static_test_mat_float.invert_() is None +assert static_test_mat_float == static_test_mat_float_inv try: ~mat3x3([1, 2, 3, 2, 4, 6, 3, 6, 9])