diff --git a/src/ceval.h b/src/ceval.h index baa5d7d1..a928b5d1 100644 --- a/src/ceval.h +++ b/src/ceval.h @@ -14,9 +14,14 @@ PyVar VM::run_frame(Frame* frame){ case OP_LOAD_CONST: frame->push(frame->co->consts[byte.arg]); continue; case OP_LOAD_FUNCTION: { PyVar obj = frame->co->consts[byte.arg]; - setattr(obj, __module__, frame->_module); + auto& f = PyFunction_AS_C(obj); + f->_module = frame->_module; frame->push(obj); } continue; + case OP_SETUP_CLOSURE: { + auto& f = PyFunction_AS_C(frame->top()); + f->_closure = frame->_locals; + } continue; case OP_LOAD_NAME_REF: { frame->push(PyRef(NameRef(frame->co->names[byte.arg]))); } continue; diff --git a/src/common.h b/src/common.h index d2098a99..e858f634 100644 --- a/src/common.h +++ b/src/common.h @@ -34,7 +34,7 @@ #define UNREACHABLE() throw std::runtime_error( __FILE__ + std::string(":") + std::to_string(__LINE__) + " UNREACHABLE()!"); #endif -#define PK_VERSION "0.8.7" +#define PK_VERSION "0.8.8" typedef int64_t i64; typedef double f64; diff --git a/src/compiler.h b/src/compiler.h index 9121c780..60620158 100644 --- a/src/compiler.h +++ b/src/compiler.h @@ -976,7 +976,7 @@ __LISTCOMP: consume(TK("@id")); const Str& name = parser->prev.str(); - if(func->hasName(name)) SyntaxError("duplicate argument name"); + if(func->has_name(name)) SyntaxError("duplicate argument name"); // eat type hints if(enable_type_hints && match(TK(":"))) consume(TK("@id")); @@ -986,15 +986,15 @@ __LISTCOMP: switch (state) { case 0: func->args.push_back(name); break; - case 1: func->starredArg = name; state+=1; break; + case 1: func->starred_arg = name; state+=1; break; case 2: { consume(TK("=")); PyVarOrNull value = read_literal(); if(value == nullptr){ SyntaxError(Str("expect a literal, not ") + TK_STR(parser->curr.type)); } - func->kwArgs[name] = value; - func->kwArgsOrder.push_back(name); + func->kwargs[name] = value; + func->kwargs_order.push_back(name); } break; case 3: SyntaxError("**kwargs is not supported yet"); break; } @@ -1021,6 +1021,7 @@ __LISTCOMP: func->code->optimize(vm); this->codes.pop(); emit(OP_LOAD_FUNCTION, co()->add_const(vm->PyFunction(func))); + if(name_scope() == NAME_LOCAL) emit(OP_SETUP_CLOSURE); if(!is_compiling_class) emit(OP_STORE_NAME, co()->add_name(func->name, name_scope())); } diff --git a/src/frame.h b/src/frame.h index ee65af66..91d04923 100644 --- a/src/frame.h +++ b/src/frame.h @@ -12,14 +12,21 @@ struct Frame { const CodeObject_ co; PyVar _module; pkpy::shared_ptr _locals; + pkpy::shared_ptr _closure; const i64 id; std::stack>> s_try_block; inline pkpy::NameDict& f_locals() noexcept { return *_locals; } inline pkpy::NameDict& f_globals() noexcept { return _module->attr(); } - Frame(const CodeObject_ co, PyVar _module, pkpy::shared_ptr _locals) - : co(co), _module(_module), _locals(_locals), id(kFrameGlobalId++) { } + inline PyVar* f_closure_try_get(const Str& name) noexcept { + if(_closure == nullptr) return nullptr; + return _closure->try_get(name); + } + + Frame(const CodeObject_ co, PyVar _module, + pkpy::shared_ptr _locals, pkpy::shared_ptr _closure=nullptr) + : co(co), _module(_module), _locals(_locals), _closure(_closure), id(kFrameGlobalId++) { } inline const Bytecode& next_bytecode() { _ip = _next_ip++; diff --git a/src/memory.h b/src/memory.h index 921f0929..eeb3db16 100644 --- a/src/memory.h +++ b/src/memory.h @@ -18,14 +18,14 @@ namespace pkpy{ template class shared_ptr { - int* counter = nullptr; + int* counter; #define _t() ((T*)(counter + 1)) #define _inc_counter() if(counter) ++(*counter) #define _dec_counter() if(counter && --(*counter) == 0){ SpAllocator::dealloc(counter); } public: - shared_ptr() {} + shared_ptr() : counter(nullptr) {} shared_ptr(int* counter) : counter(counter) {} shared_ptr(const shared_ptr& other) : counter(other.counter) { _inc_counter(); diff --git a/src/obj.h b/src/obj.h index 944a67da..09f160c2 100644 --- a/src/obj.h +++ b/src/obj.h @@ -24,14 +24,18 @@ struct Function { Str name; CodeObject_ code; std::vector args; - Str starredArg; // empty if no *arg - pkpy::NameDict kwArgs; // empty if no k=v - std::vector kwArgsOrder; + Str starred_arg; // empty if no *arg + pkpy::NameDict kwargs; // empty if no k=v + std::vector kwargs_order; - bool hasName(const Str& val) const { + // runtime settings + PyVar _module; + pkpy::shared_ptr _closure; + + bool has_name(const Str& val) const { bool _0 = std::find(args.begin(), args.end(), val) != args.end(); - bool _1 = starredArg == val; - bool _2 = kwArgs.find(val) != kwArgs.end(); + bool _1 = starred_arg == val; + bool _2 = kwargs.find(val) != kwargs.end(); return _0 || _1 || _2; } }; @@ -99,8 +103,7 @@ struct Py_ : PyObject { Py_(Type type, T&& val): PyObject(type, sizeof(Py_)), _value(std::move(val)) { _init(); } inline void _init() noexcept { - if constexpr (std::is_same_v || std::is_same_v - || std::is_same_v || std::is_same_v) { + if constexpr (std::is_same_v || std::is_same_v) { _attr = new pkpy::NameDict(); }else{ _attr = nullptr; diff --git a/src/opcodes.h b/src/opcodes.h index c6d8d22b..283264f2 100644 --- a/src/opcodes.h +++ b/src/opcodes.h @@ -76,4 +76,6 @@ OPCODE(FAST_INDEX_REF) // a[x] OPCODE(INPLACE_BINARY_OP) OPCODE(INPLACE_BITWISE_OP) +OPCODE(SETUP_CLOSURE) + #endif \ No newline at end of file diff --git a/src/str.h b/src/str.h index bc08eab7..cc08223b 100644 --- a/src/str.h +++ b/src/str.h @@ -153,7 +153,6 @@ const Str __new__ = Str("__new__"); const Str __iter__ = Str("__iter__"); const Str __str__ = Str("__str__"); const Str __repr__ = Str("__repr__"); -const Str __module__ = Str("__module__"); const Str __getitem__ = Str("__getitem__"); const Str __setitem__ = Str("__setitem__"); const Str __delitem__ = Str("__delitem__"); diff --git a/src/vm.h b/src/vm.h index ff21045c..57dc2c08 100644 --- a/src/vm.h +++ b/src/vm.h @@ -163,15 +163,15 @@ public: TypeError("missing positional argument '" + name + "'"); } - locals.insert(fn->kwArgs.begin(), fn->kwArgs.end()); + locals.insert(fn->kwargs.begin(), fn->kwargs.end()); std::vector positional_overrided_keys; - if(!fn->starredArg.empty()){ + if(!fn->starred_arg.empty()){ pkpy::List vargs; // handle *args while(i < args.size()) vargs.push_back(args[i++]); - locals.emplace(fn->starredArg, PyTuple(std::move(vargs))); + locals.emplace(fn->starred_arg, PyTuple(std::move(vargs))); }else{ - for(const auto& key : fn->kwArgsOrder){ + for(const auto& key : fn->kwargs_order){ if(i < args.size()){ locals[key] = args[i++]; positional_overrided_keys.push_back(key); @@ -184,7 +184,7 @@ public: for(int i=0; ikwArgs.contains(key)){ + if(!fn->kwargs.contains(key)){ TypeError(key.escape(true) + " is an invalid keyword argument for " + fn->name + "()"); } const PyVar& val = kwargs[i+1]; @@ -196,10 +196,8 @@ public: } locals[key] = val; } - - PyVar* _m = (*callable)->attr().try_get(__module__); - PyVar _module = _m != nullptr ? *_m : top_frame()->_module; - auto _frame = _new_frame(fn->code, _module, _locals); + PyVar _module = fn->_module != nullptr ? fn->_module : top_frame()->_module; + auto _frame = _new_frame(fn->code, _module, _locals, fn->_closure); if(fn->code->is_generator){ return PyIter(pkpy::make_shared( this, std::move(_frame))); @@ -208,7 +206,7 @@ public: if(opCall) return _py_op_call; return _exec(); } - TypeError("'" + OBJ_NAME(_t(*callable)) + "' object is not callable"); + TypeError(OBJ_NAME(_t(*callable)).escape(true) + " object is not callable"); return None; } @@ -716,6 +714,8 @@ PyVar NameRef::get(VM* vm, Frame* frame) const{ PyVar* val; val = frame->f_locals().try_get(name()); if(val) return *val; + val = frame->f_closure_try_get(name()); + if(val) return *val; val = frame->f_globals().try_get(name()); if(val) return *val; val = vm->builtins->attr().try_get(name()); diff --git a/tests/_closure.py b/tests/_closure.py new file mode 100644 index 00000000..f5d62b13 --- /dev/null +++ b/tests/_closure.py @@ -0,0 +1,19 @@ +# only one level nested closure is implemented + +def f0(a, b): + def f1(): + return a + b + return f1 + +a = f0(1, 2) +assert a() == 3 + + +def f0(a, b): + def f1(): + a = 5 # use this first + return a + b + return f1 + +a = f0(1, 2) +assert a() == 7 \ No newline at end of file