From a0a5753283deb7ddf6cf996234282f0c9d2cc142 Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Tue, 13 Feb 2024 13:13:22 +0800 Subject: [PATCH] allows `vec` * `vec` --- include/pocketpy/linalg.h | 15 +++-------- include/typings/linalg.pyi | 15 +++++++++++ src/linalg.cpp | 52 +++++++++++++++++++++++++------------- tests/80_linalg.py | 7 ++++- 4 files changed, 58 insertions(+), 31 deletions(-) diff --git a/include/pocketpy/linalg.h b/include/pocketpy/linalg.h index 2640aaab..e6904e7c 100644 --- a/include/pocketpy/linalg.h +++ b/include/pocketpy/linalg.h @@ -13,13 +13,10 @@ struct Vec2{ Vec2(const Vec2& v) = default; Vec2 operator+(const Vec2& v) const { return Vec2(x + v.x, y + v.y); } - Vec2& operator+=(const Vec2& v) { x += v.x; y += v.y; return *this; } Vec2 operator-(const Vec2& v) const { return Vec2(x - v.x, y - v.y); } - Vec2& operator-=(const Vec2& v) { x -= v.x; y -= v.y; return *this; } Vec2 operator*(float s) const { return Vec2(x * s, y * s); } - Vec2& operator*=(float s) { x *= s; y *= s; return *this; } + Vec2 operator*(const Vec2& v) const { return Vec2(x * v.x, y * v.y); } Vec2 operator/(float s) const { return Vec2(x / s, y / s); } - Vec2& operator/=(float s) { x /= s; y /= s; return *this; } Vec2 operator-() const { return Vec2(-x, -y); } bool operator==(const Vec2& v) const { return isclose(x, v.x) && isclose(y, v.y); } bool operator!=(const Vec2& v) const { return !isclose(x, v.x) || !isclose(y, v.y); } @@ -40,13 +37,10 @@ struct Vec3{ Vec3(const Vec3& v) = default; Vec3 operator+(const Vec3& v) const { return Vec3(x + v.x, y + v.y, z + v.z); } - Vec3& operator+=(const Vec3& v) { x += v.x; y += v.y; z += v.z; return *this; } Vec3 operator-(const Vec3& v) const { return Vec3(x - v.x, y - v.y, z - v.z); } - Vec3& operator-=(const Vec3& v) { x -= v.x; y -= v.y; z -= v.z; return *this; } Vec3 operator*(float s) const { return Vec3(x * s, y * s, z * s); } - Vec3& operator*=(float s) { x *= s; y *= s; z *= s; return *this; } + Vec3 operator*(const Vec3& v) const { return Vec3(x * v.x, y * v.y, z * v.z); } Vec3 operator/(float s) const { return Vec3(x / s, y / s, z / s); } - Vec3& operator/=(float s) { x /= s; y /= s; z /= s; return *this; } Vec3 operator-() const { return Vec3(-x, -y, -z); } bool operator==(const Vec3& v) const { return isclose(x, v.x) && isclose(y, v.y) && isclose(z, v.z); } bool operator!=(const Vec3& v) const { return !isclose(x, v.x) || !isclose(y, v.y) || !isclose(z, v.z); } @@ -66,13 +60,10 @@ struct Vec4{ Vec4(const Vec4& v) = default; Vec4 operator+(const Vec4& v) const { return Vec4(x + v.x, y + v.y, z + v.z, w + v.w); } - Vec4& operator+=(const Vec4& v) { x += v.x; y += v.y; z += v.z; w += v.w; return *this; } Vec4 operator-(const Vec4& v) const { return Vec4(x - v.x, y - v.y, z - v.z, w - v.w); } - Vec4& operator-=(const Vec4& v) { x -= v.x; y -= v.y; z -= v.z; w -= v.w; return *this; } Vec4 operator*(float s) const { return Vec4(x * s, y * s, z * s, w * s); } - Vec4& operator*=(float s) { x *= s; y *= s; z *= s; w *= s; return *this; } + Vec4 operator*(const Vec4& v) const { return Vec4(x * v.x, y * v.y, z * v.z, w * v.w); } Vec4 operator/(float s) const { return Vec4(x / s, y / s, z / s, w / s); } - Vec4& operator/=(float s) { x /= s; y /= s; z /= s; w /= s; return *this; } Vec4 operator-() const { return Vec4(-x, -y, -z, -w); } bool operator==(const Vec4& v) const { return isclose(x, v.x) && isclose(y, v.y) && isclose(z, v.z) && isclose(w, v.w); } bool operator!=(const Vec4& v) const { return !isclose(x, v.x) || !isclose(y, v.y) || !isclose(z, v.z) || !isclose(w, v.w); } diff --git a/include/typings/linalg.pyi b/include/typings/linalg.pyi index 30dafeb6..5006a663 100644 --- a/include/typings/linalg.pyi +++ b/include/typings/linalg.pyi @@ -8,7 +8,12 @@ class vec2(_StructLike['vec2']): def __init__(self, x: float, y: float) -> None: ... def __add__(self, other: vec2) -> vec2: ... def __sub__(self, other: vec2) -> vec2: ... + + @overload def __mul__(self, other: float) -> vec2: ... + @overload + def __mul__(self, other: vec2) -> vec2: ... + def __rmul__(self, other: float) -> vec2: ... def __truediv__(self, other: float) -> vec2: ... def dot(self, other: vec2) -> float: ... @@ -44,7 +49,12 @@ class vec3(_StructLike['vec3']): def __init__(self, x: float, y: float, z: float) -> None: ... def __add__(self, other: vec3) -> vec3: ... def __sub__(self, other: vec3) -> vec3: ... + + @overload def __mul__(self, other: float) -> vec3: ... + @overload + def __mul__(self, other: vec3) -> vec3: ... + def __rmul__(self, other: float) -> vec3: ... def __truediv__(self, other: float) -> vec3: ... def dot(self, other: vec3) -> float: ... @@ -65,7 +75,12 @@ class vec4(_StructLike['vec4']): def __init__(self, x: float, y: float, z: float, w: float) -> None: ... def __add__(self, other: vec4) -> vec4: ... def __sub__(self, other: vec4) -> vec4: ... + + @overload def __mul__(self, other: float) -> vec4: ... + @overload + def __mul__(self, other: vec4) -> vec4: ... + def __rmul__(self, other: float) -> vec4: ... def __truediv__(self, other: float) -> vec4: ... def dot(self, other: vec4) -> float: ... diff --git a/src/linalg.cpp b/src/linalg.cpp index 69cb1ce0..9b01f888 100644 --- a/src/linalg.cpp +++ b/src/linalg.cpp @@ -2,18 +2,18 @@ namespace pkpy{ -#define BIND_VEC_VEC_OP(D, name, op) \ - vm->bind_method<1>(type, #name, [](VM* vm, ArgsView args){ \ - PyVec##D& self = _CAST(PyVec##D&, args[0]); \ - PyVec##D& other = CAST(PyVec##D&, args[1]); \ - return VAR(self op other); \ +#define BIND_VEC_VEC_OP(D, name, op) \ + vm->bind##name(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0, PyObject* _1){ \ + PyVec##D& self = _CAST(PyVec##D&, _0); \ + PyVec##D& other = CAST(PyVec##D&, _1); \ + return VAR(self op other); \ }); #define BIND_VEC_FLOAT_OP(D, name, op) \ - vm->bind_method<1>(type, #name, [](VM* vm, ArgsView args){ \ - PyVec##D& self = _CAST(PyVec##D&, args[0]); \ - f64 other = CAST(f64, args[1]); \ - return VAR(self op other); \ + vm->bind##name(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0, PyObject* _1){ \ + PyVec##D& self = _CAST(PyVec##D&, _0); \ + f64 other = CAST(f64, _1); \ + return VAR(self op other); \ }); #define BIND_VEC_FUNCTION_0(D, name) \ @@ -29,6 +29,27 @@ namespace pkpy{ return VAR(self.name(other)); \ }); +#define BIND_VEC_MUL_OP(D) \ + vm->bind__mul__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0, PyObject* _1){ \ + PyVec##D& self = _CAST(PyVec##D&, _0); \ + if(is_non_tagged_type(_1, PyVec##D::_type(vm))){ \ + PyVec##D& other = _CAST(PyVec##D&, _1); \ + return VAR(self * other); \ + } \ + f64 other = CAST(f64, _1); \ + return VAR(self * other); \ + }); \ + vm->bind_method<1>(type, "__rmul__", [](VM* vm, ArgsView args){ \ + PyVec##D& self = _CAST(PyVec##D&, args[0]); \ + f64 other = CAST(f64, args[1]); \ + return VAR(self * other); \ + }); \ + vm->bind__truediv__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0, PyObject* _1){ \ + PyVec##D& self = _CAST(PyVec##D&, _0); \ + f64 other = CAST(f64, _1); \ + return VAR(self / other); \ + }); + // https://github.com/Unity-Technologies/UnityCsReference/blob/master/Runtime/Export/Math/Vector2.cs#L289 static Vec2 SmoothDamp(Vec2 current, Vec2 target, PyVec2& currentVelocity, float smoothTime, float maxSpeed, float deltaTime) { @@ -142,8 +163,7 @@ static Vec2 SmoothDamp(Vec2 current, Vec2 target, PyVec2& currentVelocity, float BIND_VEC_VEC_OP(2, __add__, +) BIND_VEC_VEC_OP(2, __sub__, -) - BIND_VEC_FLOAT_OP(2, __mul__, *) - BIND_VEC_FLOAT_OP(2, __rmul__, *) + BIND_VEC_MUL_OP(2) BIND_VEC_FLOAT_OP(2, __truediv__, /) BIND_VEC_FUNCTION_1(2, dot) BIND_VEC_FUNCTION_1(2, cross) @@ -178,9 +198,7 @@ static Vec2 SmoothDamp(Vec2 current, Vec2 target, PyVec2& currentVelocity, float BIND_VEC_VEC_OP(3, __add__, +) BIND_VEC_VEC_OP(3, __sub__, -) - BIND_VEC_FLOAT_OP(3, __mul__, *) - BIND_VEC_FLOAT_OP(3, __rmul__, *) - BIND_VEC_FLOAT_OP(3, __truediv__, /) + BIND_VEC_MUL_OP(3) BIND_VEC_FUNCTION_1(3, dot) BIND_VEC_FUNCTION_1(3, cross) BIND_VEC_FUNCTION_1(3, copy_) @@ -216,9 +234,7 @@ static Vec2 SmoothDamp(Vec2 current, Vec2 target, PyVec2& currentVelocity, float BIND_VEC_VEC_OP(4, __add__, +) BIND_VEC_VEC_OP(4, __sub__, -) - BIND_VEC_FLOAT_OP(4, __mul__, *) - BIND_VEC_FLOAT_OP(4, __rmul__, *) - BIND_VEC_FLOAT_OP(4, __truediv__, /) + BIND_VEC_MUL_OP(4) BIND_VEC_FUNCTION_1(4, dot) BIND_VEC_FUNCTION_1(4, copy_) BIND_VEC_FUNCTION_0(4, length) @@ -228,7 +244,7 @@ static Vec2 SmoothDamp(Vec2 current, Vec2 target, PyVec2& currentVelocity, float } #undef BIND_VEC_VEC_OP -#undef BIND_VEC_FLOAT_OP +#undef BIND_VEC_MUL_OP #undef BIND_VEC_FUNCTION_0 #undef BIND_VEC_FUNCTION_1 diff --git a/tests/80_linalg.py b/tests/80_linalg.py index 9b3bb9e6..2d185be0 100644 --- a/tests/80_linalg.py +++ b/tests/80_linalg.py @@ -481,4 +481,9 @@ try: assert d[6, 6] exit(1) except IndexError: - pass \ No newline at end of file + pass + +# test vec * vec +assert vec2(1, 2) * vec2(3, 4) == vec2(3, 8) +assert vec3(1, 2, 3) * vec3(4, 5, 6) == vec3(4, 10, 18) +assert vec4(1, 2, 3, 4) * vec4(5, 6, 7, 8) == vec4(5, 12, 21, 32)