From e187a6162412daf49f407d5286ae4fd63105de6e Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Sun, 22 Jun 2025 16:31:46 +0800 Subject: [PATCH] fix module reload bug --- build_g.sh | 2 +- include/pocketpy/interpreter/typeinfo.h | 4 +-- include/pocketpy/interpreter/vm.h | 2 +- include/pocketpy/pocketpy.h | 2 +- src/interpreter/typeinfo.c | 37 ++++++++++++++++----- src/interpreter/vm.c | 2 +- src/modules/os.c | 2 ++ src/objects/codeobject.c | 1 + src/public/modules.c | 10 ++++-- src/public/py_mappingproxy.c | 10 ++++++ tests/31_modulereload.py | 43 +++++++++++++++++++++++++ tests/testreload/__init__.py | 2 ++ tests/testreload/a.py | 18 +++++++++++ 13 files changed, 117 insertions(+), 18 deletions(-) create mode 100644 tests/31_modulereload.py create mode 100644 tests/testreload/__init__.py create mode 100644 tests/testreload/a.py diff --git a/build_g.sh b/build_g.sh index 1a02c970..22b626a3 100644 --- a/build_g.sh +++ b/build_g.sh @@ -6,7 +6,7 @@ SRC=$(find src/ -name "*.c") FLAGS="-std=c11 -lm -ldl -lpthread -Iinclude -O0 -Wfatal-errors -g -DDEBUG -DPK_ENABLE_OS=1" -SANITIZE_FLAGS="-fsanitize=address,leak,undefined" +SANITIZE_FLAGS="-fsanitize=address,leak,undefined -fno-sanitize=function" if [ "$(uname)" == "Darwin" ]; then SANITIZE_FLAGS="-fsanitize=address,undefined" diff --git a/include/pocketpy/interpreter/typeinfo.h b/include/pocketpy/interpreter/typeinfo.h index e01b7146..5dbdcb97 100644 --- a/include/pocketpy/interpreter/typeinfo.h +++ b/include/pocketpy/interpreter/typeinfo.h @@ -16,10 +16,8 @@ typedef struct py_TypeInfo { bool is_python; // is it a python class? (not derived from c object) bool is_sealed; // can it be subclassed? - py_Dtor dtor; // destructor for this type, NULL if no dtor - py_TValue annotations; - + py_Dtor dtor; // destructor for this type, NULL if no dtor void (*on_end_subclass)(struct py_TypeInfo*); // backdoor for enum module } py_TypeInfo; diff --git a/include/pocketpy/interpreter/vm.h b/include/pocketpy/interpreter/vm.h index c8bf1797..7436ffea 100644 --- a/include/pocketpy/interpreter/vm.h +++ b/include/pocketpy/interpreter/vm.h @@ -38,9 +38,9 @@ typedef struct py_ModuleInfo { c11_string* name; c11_string* package; c11_string* path; + py_GlobalRef self; // weakref to the original module object } py_ModuleInfo; - typedef struct VM { py_Frame* top_frame; diff --git a/include/pocketpy/pocketpy.h b/include/pocketpy/pocketpy.h index 3ce3938f..33e5f5db 100644 --- a/include/pocketpy/pocketpy.h +++ b/include/pocketpy/pocketpy.h @@ -558,7 +558,7 @@ PK_API py_GlobalRef py_newmodule(const char* path); /// Get a module by path. PK_API py_GlobalRef py_getmodule(const char* path); /// Reload an existing module. -PK_API bool py_importlib_reload(py_GlobalRef module) PY_RAISE PY_RETURN; +PK_API bool py_importlib_reload(py_Ref module) PY_RAISE PY_RETURN; /// Import a module. /// The result will be set to `py_retval()`. diff --git a/src/interpreter/typeinfo.c b/src/interpreter/typeinfo.c index 8c5fa460..49d2eeeb 100644 --- a/src/interpreter/typeinfo.c +++ b/src/interpreter/typeinfo.c @@ -55,26 +55,28 @@ static void py_TypeInfo__common_init(py_Name name, void (*dtor)(void*), bool is_python, bool is_sealed, - py_TypeInfo* self) { + py_TypeInfo* self, + py_TValue* typeobject) { py_TypeInfo* base_ti = base ? pk_typeinfo(base) : NULL; if(base_ti && base_ti->is_sealed) { c11__abort("type '%s' is not an acceptable base type", py_name2str(base_ti->name)); } - memset(self, 0, sizeof(py_TypeInfo)); self->name = name; self->index = index; self->base = base; self->base_ti = base_ti; - self->self = *py_retval(); + py_assign(&self->self, typeobject); self->module = module ? module : py_NIL(); - self->annotations = *py_NIL(); if(!dtor && base) dtor = base_ti->dtor; self->is_python = is_python; self->is_sealed = is_sealed; + + self->annotations = *py_NIL(); self->dtor = dtor; + self->on_end_subclass = NULL; } py_Type pk_newtype(const char* name, @@ -85,7 +87,15 @@ py_Type pk_newtype(const char* name, bool is_sealed) { py_Type index = pk_current_vm->types.length; py_TypeInfo* self = py_newobject(py_retval(), tp_type, -1, sizeof(py_TypeInfo)); - py_TypeInfo__common_init(py_name(name), base, index, module, dtor, is_python, is_sealed, self); + py_TypeInfo__common_init(py_name(name), + base, + index, + module, + dtor, + is_python, + is_sealed, + self, + py_retval()); TypePointer* pointer = c11_vector__emplace(&pk_current_vm->types); pointer->ti = self; pointer->dtor = self->dtor; @@ -102,11 +112,22 @@ py_Type pk_newtypewithmode(py_Name name, if(mode == RELOAD_MODE && module != NULL) { py_ItemRef old_class = py_getdict(module, name); if(old_class != NULL && py_istype(old_class, tp_type)) { - NameDict* old_dict = PyObject__dict(old_class->_obj); - NameDict__clear(old_dict); +#ifndef NDEBUG + const char* name_cstr = py_name2str(name); + (void)name_cstr; // avoid unused warning +#endif + py_cleardict(old_class); py_TypeInfo* self = py_touserdata(old_class); py_Type index = self->index; - py_TypeInfo__common_init(name, base, index, module, dtor, is_python, is_sealed, self); + py_TypeInfo__common_init(name, + base, + index, + module, + dtor, + is_python, + is_sealed, + self, + &self->self); TypePointer* pointer = c11__at(TypePointer, &pk_current_vm->types, index); pointer->ti = self; pointer->dtor = self->dtor; diff --git a/src/interpreter/vm.c b/src/interpreter/vm.c index 488d76da..c402c1ee 100644 --- a/src/interpreter/vm.c +++ b/src/interpreter/vm.c @@ -52,7 +52,7 @@ void VM__ctor(VM* self) { .f_cmp = BinTree__cmp_cstr, .need_free_key = false, }; - BinTree__ctor(&self->modules, c11_strdup(""), py_NIL(), &modules_config); + BinTree__ctor(&self->modules, "", py_NIL(), &modules_config); c11_vector__ctor(&self->types, sizeof(TypePointer)); self->builtins = NULL; diff --git a/src/modules/os.c b/src/modules/os.c index 45804407..abeb1e06 100644 --- a/src/modules/os.c +++ b/src/modules/os.c @@ -97,6 +97,8 @@ void pk__add_module_os() { py_ItemRef path_object = py_emplacedict(mod, py_name("path")); py_newobject(path_object, tp_object, -1, 0); py_bindfunc(path_object, "exists", os_path_exists); + + py_newdict(py_emplacedict(mod, py_name("environ"))); } typedef struct { diff --git a/src/objects/codeobject.c b/src/objects/codeobject.c index 857d0a69..9e441527 100644 --- a/src/objects/codeobject.c +++ b/src/objects/codeobject.c @@ -197,4 +197,5 @@ void Function__dtor(Function* self) { // self->decl->code.src->filename->data); PK_DECREF(self->decl); if(self->closure) NameDict__delete(self->closure); + memset(self, 0, sizeof(Function)); } \ No newline at end of file diff --git a/src/public/modules.c b/src/public/modules.c index e8148dd3..7eb54c46 100644 --- a/src/public/modules.c +++ b/src/public/modules.c @@ -83,7 +83,9 @@ py_Ref py_newmodule(const char* path) { if(exists) c11__abort("module '%s' already exists", path); BinTree__set(&pk_current_vm->modules, (void*)path, py_retval()); - return py_getmodule(path); + py_GlobalRef retval = py_getmodule(path); + mi->self = retval; + return retval; } int load_module_from_dll_desktop_only(const char* path) PY_RAISE PY_RETURN; @@ -181,9 +183,11 @@ __SUCCESS: return ok ? 1 : -1; } -bool py_importlib_reload(py_GlobalRef module) { +bool py_importlib_reload(py_Ref module) { VM* vm = pk_current_vm; py_ModuleInfo* mi = py_touserdata(module); + // We should ensure that the module is its original py_GlobalRef + module = mi->self; c11_sv path = c11_string__sv(mi->path); c11_string* slashed_path = c11_sv__replace(path, '.', PK_PLATFORM_SEP); c11_string* filename = c11_string__new3("%s.py", slashed_path->data); @@ -195,7 +199,7 @@ bool py_importlib_reload(py_GlobalRef module) { } c11_string__delete(slashed_path); if(data == NULL) return ImportError("module '%v' not found", path); - py_cleardict(module); + // py_cleardict(module); BUG: removing old classes will cause RELOAD_MODE to fail bool ok = py_exec(data, filename->data, RELOAD_MODE, module); c11_string__delete(filename); PK_FREE(data); diff --git a/src/public/py_mappingproxy.c b/src/public/py_mappingproxy.c index fd492469..474d9b0c 100644 --- a/src/public/py_mappingproxy.c +++ b/src/public/py_mappingproxy.c @@ -20,6 +20,15 @@ static bool namedict__getitem__(int argc, py_Ref argv) { return true; } +static bool namedict__get(int argc, py_Ref argv) { + PY_CHECK_ARGC(3); + PY_CHECK_ARG_TYPE(1, tp_str); + py_Name name = py_namev(py_tosv(py_arg(1))); + py_Ref res = py_getdict(py_getslot(argv, 0), name); + py_assign(py_retval(), res ? res : py_arg(2)); + return true; +} + static bool namedict__setitem__(int argc, py_Ref argv) { PY_CHECK_ARGC(3); PY_CHECK_ARG_TYPE(1, tp_str); @@ -82,5 +91,6 @@ py_Type pk_namedict__register() { py_setdict(py_tpobject(type), __hash__, py_None()); py_bindmethod(type, "items", namedict_items); py_bindmethod(type, "clear", namedict_clear); + py_bindmethod(type, "get", namedict__get); return type; } diff --git a/tests/31_modulereload.py b/tests/31_modulereload.py new file mode 100644 index 00000000..644d675a --- /dev/null +++ b/tests/31_modulereload.py @@ -0,0 +1,43 @@ +try: + import os +except ImportError: + exit(0) + +import importlib + +os.chdir('tests') +assert os.getcwd().endswith('tests') + +# test +os.environ['TEST_RELOAD_VALUE'] = '123' +os.environ['SET_X'] = '1' +os.environ['SET_Y'] = '0' + +from testreload import MyClass, a + +objid = id(MyClass) +funcid = id(MyClass.some_func) +getxyid = id(MyClass.get_xy) + +assert MyClass.value == '123' +assert MyClass.get_xy() == (1, 0) + +inst = MyClass() +assert inst.some_func() == '123' + +# reload +os.environ['TEST_RELOAD_VALUE'] = '456' +os.environ['SET_X'] = '0' +os.environ['SET_Y'] = '1' + +importlib.reload(a) + +assert id(MyClass) == objid +assert id(MyClass.some_func) != funcid +assert id(MyClass.get_xy) != getxyid + +assert MyClass.value == '456' +assert inst.some_func() == '456' +assert (MyClass.get_xy() == (1, 1)), MyClass.get_xy() + + diff --git a/tests/testreload/__init__.py b/tests/testreload/__init__.py new file mode 100644 index 00000000..0fa0df2e --- /dev/null +++ b/tests/testreload/__init__.py @@ -0,0 +1,2 @@ +from .a import MyClass +from . import a \ No newline at end of file diff --git a/tests/testreload/a.py b/tests/testreload/a.py new file mode 100644 index 00000000..c02b8c5f --- /dev/null +++ b/tests/testreload/a.py @@ -0,0 +1,18 @@ +import os + +class MyClass: + value = os.environ['TEST_RELOAD_VALUE'] + + def some_func(self): + return self.value + + @staticmethod + def get_xy(): + g = globals() + return g.get('x', 0), g.get('y', 0) + + +if os.environ['SET_X'] == '1': + x = 1 +elif os.environ['SET_Y'] == '1': + y = 1 \ No newline at end of file