import can be used in local scope now

This commit is contained in:
blueloveTH 2023-10-12 01:02:46 +08:00
parent 3a3b97c070
commit eb5be9ba41
4 changed files with 32 additions and 19 deletions

View File

@ -65,6 +65,7 @@ struct CodeEmitContext{
int add_varname(StrName name); int add_varname(StrName name);
int add_const(PyObject* v); int add_const(PyObject* v);
int add_func_decl(FuncDecl_ decl); int add_func_decl(FuncDecl_ decl);
void emit_store_name(NameScope scope, StrName name, int line);
}; };
struct NameExpr: Expr{ struct NameExpr: Expr{

View File

@ -477,7 +477,6 @@ __SUBSCR_END:
// import a [as b] // import a [as b]
// import a [as b], c [as d] // import a [as b], c [as d]
void Compiler::compile_normal_import() { void Compiler::compile_normal_import() {
if(name_scope() != NAME_GLOBAL) SyntaxError("import statement should be used in global scope");
do { do {
consume(TK("@id")); consume(TK("@id"));
Str name = prev().str(); Str name = prev().str();
@ -486,7 +485,7 @@ __SUBSCR_END:
consume(TK("@id")); consume(TK("@id"));
name = prev().str(); name = prev().str();
} }
ctx()->emit(OP_STORE_GLOBAL, StrName(name).index, prev().line); ctx()->emit_store_name(name_scope(), StrName(name), prev().line);
} while (match(TK(","))); } while (match(TK(",")));
consume_end_stmt(); consume_end_stmt();
} }
@ -499,7 +498,6 @@ __SUBSCR_END:
// from .a.b import c [as d] // from .a.b import c [as d]
// from xxx import * // from xxx import *
void Compiler::compile_from_import() { void Compiler::compile_from_import() {
if(name_scope() != NAME_GLOBAL) SyntaxError("import statement should be used in global scope");
int dots = 0; int dots = 0;
while(true){ while(true){
@ -538,6 +536,7 @@ __EAT_DOTS_END:
consume(TK("import")); consume(TK("import"));
if (match(TK("*"))) { if (match(TK("*"))) {
if(name_scope() != NAME_GLOBAL) SyntaxError("from <module> import * can only be used in global scope");
// pop the module and import __all__ // pop the module and import __all__
ctx()->emit(OP_POP_IMPORT_STAR, BC_NOARG, prev().line); ctx()->emit(OP_POP_IMPORT_STAR, BC_NOARG, prev().line);
consume_end_stmt(); consume_end_stmt();
@ -553,7 +552,7 @@ __EAT_DOTS_END:
consume(TK("@id")); consume(TK("@id"));
name = prev().str(); name = prev().str();
} }
ctx()->emit(OP_STORE_GLOBAL, StrName(name).index, prev().line); ctx()->emit_store_name(name_scope(), StrName(name), prev().line);
} while (match(TK(","))); } while (match(TK(",")));
ctx()->emit(OP_POP_TOP, BC_NOARG, BC_KEEPLINE); ctx()->emit(OP_POP_TOP, BC_NOARG, BC_KEEPLINE);
consume_end_stmt(); consume_end_stmt();

View File

@ -1,4 +1,5 @@
#include "pocketpy/expr.h" #include "pocketpy/expr.h"
#include "pocketpy/codeobject.h"
namespace pkpy{ namespace pkpy{
@ -96,6 +97,21 @@ namespace pkpy{
return co->func_decls.size() - 1; return co->func_decls.size() - 1;
} }
void CodeEmitContext::emit_store_name(NameScope scope, StrName name, int line){
switch(scope){
case NAME_LOCAL:
emit(OP_STORE_FAST, add_varname(name), line);
break;
case NAME_GLOBAL:
emit(OP_STORE_GLOBAL, StrName(name).index, line);
break;
case NAME_GLOBAL_UNKNOWN:
emit(OP_STORE_NAME, StrName(name).index, line);
break;
default: FATAL_ERROR(); break;
}
}
void NameExpr::emit(CodeEmitContext* ctx) { void NameExpr::emit(CodeEmitContext* ctx) {
int index = ctx->co->varnames_inv.try_get(name); int index = ctx->co->varnames_inv.try_get(name);
@ -127,22 +143,11 @@ namespace pkpy{
bool NameExpr::emit_store(CodeEmitContext* ctx) { bool NameExpr::emit_store(CodeEmitContext* ctx) {
if(ctx->is_compiling_class){ if(ctx->is_compiling_class){
int index = StrName(name).index; int index = name.index;
ctx->emit(OP_STORE_CLASS_ATTR, index, line); ctx->emit(OP_STORE_CLASS_ATTR, index, line);
return true; return true;
} }
switch(scope){ ctx->emit_store_name(scope, name, line);
case NAME_LOCAL:
ctx->emit(OP_STORE_FAST, ctx->add_varname(name), line);
break;
case NAME_GLOBAL:
ctx->emit(OP_STORE_GLOBAL, StrName(name).index, line);
break;
case NAME_GLOBAL_UNKNOWN:
ctx->emit(OP_STORE_NAME, StrName(name).index, line);
break;
default: FATAL_ERROR(); break;
}
return true; return true;
} }

View File

@ -17,5 +17,13 @@ from test2.utils import get_value_2
assert get_value_2() == '123' assert get_value_2() == '123'
from test3.a.b import value from test3.a.b import value
# should test3
assert value == 1 assert value == 1
def f():
import math as m
assert m.pi > 3
from test3.a.b import value
assert value == 1
f()