From 3bb3b0d58443248fa09e3a7ea14fbbef91ffff41 Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Sun, 19 Feb 2023 02:39:49 +0800 Subject: [PATCH] bug fix --- src/ceval.h | 16 ++++++++-------- src/compiler.h | 35 ++++++++++++++++++----------------- src/obj.h | 2 -- src/pocketpy.h | 2 +- src/vm.h | 26 +++++++++++++------------- tests/_closure.py | 10 +++++++++- 6 files changed, 49 insertions(+), 42 deletions(-) diff --git a/src/ceval.h b/src/ceval.h index a928b5d1..4106abea 100644 --- a/src/ceval.h +++ b/src/ceval.h @@ -13,14 +13,14 @@ PyVar VM::run_frame(Frame* frame){ case OP_NO_OP: continue; case OP_LOAD_CONST: frame->push(frame->co->consts[byte.arg]); continue; case OP_LOAD_FUNCTION: { - PyVar obj = frame->co->consts[byte.arg]; - auto& f = PyFunction_AS_C(obj); - f->_module = frame->_module; - frame->push(obj); + const PyVar obj = frame->co->consts[byte.arg]; + pkpy::Function f = PyFunction_AS_C(obj); // copy + f._module = frame->_module; + frame->push(PyFunction(f)); } continue; case OP_SETUP_CLOSURE: { - auto& f = PyFunction_AS_C(frame->top()); - f->_closure = frame->_locals; + pkpy::Function& f = PyFunction_AS_C(frame->top()); // reference + f._closure = frame->_locals; } continue; case OP_LOAD_NAME_REF: { frame->push(PyRef(NameRef(frame->co->names[byte.arg]))); @@ -98,8 +98,8 @@ PyVar VM::run_frame(Frame* frame){ while(true){ PyVar fn = frame->pop_value(this); if(fn == None) break; - const pkpy::Function_& f = PyFunction_AS_C(fn); - setattr(cls, f->name, fn); + const pkpy::Function& f = PyFunction_AS_C(fn); + setattr(cls, f.name, fn); } } continue; case OP_RETURN_VALUE: return frame->pop_value(this); diff --git a/src/compiler.h b/src/compiler.h index 60620158..d5b167e8 100644 --- a/src/compiler.h +++ b/src/compiler.h @@ -384,19 +384,20 @@ private: } void exprLambda() { - pkpy::Function_ func = pkpy::make_shared(); - func->name = ""; + pkpy::Function func; + func.name = ""; if(!match(TK(":"))){ _compile_f_args(func, false); consume(TK(":")); } - func->code = pkpy::make_shared(parser->src, func->name); - this->codes.push(func->code); + func.code = pkpy::make_shared(parser->src, func.name); + this->codes.push(func.code); EXPR_TUPLE(); emit(OP_RETURN_VALUE); - func->code->optimize(vm); + func.code->optimize(vm); this->codes.pop(); emit(OP_LOAD_FUNCTION, co()->add_const(vm->PyFunction(func))); + if(name_scope() == NAME_LOCAL) emit(OP_SETUP_CLOSURE); } void exprAssign() { @@ -961,7 +962,7 @@ __LISTCOMP: emit(OP_BUILD_CLASS, cls_name_idx); } - void _compile_f_args(pkpy::Function_ func, bool enable_type_hints){ + void _compile_f_args(pkpy::Function& func, bool enable_type_hints){ int state = 0; // 0 for args, 1 for *args, 2 for k=v, 3 for **kwargs do { if(state == 3) SyntaxError("**kwargs should be the last argument"); @@ -976,7 +977,7 @@ __LISTCOMP: consume(TK("@id")); const Str& name = parser->prev.str(); - if(func->has_name(name)) SyntaxError("duplicate argument name"); + if(func.has_name(name)) SyntaxError("duplicate argument name"); // eat type hints if(enable_type_hints && match(TK(":"))) consume(TK("@id")); @@ -985,16 +986,16 @@ __LISTCOMP: switch (state) { - case 0: func->args.push_back(name); break; - case 1: func->starred_arg = name; state+=1; break; + case 0: func.args.push_back(name); break; + case 1: func.starred_arg = name; state+=1; break; case 2: { consume(TK("=")); PyVarOrNull value = read_literal(); if(value == nullptr){ SyntaxError(Str("expect a literal, not ") + TK_STR(parser->curr.type)); } - func->kwargs[name] = value; - func->kwargs_order.push_back(name); + func.kwargs[name] = value; + func.kwargs_order.push_back(name); } break; case 3: SyntaxError("**kwargs is not supported yet"); break; } @@ -1006,23 +1007,23 @@ __LISTCOMP: if(match(TK("pass"))) return; consume(TK("def")); } - pkpy::Function_ func = pkpy::make_shared(); + pkpy::Function func; consume(TK("@id")); - func->name = parser->prev.str(); + func.name = parser->prev.str(); consume(TK("(")); if (!match(TK(")"))) { _compile_f_args(func, true); consume(TK(")")); } if(match(TK("->"))) consume(TK("@id")); // eat type hints - func->code = pkpy::make_shared(parser->src, func->name); - this->codes.push(func->code); + func.code = pkpy::make_shared(parser->src, func.name); + this->codes.push(func.code); compile_block_body(); - func->code->optimize(vm); + func.code->optimize(vm); this->codes.pop(); emit(OP_LOAD_FUNCTION, co()->add_const(vm->PyFunction(func))); if(name_scope() == NAME_LOCAL) emit(OP_SETUP_CLOSURE); - if(!is_compiling_class) emit(OP_STORE_NAME, co()->add_name(func->name, name_scope())); + if(!is_compiling_class) emit(OP_STORE_NAME, co()->add_name(func.name, name_scope())); } PyVarOrNull read_literal(){ diff --git a/src/obj.h b/src/obj.h index 09f160c2..1f1c3ac2 100644 --- a/src/obj.h +++ b/src/obj.h @@ -63,8 +63,6 @@ struct Slice { if(stop < start) stop = start; } }; - -typedef shared_ptr Function_; } class BaseIter { diff --git a/src/pocketpy.h b/src/pocketpy.h index 8c8e83d3..bdd10c15 100644 --- a/src/pocketpy.h +++ b/src/pocketpy.h @@ -591,7 +591,7 @@ void add_module_math(VM* vm){ void add_module_dis(VM* vm){ PyVar mod = vm->new_module("dis"); vm->bind_func<1>(mod, "dis", [](VM* vm, pkpy::Args& args) { - CodeObject_ code = vm->PyFunction_AS_C(args[0])->code; + CodeObject_ code = vm->PyFunction_AS_C(args[0]).code; (*vm->_stdout) << vm->disassemble(code); return vm->None; }); diff --git a/src/vm.h b/src/vm.h index 57dc2c08..841397af 100644 --- a/src/vm.h +++ b/src/vm.h @@ -150,12 +150,12 @@ public: if(kwargs.size() != 0) TypeError("native_function does not accept keyword arguments"); return f(this, args); } else if((*callable)->is_type(tp_function)){ - const pkpy::Function_& fn = PyFunction_AS_C((*callable)); + const pkpy::Function& fn = PyFunction_AS_C(*callable); pkpy::shared_ptr _locals = pkpy::make_shared(); pkpy::NameDict& locals = *_locals; int i = 0; - for(const auto& name : fn->args){ + for(const auto& name : fn.args){ if(i < args.size()){ locals.emplace(name, args[i++]); continue; @@ -163,15 +163,15 @@ public: TypeError("missing positional argument '" + name + "'"); } - locals.insert(fn->kwargs.begin(), fn->kwargs.end()); + locals.insert(fn.kwargs.begin(), fn.kwargs.end()); std::vector positional_overrided_keys; - if(!fn->starred_arg.empty()){ + if(!fn.starred_arg.empty()){ pkpy::List vargs; // handle *args while(i < args.size()) vargs.push_back(args[i++]); - locals.emplace(fn->starred_arg, PyTuple(std::move(vargs))); + locals.emplace(fn.starred_arg, PyTuple(std::move(vargs))); }else{ - for(const auto& key : fn->kwargs_order){ + for(const auto& key : fn.kwargs_order){ if(i < args.size()){ locals[key] = args[i++]; positional_overrided_keys.push_back(key); @@ -184,8 +184,8 @@ public: for(int i=0; ikwargs.contains(key)){ - TypeError(key.escape(true) + " is an invalid keyword argument for " + fn->name + "()"); + if(!fn.kwargs.contains(key)){ + TypeError(key.escape(true) + " is an invalid keyword argument for " + fn.name + "()"); } const PyVar& val = kwargs[i+1]; if(!positional_overrided_keys.empty()){ @@ -196,9 +196,9 @@ public: } locals[key] = val; } - PyVar _module = fn->_module != nullptr ? fn->_module : top_frame()->_module; - auto _frame = _new_frame(fn->code, _module, _locals, fn->_closure); - if(fn->code->is_generator){ + PyVar _module = fn._module != nullptr ? fn._module : top_frame()->_module; + auto _frame = _new_frame(fn.code, _module, _locals, fn._closure); + if(fn.code->is_generator){ return PyIter(pkpy::make_shared( this, std::move(_frame))); } @@ -512,7 +512,7 @@ public: PyVar obj = co->consts[i]; if(obj->is_type(tp_function)){ const auto& f = PyFunction_AS_C(obj); - ss << disassemble(f->code); + ss << disassemble(f.code); } } return Str(ss.str()); @@ -554,7 +554,7 @@ public: DEF_NATIVE(Float, f64, tp_float) DEF_NATIVE(List, pkpy::List, tp_list) DEF_NATIVE(Tuple, pkpy::Tuple, tp_tuple) - DEF_NATIVE(Function, pkpy::Function_, tp_function) + DEF_NATIVE(Function, pkpy::Function, tp_function) DEF_NATIVE(NativeFunc, pkpy::NativeFunc, tp_native_function) DEF_NATIVE(Iter, pkpy::shared_ptr, tp_native_iterator) DEF_NATIVE(BoundMethod, pkpy::BoundMethod, tp_bound_method) diff --git a/tests/_closure.py b/tests/_closure.py index f5d62b13..b78533a2 100644 --- a/tests/_closure.py +++ b/tests/_closure.py @@ -6,7 +6,9 @@ def f0(a, b): return f1 a = f0(1, 2) +b = f0(3, 4) assert a() == 3 +assert b() == 7 def f0(a, b): @@ -16,4 +18,10 @@ def f0(a, b): return f1 a = f0(1, 2) -assert a() == 7 \ No newline at end of file +assert a() == 7 + +def f3(x, y): + return lambda z: x + y + z + +a = f3(1, 2) +assert a(3) == 6 \ No newline at end of file