diff --git a/include/pocketpy/interpreter/vm.h b/include/pocketpy/interpreter/vm.h index 30c0af50..4ad8cb94 100644 --- a/include/pocketpy/interpreter/vm.h +++ b/include/pocketpy/interpreter/vm.h @@ -30,7 +30,7 @@ typedef struct py_TypeInfo { typedef struct VM { Frame* top_frame; - NameDict modules; + ModuleDict modules; c11_vector /*T=py_TypeInfo*/ types; py_TValue builtins; // builtins module diff --git a/include/pocketpy/objects/namedict.h b/include/pocketpy/objects/namedict.h index 0bd1c3ca..c68463db 100644 --- a/include/pocketpy/objects/namedict.h +++ b/include/pocketpy/objects/namedict.h @@ -12,3 +12,16 @@ #include "pocketpy/xmacros/smallmap.h" #undef SMALLMAP_T__HEADER +/* A simple binary tree for storing modules. */ +typedef struct ModuleDict { + const char* path; + py_TValue module; + struct ModuleDict* left; + struct ModuleDict* right; +} ModuleDict; + +void ModuleDict__ctor(ModuleDict* self, const char* path, py_TValue module); +void ModuleDict__dtor(ModuleDict* self); +void ModuleDict__set(ModuleDict* self, const char* key, py_TValue val); +py_TValue* ModuleDict__try_get(ModuleDict* self, const char* path); +bool ModuleDict__contains(ModuleDict* self, const char* path); \ No newline at end of file diff --git a/include/pocketpy/pocketpy.h b/include/pocketpy/pocketpy.h index 58ee8990..dfd06ba1 100644 --- a/include/pocketpy/pocketpy.h +++ b/include/pocketpy/pocketpy.h @@ -377,9 +377,9 @@ bool py_vectorcall(uint16_t argc, uint16_t kwargc) PY_RAISE; /************* Modules *************/ /// Create a new module. -py_TmpRef py_newmodule(const char* path); +py_GlobalRef py_newmodule(const char* path); /// Get a module by path. -py_TmpRef py_getmodule(const char* path); +py_GlobalRef py_getmodule(const char* path); /// Import a module. /// The result will be set to `py_retval()`. diff --git a/src/interpreter/vm.c b/src/interpreter/vm.c index 4e009910..b2eb72d1 100644 --- a/src/interpreter/vm.c +++ b/src/interpreter/vm.c @@ -57,7 +57,7 @@ static void py_TypeInfo__dtor(py_TypeInfo* self) { c11_vector__dtor(&self->annot void VM__ctor(VM* self) { self->top_frame = NULL; - NameDict__ctor(&self->modules); + ModuleDict__ctor(&self->modules, NULL, *py_NIL); c11_vector__ctor(&self->types, sizeof(py_TypeInfo)); self->builtins = *py_NIL; @@ -221,7 +221,7 @@ void VM__dtor(VM* self) { // clear frames while(self->top_frame) VM__pop_frame(self); - NameDict__dtor(&self->modules); + ModuleDict__dtor(&self->modules); c11__foreach(py_TypeInfo, &self->types, ti) py_TypeInfo__dtor(ti); c11_vector__dtor(&self->types); ValueStack__clear(&self->stack); @@ -315,7 +315,7 @@ py_Type pk_newtype(const char* name, py_Type index = types->count; py_TypeInfo* ti = c11_vector__emplace(types); py_TypeInfo* base_ti = base ? c11__at(py_TypeInfo, types, base) : NULL; - if(base_ti && base_ti->is_sealed){ + if(base_ti && base_ti->is_sealed) { c11__abort("type '%s' is not an acceptable base type", py_name2str(base_ti->name)); } py_TypeInfo__ctor(ti, py_name(name), index, base, module ? *module : *py_NIL); diff --git a/src/modules/json.c b/src/modules/json.c index 5a5ca0bf..693f13e3 100644 --- a/src/modules/json.c +++ b/src/modules/json.c @@ -10,7 +10,7 @@ static bool json_loads(int argc, py_Ref argv) { PY_CHECK_ARGC(1); PY_CHECK_ARG_TYPE(0, tp_str); const char* source = py_tostr(argv); - py_TmpRef mod = py_getmodule("json"); + py_GlobalRef mod = py_getmodule("json"); return py_exec(source, "", EVAL_MODE, mod); } diff --git a/src/objects/namedict.c b/src/objects/namedict.c index b8ba304b..a221171b 100644 --- a/src/objects/namedict.c +++ b/src/objects/namedict.c @@ -6,3 +6,70 @@ #define NAME NameDict #include "pocketpy/xmacros/smallmap.h" #undef SMALLMAP_T__SOURCE + +void ModuleDict__ctor(ModuleDict* self, const char* path, py_TValue module) { + self->path = path; + self->module = module; + self->left = NULL; + self->right = NULL; +} + +void ModuleDict__dtor(ModuleDict* self) { + if(self->left) { + ModuleDict__dtor(self->left); + free(self->left); + } + if(self->right) { + ModuleDict__dtor(self->right); + free(self->right); + } +} + +void ModuleDict__set(ModuleDict* self, const char* key, py_TValue val) { + if(self->path == NULL) { + self->path = key; + self->module = val; + } + int cmp = strcmp(key, self->path); + if(cmp < 0) { + if(self->left) { + ModuleDict__set(self->left, key, val); + } else { + self->left = malloc(sizeof(ModuleDict)); + ModuleDict__ctor(self->left, key, val); + } + } else if(cmp > 0) { + if(self->right) { + ModuleDict__set(self->right, key, val); + } else { + self->right = malloc(sizeof(ModuleDict)); + ModuleDict__ctor(self->right, key, val); + } + } else { + self->module = val; + } +} + +py_TValue* ModuleDict__try_get(ModuleDict* self, const char* path) { + if(self->path == NULL) return NULL; + int cmp = strcmp(path, self->path); + if(cmp < 0) { + if(self->left) { + return ModuleDict__try_get(self->left, path); + } else { + return NULL; + } + } else if(cmp > 0) { + if(self->right) { + return ModuleDict__try_get(self->right, path); + } else { + return NULL; + } + } else { + return &self->module; + } +} + +bool ModuleDict__contains(ModuleDict* self, const char* path) { + return ModuleDict__try_get(self, path) != NULL; +} \ No newline at end of file diff --git a/src/public/modules.c b/src/public/modules.c index b01d5751..ec71fd82 100644 --- a/src/public/modules.c +++ b/src/public/modules.c @@ -10,7 +10,7 @@ py_Ref py_getmodule(const char* path) { VM* vm = pk_current_vm; - return NameDict__try_get(&vm->modules, py_name(path)); + return ModuleDict__try_get(&vm->modules, path); } py_Ref py_getbuiltin(py_Name name) { return py_getdict(&pk_current_vm->builtins, name); } @@ -51,10 +51,12 @@ py_Ref py_newmodule(const char* path) { // we do not allow override in order to avoid memory leak // it is because Module objects are not garbage collected - py_Name path_name = py_name(path); - bool exists = NameDict__contains(&pk_current_vm->modules, path_name); + bool exists = ModuleDict__contains(&pk_current_vm->modules, path); if(exists) c11__abort("module '%s' already exists", path); - NameDict__set(&pk_current_vm->modules, path_name, *r0); + + // convert to a weak (const char*) + path = py_tostr(py_getdict(r0, __path__)); + ModuleDict__set(&pk_current_vm->modules, path, *r0); py_shrink(2); return py_getmodule(path); @@ -112,7 +114,7 @@ int py_import(const char* path_cstr) { assert(path.data[0] != '.' && path.data[path.size - 1] != '.'); // check existing module - py_TmpRef ext_mod = py_getmodule(path.data); + py_GlobalRef ext_mod = py_getmodule(path.data); if(ext_mod) { py_assign(py_retval(), ext_mod); return true;