warn return with arg inside generator function

This commit is contained in:
blueloveTH 2024-01-05 22:02:23 +08:00
parent 96b814cc10
commit 02a25de8e5
7 changed files with 57 additions and 41 deletions

View File

@ -43,7 +43,7 @@ struct Expr{
struct CodeEmitContext{
VM* vm;
FuncDecl_ func; // optional
CodeObject_ co;
CodeObject_ co; // 1 CodeEmitContext <=> 1 CodeObject_
// some bugs on MSVC (error C2280) when using std::vector<Expr_>
// so we use stack_no_copy instead
stack_no_copy<Expr_> s_expr;
@ -55,6 +55,9 @@ struct CodeEmitContext{
bool is_compiling_class = false;
int for_loop_depth = 0;
std::map<void*, int> _co_consts_nonstring_dedup_map;
std::map<std::string_view, int> _co_consts_string_dedup_map;
int get_loop() const;
CodeBlock* enter_block(CodeBlockType type);
void exit_block();

View File

@ -630,7 +630,7 @@ __NEXT_STEP:;
PUSH(_0);
} DISPATCH();
TARGET(RETURN_VALUE){
PyObject* _0 = POPX();
PyObject* _0 = byte.arg == BC_NOARG ? POPX() : None;
_pop_frame();
if(frame.index == base_id){ // [ frameBase<- ]
return _0;

View File

@ -30,22 +30,13 @@ namespace pkpy{
// add a `return None` in the end as a guard
// previously, we only do this if the last opcode is not a return
// however, this is buggy...since there may be a jump to the end (out of bound) even if the last opcode is a return
ctx()->emit_(OP_LOAD_NONE, BC_NOARG, BC_KEEPLINE);
ctx()->emit_(OP_RETURN_VALUE, BC_NOARG, BC_KEEPLINE);
ctx()->emit_(OP_RETURN_VALUE, 1, BC_KEEPLINE);
// some check here
std::vector<Bytecode>& codes = ctx()->co->codes;
if(ctx()->co->varnames.size() > PK_MAX_CO_VARNAMES){
SyntaxError("maximum number of local variables exceeded");
}
if(ctx()->co->consts.size() > 65535){
// std::map<std::string_view, int> counts;
// for(PyObject* c: ctx()->co->consts){
// std::string_view key = obj_type_name(vm, vm->_tp(c)).sv();
// counts[key] += 1;
// }
// for(auto pair: counts){
// std::cout << pair.first << ": " << pair.second << std::endl;
// }
SyntaxError("maximum number of constants exceeded");
}
if(codes.size() > 65535 && ctx()->co->src->mode != JSON_MODE){
@ -63,6 +54,7 @@ namespace pkpy{
bc.arg = ctx()->co->_get_block_codei(i).end;
}
}
// pre-compute func->is_simple
FuncDecl_ func = contexts.top().func;
if(func){
func->is_simple = true;
@ -809,12 +801,14 @@ __EAT_DOTS_END:
case TK("return"):
if (contexts.size() <= 1) SyntaxError("'return' outside function");
if(match_end_stmt()){
ctx()->emit_(OP_LOAD_NONE, BC_NOARG, kw_line);
ctx()->emit_(OP_RETURN_VALUE, 1, kw_line);
}else{
EXPR_TUPLE(false);
// check if it is a generator
if(ctx()->co->is_generator) SyntaxError("'return' with argument inside generator function");
consume_end_stmt();
ctx()->emit_(OP_RETURN_VALUE, BC_NOARG, kw_line);
}
ctx()->emit_(OP_RETURN_VALUE, BC_NOARG, kw_line);
break;
/*************************************************/
case TK("if"): compile_if_stmt(); break;

View File

@ -83,21 +83,31 @@ namespace pkpy{
}
int CodeEmitContext::add_const(PyObject* v){
// simple deduplication, only works for int/float
for(int i=0; i<co->consts.size(); i++){
if(co->consts[i] == v) return i;
}
// string deduplication
if(is_non_tagged_type(v, vm->tp_str)){
const Str& v_str = PK_OBJ_GET(Str, v);
for(int i=0; i<co->consts.size(); i++){
if(is_non_tagged_type(co->consts[i], vm->tp_str)){
if(PK_OBJ_GET(Str, co->consts[i]) == v_str) return i;
}
// string deduplication
std::string_view key = PK_OBJ_GET(Str, v).sv();
auto it = _co_consts_string_dedup_map.find(key);
if(it != _co_consts_string_dedup_map.end()){
return it->second;
}else{
co->consts.push_back(v);
int index = co->consts.size() - 1;
_co_consts_string_dedup_map[key] = index;
return index;
}
}else{
// non-string deduplication
auto it = _co_consts_nonstring_dedup_map.find(v);
if(it != _co_consts_nonstring_dedup_map.end()){
return it->second;
}else{
co->consts.push_back(v);
int index = co->consts.size() - 1;
_co_consts_nonstring_dedup_map[v] = index;
return index;
}
}
co->consts.push_back(v);
return co->consts.size() - 1;
PK_UNREACHABLE();
}
int CodeEmitContext::add_func_decl(FuncDecl_ decl){

View File

@ -64,4 +64,26 @@ try:
except ValueError:
pass
assert next(t) == StopIteration
assert next(t) == StopIteration
def f():
yield 1
yield 2
return
yield 3
assert list(f()) == [1, 2]
src = '''
def g():
yield 1
yield 2
return 3
yield 4
'''
try:
exec(src)
exit(1)
except SyntaxError:
pass

View File

@ -899,18 +899,6 @@ time.sleep(0.1)
# test time.localtime
assert type(time.localtime()) is time.struct_time
# /************ module dis ************/
import dis
# 116: 1487: vm->bind_func<1>(mod, "dis", [](VM* vm, ArgsView args) {
# #####: 1488: CodeObject_ code = get_code(vm, args[0]);
# #####: 1489: vm->_stdout(vm, vm->disassemble(code));
# #####: 1490: return vm->None;
# #####: 1491: });
# test dis.dis
def aaa():
pass
assert dis.dis(aaa) is None
# test min/max
assert min(1, 2) == 1
assert min(1, 2, 3) == 1

View File

@ -13,5 +13,4 @@ def f(a):
def g(a):
return f([1,2,3] + a)
# x = _s(g)
# assert type(x) is str
assert dis(g) is None