fix module reload bug

This commit is contained in:
blueloveTH 2025-06-22 16:31:46 +08:00
parent 68a2186728
commit e187a61624
13 changed files with 117 additions and 18 deletions

View File

@ -6,7 +6,7 @@ SRC=$(find src/ -name "*.c")
FLAGS="-std=c11 -lm -ldl -lpthread -Iinclude -O0 -Wfatal-errors -g -DDEBUG -DPK_ENABLE_OS=1"
SANITIZE_FLAGS="-fsanitize=address,leak,undefined"
SANITIZE_FLAGS="-fsanitize=address,leak,undefined -fno-sanitize=function"
if [ "$(uname)" == "Darwin" ]; then
SANITIZE_FLAGS="-fsanitize=address,undefined"

View File

@ -16,10 +16,8 @@ typedef struct py_TypeInfo {
bool is_python; // is it a python class? (not derived from c object)
bool is_sealed; // can it be subclassed?
py_Dtor dtor; // destructor for this type, NULL if no dtor
py_TValue annotations;
py_Dtor dtor; // destructor for this type, NULL if no dtor
void (*on_end_subclass)(struct py_TypeInfo*); // backdoor for enum module
} py_TypeInfo;

View File

@ -38,9 +38,9 @@ typedef struct py_ModuleInfo {
c11_string* name;
c11_string* package;
c11_string* path;
py_GlobalRef self; // weakref to the original module object
} py_ModuleInfo;
typedef struct VM {
py_Frame* top_frame;

View File

@ -558,7 +558,7 @@ PK_API py_GlobalRef py_newmodule(const char* path);
/// Get a module by path.
PK_API py_GlobalRef py_getmodule(const char* path);
/// Reload an existing module.
PK_API bool py_importlib_reload(py_GlobalRef module) PY_RAISE PY_RETURN;
PK_API bool py_importlib_reload(py_Ref module) PY_RAISE PY_RETURN;
/// Import a module.
/// The result will be set to `py_retval()`.

View File

@ -55,26 +55,28 @@ static void py_TypeInfo__common_init(py_Name name,
void (*dtor)(void*),
bool is_python,
bool is_sealed,
py_TypeInfo* self) {
py_TypeInfo* self,
py_TValue* typeobject) {
py_TypeInfo* base_ti = base ? pk_typeinfo(base) : NULL;
if(base_ti && base_ti->is_sealed) {
c11__abort("type '%s' is not an acceptable base type", py_name2str(base_ti->name));
}
memset(self, 0, sizeof(py_TypeInfo));
self->name = name;
self->index = index;
self->base = base;
self->base_ti = base_ti;
self->self = *py_retval();
py_assign(&self->self, typeobject);
self->module = module ? module : py_NIL();
self->annotations = *py_NIL();
if(!dtor && base) dtor = base_ti->dtor;
self->is_python = is_python;
self->is_sealed = is_sealed;
self->annotations = *py_NIL();
self->dtor = dtor;
self->on_end_subclass = NULL;
}
py_Type pk_newtype(const char* name,
@ -85,7 +87,15 @@ py_Type pk_newtype(const char* name,
bool is_sealed) {
py_Type index = pk_current_vm->types.length;
py_TypeInfo* self = py_newobject(py_retval(), tp_type, -1, sizeof(py_TypeInfo));
py_TypeInfo__common_init(py_name(name), base, index, module, dtor, is_python, is_sealed, self);
py_TypeInfo__common_init(py_name(name),
base,
index,
module,
dtor,
is_python,
is_sealed,
self,
py_retval());
TypePointer* pointer = c11_vector__emplace(&pk_current_vm->types);
pointer->ti = self;
pointer->dtor = self->dtor;
@ -102,11 +112,22 @@ py_Type pk_newtypewithmode(py_Name name,
if(mode == RELOAD_MODE && module != NULL) {
py_ItemRef old_class = py_getdict(module, name);
if(old_class != NULL && py_istype(old_class, tp_type)) {
NameDict* old_dict = PyObject__dict(old_class->_obj);
NameDict__clear(old_dict);
#ifndef NDEBUG
const char* name_cstr = py_name2str(name);
(void)name_cstr; // avoid unused warning
#endif
py_cleardict(old_class);
py_TypeInfo* self = py_touserdata(old_class);
py_Type index = self->index;
py_TypeInfo__common_init(name, base, index, module, dtor, is_python, is_sealed, self);
py_TypeInfo__common_init(name,
base,
index,
module,
dtor,
is_python,
is_sealed,
self,
&self->self);
TypePointer* pointer = c11__at(TypePointer, &pk_current_vm->types, index);
pointer->ti = self;
pointer->dtor = self->dtor;

View File

@ -52,7 +52,7 @@ void VM__ctor(VM* self) {
.f_cmp = BinTree__cmp_cstr,
.need_free_key = false,
};
BinTree__ctor(&self->modules, c11_strdup(""), py_NIL(), &modules_config);
BinTree__ctor(&self->modules, "", py_NIL(), &modules_config);
c11_vector__ctor(&self->types, sizeof(TypePointer));
self->builtins = NULL;

View File

@ -97,6 +97,8 @@ void pk__add_module_os() {
py_ItemRef path_object = py_emplacedict(mod, py_name("path"));
py_newobject(path_object, tp_object, -1, 0);
py_bindfunc(path_object, "exists", os_path_exists);
py_newdict(py_emplacedict(mod, py_name("environ")));
}
typedef struct {

View File

@ -197,4 +197,5 @@ void Function__dtor(Function* self) {
// self->decl->code.src->filename->data);
PK_DECREF(self->decl);
if(self->closure) NameDict__delete(self->closure);
memset(self, 0, sizeof(Function));
}

View File

@ -83,7 +83,9 @@ py_Ref py_newmodule(const char* path) {
if(exists) c11__abort("module '%s' already exists", path);
BinTree__set(&pk_current_vm->modules, (void*)path, py_retval());
return py_getmodule(path);
py_GlobalRef retval = py_getmodule(path);
mi->self = retval;
return retval;
}
int load_module_from_dll_desktop_only(const char* path) PY_RAISE PY_RETURN;
@ -181,9 +183,11 @@ __SUCCESS:
return ok ? 1 : -1;
}
bool py_importlib_reload(py_GlobalRef module) {
bool py_importlib_reload(py_Ref module) {
VM* vm = pk_current_vm;
py_ModuleInfo* mi = py_touserdata(module);
// We should ensure that the module is its original py_GlobalRef
module = mi->self;
c11_sv path = c11_string__sv(mi->path);
c11_string* slashed_path = c11_sv__replace(path, '.', PK_PLATFORM_SEP);
c11_string* filename = c11_string__new3("%s.py", slashed_path->data);
@ -195,7 +199,7 @@ bool py_importlib_reload(py_GlobalRef module) {
}
c11_string__delete(slashed_path);
if(data == NULL) return ImportError("module '%v' not found", path);
py_cleardict(module);
// py_cleardict(module); BUG: removing old classes will cause RELOAD_MODE to fail
bool ok = py_exec(data, filename->data, RELOAD_MODE, module);
c11_string__delete(filename);
PK_FREE(data);

View File

@ -20,6 +20,15 @@ static bool namedict__getitem__(int argc, py_Ref argv) {
return true;
}
static bool namedict__get(int argc, py_Ref argv) {
PY_CHECK_ARGC(3);
PY_CHECK_ARG_TYPE(1, tp_str);
py_Name name = py_namev(py_tosv(py_arg(1)));
py_Ref res = py_getdict(py_getslot(argv, 0), name);
py_assign(py_retval(), res ? res : py_arg(2));
return true;
}
static bool namedict__setitem__(int argc, py_Ref argv) {
PY_CHECK_ARGC(3);
PY_CHECK_ARG_TYPE(1, tp_str);
@ -82,5 +91,6 @@ py_Type pk_namedict__register() {
py_setdict(py_tpobject(type), __hash__, py_None());
py_bindmethod(type, "items", namedict_items);
py_bindmethod(type, "clear", namedict_clear);
py_bindmethod(type, "get", namedict__get);
return type;
}

43
tests/31_modulereload.py Normal file
View File

@ -0,0 +1,43 @@
try:
import os
except ImportError:
exit(0)
import importlib
os.chdir('tests')
assert os.getcwd().endswith('tests')
# test
os.environ['TEST_RELOAD_VALUE'] = '123'
os.environ['SET_X'] = '1'
os.environ['SET_Y'] = '0'
from testreload import MyClass, a
objid = id(MyClass)
funcid = id(MyClass.some_func)
getxyid = id(MyClass.get_xy)
assert MyClass.value == '123'
assert MyClass.get_xy() == (1, 0)
inst = MyClass()
assert inst.some_func() == '123'
# reload
os.environ['TEST_RELOAD_VALUE'] = '456'
os.environ['SET_X'] = '0'
os.environ['SET_Y'] = '1'
importlib.reload(a)
assert id(MyClass) == objid
assert id(MyClass.some_func) != funcid
assert id(MyClass.get_xy) != getxyid
assert MyClass.value == '456'
assert inst.some_func() == '456'
assert (MyClass.get_xy() == (1, 1)), MyClass.get_xy()

View File

@ -0,0 +1,2 @@
from .a import MyClass
from . import a

18
tests/testreload/a.py Normal file
View File

@ -0,0 +1,18 @@
import os
class MyClass:
value = os.environ['TEST_RELOAD_VALUE']
def some_func(self):
return self.value
@staticmethod
def get_xy():
g = globals()
return g.get('x', 0), g.get('y', 0)
if os.environ['SET_X'] == '1':
x = 1
elif os.environ['SET_Y'] == '1':
y = 1