This commit is contained in:
BLUELOVETH 2023-08-28 11:42:40 +08:00
parent 0c80a626ba
commit 0ed2d8f3b1
3 changed files with 26 additions and 17 deletions

View File

@ -410,17 +410,23 @@ public:
} }
struct ImportContext{ struct ImportContext{
std::vector<StrName> pending; std::vector<Str> pending;
std::vector<bool> pending_is_init; // a.k.a __init__.py
struct Temp{ struct Temp{
ImportContext* ctx; ImportContext* ctx;
StrName name; Temp(ImportContext* ctx, Str name, bool is_init) : ctx(ctx){
Temp(ImportContext* ctx, StrName name) : ctx(ctx), name(name){
ctx->pending.push_back(name); ctx->pending.push_back(name);
ctx->pending_is_init.push_back(is_init);
}
~Temp(){
ctx->pending.pop_back();
ctx->pending_is_init.pop_back();
} }
~Temp(){ ctx->pending.pop_back(); }
}; };
Temp scope(StrName name){ return {this, name}; } Temp scope(Str name, bool is_init){
return {this, name, is_init};
}
}; };
ImportContext _import_context; ImportContext _import_context;

View File

@ -1229,9 +1229,7 @@ void init_builtins(VM* _vm) {
_vm->bind__repr__(_vm->tp_module, [](VM* vm, PyObject* obj) { _vm->bind__repr__(_vm->tp_module, [](VM* vm, PyObject* obj) {
const Str& package = CAST(Str&, obj->attr(__package__)); const Str& package = CAST(Str&, obj->attr(__package__));
Str name = CAST(Str&, obj->attr(__name__)); Str name = CAST(Str&, obj->attr(__name__));
if(!package.empty()){ if(!package.empty()) name = package + "." + name;
name = package + "." + name;
}
return VAR(fmt("<module ", name.escape(), ">")); return VAR(fmt("<module ", name.escape(), ">"));
}); });

View File

@ -228,12 +228,13 @@ namespace pkpy{
}; };
if(path[0] == '.'){ if(path[0] == '.'){
Str _mod_name = CAST(Str&, _module->attr(__name__)); if(_import_context.pending.empty()){
Str _mod_package = CAST(Str&, _module->attr(__package__)); ImportError("relative import outside of package");
// get _module's fullname }
if(!_mod_package.empty()) _mod_name = _mod_package + "." + _mod_name; Str curr_path = _import_context.pending.back();
bool curr_is_init = _import_context.pending_is_init.back();
// convert relative path to absolute path // convert relative path to absolute path
std::vector<std::string_view> cpnts = _mod_name.split(".", true); std::vector<std::string_view> cpnts = curr_path.split(".", true);
int prefix = 0; // how many dots in the prefix int prefix = 0; // how many dots in the prefix
for(int i=0; i<path.length(); i++){ for(int i=0; i<path.length(); i++){
if(path[i] == '.') prefix++; if(path[i] == '.') prefix++;
@ -241,16 +242,18 @@ namespace pkpy{
} }
if(prefix > cpnts.size()) ImportError("attempted relative import beyond top-level package"); if(prefix > cpnts.size()) ImportError("attempted relative import beyond top-level package");
path = path.substr(prefix); // remove prefix path = path.substr(prefix); // remove prefix
for(int i=1; i<prefix; i++) cpnts.pop_back(); for(int i=(int)curr_is_init; i<prefix; i++) cpnts.pop_back();
cpnts.push_back(path.sv()); cpnts.push_back(path.sv());
path = f_join(cpnts); path = f_join(cpnts);
} }
std::cout << "py_import(" << path.escape() << ")" << std::endl;
StrName name(path); // path to StrName StrName name(path); // path to StrName
// check circular import // check circular import
for(StrName pending_name: _import_context.pending){ for(Str pending_name: _import_context.pending){
if(pending_name == name) ImportError(fmt("circular import ", name.escape())); if(pending_name == path) ImportError(fmt("circular import ", name.escape()));
} }
PyObject* ext_mod = _modules.try_get(name); PyObject* ext_mod = _modules.try_get(name);
@ -259,11 +262,13 @@ namespace pkpy{
// try import // try import
Str filename = path.replace('.', kPlatformSep) + ".py"; Str filename = path.replace('.', kPlatformSep) + ".py";
Str source; Str source;
bool is_init = false;
auto it = _lazy_modules.find(name); auto it = _lazy_modules.find(name);
if(it == _lazy_modules.end()){ if(it == _lazy_modules.end()){
Bytes b = _import_handler(filename); Bytes b = _import_handler(filename);
if(!b){ if(!b){
filename = path.replace('.', kPlatformSep).str() + kPlatformSep + "__init__.py"; filename = path.replace('.', kPlatformSep).str() + kPlatformSep + "__init__.py";
is_init = true;
b = _import_handler(filename); b = _import_handler(filename);
} }
if(!b) ImportError(fmt("module ", path.escape(), " not found")); if(!b) ImportError(fmt("module ", path.escape(), " not found"));
@ -272,7 +277,7 @@ namespace pkpy{
source = it->second; source = it->second;
_lazy_modules.erase(it); _lazy_modules.erase(it);
} }
auto _ = _import_context.scope(name); auto _ = _import_context.scope(path, is_init);
CodeObject_ code = compile(source, filename, EXEC_MODE); CodeObject_ code = compile(source, filename, EXEC_MODE);
auto all_cpnts = path.split(".", true); auto all_cpnts = path.split(".", true);