diff --git a/amalgamate.py b/amalgamate.py index dc3a5393..a0f4b5ab 100644 --- a/amalgamate.py +++ b/amalgamate.py @@ -6,7 +6,7 @@ with open("include/pocketpy/opcodes.h", "rt", encoding='utf-8') as f: OPCODES_TEXT = '\n' + f.read() + '\n' pipeline = [ - ["config.h", "export.h", "_generated.h", "common.h", "memory.h", "any.h", "vector.h", "str.h", "tuplelist.h", "namedict.h", "error.h"], + ["config.h", "export.h", "_generated.h", "common.h", "memory.h", "vector.h", "str.h", "tuplelist.h", "namedict.h", "error.h", "any.h"], ["obj.h", "dict.h", "codeobject.h", "frame.h", "profiler.h"], ["gc.h", "vm.h", "ceval.h", "lexer.h", "expr.h", "compiler.h", "repl.h"], ["cffi.h", "bindings.h", "iter.h", "base64.h", "csv.h", "collections.h", "array2d.h", "dataclasses.h", "random.h", "linalg.h", "easing.h", "io.h", "modules.h"], diff --git a/include/pocketpy/any.h b/include/pocketpy/any.h index ac4800c1..580c5dce 100644 --- a/include/pocketpy/any.h +++ b/include/pocketpy/any.h @@ -1,6 +1,7 @@ #pragma once #include "common.h" +#include "str.h" namespace pkpy { @@ -52,18 +53,25 @@ struct any{ any& operator=(const any& other) = delete; ~any() { if(data) _vt->deleter(data); } + + template + T& _cast() const noexcept{ + static_assert(std::is_same_v>); + if constexpr (is_sso_v){ + return *((T*)(&data)); + }else{ + return *(static_cast(data)); + } + } + + template + T& cast() const{ + static_assert(std::is_same_v>); + if(type_id() != typeid(T)) __bad_any_cast(typeid(T), type_id()); + return _cast(); + } + + static void __bad_any_cast(const std::type_index expected, const std::type_index actual); }; -template -bool any_cast(const any& a, T** out){ - static_assert(std::is_same_v>); - if(a.type_id() != typeid(T)) return false; - if constexpr (is_sso_v){ - *out = (T*)(&a.data); - }else{ - *out = static_cast(a.data); - } - return true; -} - } // namespace pkpy \ No newline at end of file diff --git a/include/pocketpy/codeobject.h b/include/pocketpy/codeobject.h index 25ede579..812c708e 100644 --- a/include/pocketpy/codeobject.h +++ b/include/pocketpy/codeobject.h @@ -186,17 +186,8 @@ struct Py_ final: PyObject { template T& lambda_get_userdata(PyObject** p){ static_assert(std::is_same_v>); - any* ud; - if(p[-1] != PY_NULL) ud = &PK_OBJ_GET(NativeFunc, p[-1])._userdata; - else ud = &PK_OBJ_GET(NativeFunc, p[-2])._userdata; - T* out; - if(!any_cast(*ud, &out)){ - const char* expected = typeid(T).name(); - const char* actual = ud->type_id().name(); - Str error = _S("lambda_get_userdata: any_cast failed: expected ", expected, ", got ", actual); - throw std::runtime_error(error.c_str()); - } - return *out; + int offset = p[-1] != PY_NULL ? -1 : -2; + return PK_OBJ_GET(NativeFunc, p[offset])._userdata.cast(); } } // namespace pkpy \ No newline at end of file diff --git a/src/any.cpp b/src/any.cpp index 8ec4be7c..ca36fae9 100644 --- a/src/any.cpp +++ b/src/any.cpp @@ -2,6 +2,11 @@ namespace pkpy{ +void any::__bad_any_cast(const std::type_index expected, const std::type_index actual){ + Str error = _S("bad_any_cast: expected ", expected.name(), ", got ", actual.name()); + throw std::runtime_error(error.c_str()); +} + any::any(any&& other) noexcept: data(other.data), _vt(other._vt){ other.data = nullptr; other._vt = nullptr;