From 8d9761c6b53406d715e93f2a55e58ec751051920 Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Fri, 12 May 2023 14:13:57 +0800 Subject: [PATCH] ... --- src/linalg.h | 23 ++++++++++++++--------- src/linalg.pyi | 5 +++++ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/linalg.h b/src/linalg.h index 802ec484..38071e47 100644 --- a/src/linalg.h +++ b/src/linalg.h @@ -562,17 +562,22 @@ struct PyMat3x3: Mat3x3{ return VAR_T(PyMat3x3, self / other); }); - vm->bind_method<1>(type, "__matmul__", [](VM* vm, ArgsView args){ + auto f_mm = [](VM* vm, ArgsView args){ PyMat3x3& self = _CAST(PyMat3x3&, args[0]); - PyMat3x3& other = CAST(PyMat3x3&, args[1]); - return VAR_T(PyMat3x3, self.matmul(other)); - }); + if(is_non_tagged_type(args[1], PyMat3x3::_type(vm))){ + PyMat3x3& other = _CAST(PyMat3x3&, args[1]); + return VAR_T(PyMat3x3, self.matmul(other)); + } + if(is_non_tagged_type(args[1], PyVec3::_type(vm))){ + PyVec3& other = _CAST(PyVec3&, args[1]); + return VAR_T(PyVec3, self.matmul(other)); + } + vm->TypeError("unsupported operand type(s) for @"); + return vm->None; + }; - vm->bind_method<1>(type, "matmul", [](VM* vm, ArgsView args){ - PyMat3x3& self = _CAST(PyMat3x3&, args[0]); - PyMat3x3& other = CAST(PyMat3x3&, args[1]); - return VAR_T(PyMat3x3, self.matmul(other)); - }); + vm->bind_method<1>(type, "__matmul__", f_mm); + vm->bind_method<1>(type, "matmul", f_mm); vm->bind_method<1>(type, "__eq__", [](VM* vm, ArgsView args){ PyMat3x3& self = _CAST(PyMat3x3&, args[0]); diff --git a/src/linalg.pyi b/src/linalg.pyi index fade6bfa..9f2bdb22 100644 --- a/src/linalg.pyi +++ b/src/linalg.pyi @@ -64,8 +64,13 @@ class mat3x3: def __sub__(self, other: mat3x3) -> mat3x3: ... def __mul__(self, other: float) -> mat3x3: ... def __truediv__(self, other: float) -> mat3x3: ... + @overload def __matmul__(self, other: mat3x3) -> mat3x3: ... + @overload + def __matmul__(self, other: vec3) -> vec3: ... + @overload def matmul(self, other: mat3x3) -> mat3x3: ... + @overload def matmul(self, other: vec3) -> vec3: ... def determinant(self) -> float: ... def transpose(self) -> mat3x3: ...