mirror of
				https://github.com/pocketpy/pocketpy
				synced 2025-10-20 19:40:18 +00:00 
			
		
		
		
	add closure impl
This commit is contained in:
		
							parent
							
								
									c4321b8f4b
								
							
						
					
					
						commit
						2285300ed3
					
				| @ -14,9 +14,14 @@ PyVar VM::run_frame(Frame* frame){ | |||||||
|         case OP_LOAD_CONST: frame->push(frame->co->consts[byte.arg]); continue; |         case OP_LOAD_CONST: frame->push(frame->co->consts[byte.arg]); continue; | ||||||
|         case OP_LOAD_FUNCTION: { |         case OP_LOAD_FUNCTION: { | ||||||
|             PyVar obj = frame->co->consts[byte.arg]; |             PyVar obj = frame->co->consts[byte.arg]; | ||||||
|             setattr(obj, __module__, frame->_module); |             auto& f = PyFunction_AS_C(obj); | ||||||
|  |             f->_module = frame->_module; | ||||||
|             frame->push(obj); |             frame->push(obj); | ||||||
|         } continue; |         } continue; | ||||||
|  |         case OP_SETUP_CLOSURE: { | ||||||
|  |             auto& f = PyFunction_AS_C(frame->top()); | ||||||
|  |             f->_closure = frame->_locals; | ||||||
|  |         } continue; | ||||||
|         case OP_LOAD_NAME_REF: { |         case OP_LOAD_NAME_REF: { | ||||||
|             frame->push(PyRef(NameRef(frame->co->names[byte.arg]))); |             frame->push(PyRef(NameRef(frame->co->names[byte.arg]))); | ||||||
|         } continue; |         } continue; | ||||||
|  | |||||||
| @ -34,7 +34,7 @@ | |||||||
| #define UNREACHABLE() throw std::runtime_error( __FILE__ + std::string(":") + std::to_string(__LINE__) + " UNREACHABLE()!"); | #define UNREACHABLE() throw std::runtime_error( __FILE__ + std::string(":") + std::to_string(__LINE__) + " UNREACHABLE()!"); | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
| #define PK_VERSION "0.8.7" | #define PK_VERSION "0.8.8" | ||||||
| 
 | 
 | ||||||
| typedef int64_t i64; | typedef int64_t i64; | ||||||
| typedef double f64; | typedef double f64; | ||||||
|  | |||||||
| @ -976,7 +976,7 @@ __LISTCOMP: | |||||||
| 
 | 
 | ||||||
|             consume(TK("@id")); |             consume(TK("@id")); | ||||||
|             const Str& name = parser->prev.str(); |             const Str& name = parser->prev.str(); | ||||||
|             if(func->hasName(name)) SyntaxError("duplicate argument name"); |             if(func->has_name(name)) SyntaxError("duplicate argument name"); | ||||||
| 
 | 
 | ||||||
|             // eat type hints
 |             // eat type hints
 | ||||||
|             if(enable_type_hints && match(TK(":"))) consume(TK("@id")); |             if(enable_type_hints && match(TK(":"))) consume(TK("@id")); | ||||||
| @ -986,15 +986,15 @@ __LISTCOMP: | |||||||
|             switch (state) |             switch (state) | ||||||
|             { |             { | ||||||
|                 case 0: func->args.push_back(name); break; |                 case 0: func->args.push_back(name); break; | ||||||
|                 case 1: func->starredArg = name; state+=1; break; |                 case 1: func->starred_arg = name; state+=1; break; | ||||||
|                 case 2: { |                 case 2: { | ||||||
|                     consume(TK("=")); |                     consume(TK("=")); | ||||||
|                     PyVarOrNull value = read_literal(); |                     PyVarOrNull value = read_literal(); | ||||||
|                     if(value == nullptr){ |                     if(value == nullptr){ | ||||||
|                         SyntaxError(Str("expect a literal, not ") + TK_STR(parser->curr.type)); |                         SyntaxError(Str("expect a literal, not ") + TK_STR(parser->curr.type)); | ||||||
|                     } |                     } | ||||||
|                     func->kwArgs[name] = value; |                     func->kwargs[name] = value; | ||||||
|                     func->kwArgsOrder.push_back(name); |                     func->kwargs_order.push_back(name); | ||||||
|                 } break; |                 } break; | ||||||
|                 case 3: SyntaxError("**kwargs is not supported yet"); break; |                 case 3: SyntaxError("**kwargs is not supported yet"); break; | ||||||
|             } |             } | ||||||
| @ -1021,6 +1021,7 @@ __LISTCOMP: | |||||||
|         func->code->optimize(vm); |         func->code->optimize(vm); | ||||||
|         this->codes.pop(); |         this->codes.pop(); | ||||||
|         emit(OP_LOAD_FUNCTION, co()->add_const(vm->PyFunction(func))); |         emit(OP_LOAD_FUNCTION, co()->add_const(vm->PyFunction(func))); | ||||||
|  |         if(name_scope() == NAME_LOCAL) emit(OP_SETUP_CLOSURE); | ||||||
|         if(!is_compiling_class) emit(OP_STORE_NAME, co()->add_name(func->name, name_scope())); |         if(!is_compiling_class) emit(OP_STORE_NAME, co()->add_name(func->name, name_scope())); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | |||||||
							
								
								
									
										11
									
								
								src/frame.h
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								src/frame.h
									
									
									
									
									
								
							| @ -12,14 +12,21 @@ struct Frame { | |||||||
|     const CodeObject_ co; |     const CodeObject_ co; | ||||||
|     PyVar _module; |     PyVar _module; | ||||||
|     pkpy::shared_ptr<pkpy::NameDict> _locals; |     pkpy::shared_ptr<pkpy::NameDict> _locals; | ||||||
|  |     pkpy::shared_ptr<pkpy::NameDict> _closure; | ||||||
|     const i64 id; |     const i64 id; | ||||||
|     std::stack<std::pair<int, std::vector<PyVar>>> s_try_block; |     std::stack<std::pair<int, std::vector<PyVar>>> s_try_block; | ||||||
| 
 | 
 | ||||||
|     inline pkpy::NameDict& f_locals() noexcept { return *_locals; } |     inline pkpy::NameDict& f_locals() noexcept { return *_locals; } | ||||||
|     inline pkpy::NameDict& f_globals() noexcept { return _module->attr(); } |     inline pkpy::NameDict& f_globals() noexcept { return _module->attr(); } | ||||||
| 
 | 
 | ||||||
|     Frame(const CodeObject_ co, PyVar _module, pkpy::shared_ptr<pkpy::NameDict> _locals) |     inline PyVar* f_closure_try_get(const Str& name) noexcept { | ||||||
|         : co(co), _module(_module), _locals(_locals), id(kFrameGlobalId++) { } |         if(_closure == nullptr) return nullptr; | ||||||
|  |         return _closure->try_get(name); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     Frame(const CodeObject_ co, PyVar _module, | ||||||
|  |         pkpy::shared_ptr<pkpy::NameDict> _locals, pkpy::shared_ptr<pkpy::NameDict> _closure=nullptr) | ||||||
|  |         : co(co), _module(_module), _locals(_locals), _closure(_closure), id(kFrameGlobalId++) { } | ||||||
| 
 | 
 | ||||||
|     inline const Bytecode& next_bytecode() { |     inline const Bytecode& next_bytecode() { | ||||||
|         _ip = _next_ip++; |         _ip = _next_ip++; | ||||||
|  | |||||||
| @ -18,14 +18,14 @@ namespace pkpy{ | |||||||
| 
 | 
 | ||||||
|     template <typename T> |     template <typename T> | ||||||
|     class shared_ptr { |     class shared_ptr { | ||||||
|         int* counter = nullptr; |         int* counter; | ||||||
| 
 | 
 | ||||||
| #define _t() ((T*)(counter + 1)) | #define _t() ((T*)(counter + 1)) | ||||||
| #define _inc_counter() if(counter) ++(*counter) | #define _inc_counter() if(counter) ++(*counter) | ||||||
| #define _dec_counter() if(counter && --(*counter) == 0){ SpAllocator<T>::dealloc(counter); } | #define _dec_counter() if(counter && --(*counter) == 0){ SpAllocator<T>::dealloc(counter); } | ||||||
| 
 | 
 | ||||||
|     public: |     public: | ||||||
|         shared_ptr() {} |         shared_ptr() : counter(nullptr) {} | ||||||
|         shared_ptr(int* counter) : counter(counter) {} |         shared_ptr(int* counter) : counter(counter) {} | ||||||
|         shared_ptr(const shared_ptr& other) : counter(other.counter) { |         shared_ptr(const shared_ptr& other) : counter(other.counter) { | ||||||
|             _inc_counter(); |             _inc_counter(); | ||||||
|  | |||||||
							
								
								
									
										19
									
								
								src/obj.h
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								src/obj.h
									
									
									
									
									
								
							| @ -24,14 +24,18 @@ struct Function { | |||||||
|     Str name; |     Str name; | ||||||
|     CodeObject_ code; |     CodeObject_ code; | ||||||
|     std::vector<Str> args; |     std::vector<Str> args; | ||||||
|     Str starredArg;                // empty if no *arg
 |     Str starred_arg;                // empty if no *arg
 | ||||||
|     pkpy::NameDict kwArgs;          // empty if no k=v
 |     pkpy::NameDict kwargs;          // empty if no k=v
 | ||||||
|     std::vector<Str> kwArgsOrder; |     std::vector<Str> kwargs_order; | ||||||
| 
 | 
 | ||||||
|     bool hasName(const Str& val) const { |     // runtime settings
 | ||||||
|  |     PyVar _module; | ||||||
|  |     pkpy::shared_ptr<pkpy::NameDict> _closure; | ||||||
|  | 
 | ||||||
|  |     bool has_name(const Str& val) const { | ||||||
|         bool _0 = std::find(args.begin(), args.end(), val) != args.end(); |         bool _0 = std::find(args.begin(), args.end(), val) != args.end(); | ||||||
|         bool _1 = starredArg == val; |         bool _1 = starred_arg == val; | ||||||
|         bool _2 = kwArgs.find(val) != kwArgs.end(); |         bool _2 = kwargs.find(val) != kwargs.end(); | ||||||
|         return _0 || _1 || _2; |         return _0 || _1 || _2; | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
| @ -99,8 +103,7 @@ struct Py_ : PyObject { | |||||||
|     Py_(Type type, T&& val): PyObject(type, sizeof(Py_<T>)), _value(std::move(val)) { _init(); } |     Py_(Type type, T&& val): PyObject(type, sizeof(Py_<T>)), _value(std::move(val)) { _init(); } | ||||||
| 
 | 
 | ||||||
|     inline void _init() noexcept { |     inline void _init() noexcept { | ||||||
|         if constexpr (std::is_same_v<T, Dummy> || std::is_same_v<T, Type> |         if constexpr (std::is_same_v<T, Dummy> || std::is_same_v<T, Type>) { | ||||||
|         || std::is_same_v<T, pkpy::Function_> || std::is_same_v<T, pkpy::NativeFunc>) { |  | ||||||
|             _attr = new pkpy::NameDict(); |             _attr = new pkpy::NameDict(); | ||||||
|         }else{ |         }else{ | ||||||
|             _attr = nullptr; |             _attr = nullptr; | ||||||
|  | |||||||
| @ -76,4 +76,6 @@ OPCODE(FAST_INDEX_REF)       // a[x] | |||||||
| OPCODE(INPLACE_BINARY_OP) | OPCODE(INPLACE_BINARY_OP) | ||||||
| OPCODE(INPLACE_BITWISE_OP) | OPCODE(INPLACE_BITWISE_OP) | ||||||
| 
 | 
 | ||||||
|  | OPCODE(SETUP_CLOSURE) | ||||||
|  | 
 | ||||||
| #endif | #endif | ||||||
| @ -153,7 +153,6 @@ const Str __new__ = Str("__new__"); | |||||||
| const Str __iter__ = Str("__iter__"); | const Str __iter__ = Str("__iter__"); | ||||||
| const Str __str__ = Str("__str__"); | const Str __str__ = Str("__str__"); | ||||||
| const Str __repr__ = Str("__repr__"); | const Str __repr__ = Str("__repr__"); | ||||||
| const Str __module__ = Str("__module__"); |  | ||||||
| const Str __getitem__ = Str("__getitem__"); | const Str __getitem__ = Str("__getitem__"); | ||||||
| const Str __setitem__ = Str("__setitem__"); | const Str __setitem__ = Str("__setitem__"); | ||||||
| const Str __delitem__ = Str("__delitem__"); | const Str __delitem__ = Str("__delitem__"); | ||||||
|  | |||||||
							
								
								
									
										20
									
								
								src/vm.h
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								src/vm.h
									
									
									
									
									
								
							| @ -163,15 +163,15 @@ public: | |||||||
|                 TypeError("missing positional argument '" + name + "'"); |                 TypeError("missing positional argument '" + name + "'"); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             locals.insert(fn->kwArgs.begin(), fn->kwArgs.end()); |             locals.insert(fn->kwargs.begin(), fn->kwargs.end()); | ||||||
| 
 | 
 | ||||||
|             std::vector<Str> positional_overrided_keys; |             std::vector<Str> positional_overrided_keys; | ||||||
|             if(!fn->starredArg.empty()){ |             if(!fn->starred_arg.empty()){ | ||||||
|                 pkpy::List vargs;        // handle *args
 |                 pkpy::List vargs;        // handle *args
 | ||||||
|                 while(i < args.size()) vargs.push_back(args[i++]); |                 while(i < args.size()) vargs.push_back(args[i++]); | ||||||
|                 locals.emplace(fn->starredArg, PyTuple(std::move(vargs))); |                 locals.emplace(fn->starred_arg, PyTuple(std::move(vargs))); | ||||||
|             }else{ |             }else{ | ||||||
|                 for(const auto& key : fn->kwArgsOrder){ |                 for(const auto& key : fn->kwargs_order){ | ||||||
|                     if(i < args.size()){ |                     if(i < args.size()){ | ||||||
|                         locals[key] = args[i++]; |                         locals[key] = args[i++]; | ||||||
|                         positional_overrided_keys.push_back(key); |                         positional_overrided_keys.push_back(key); | ||||||
| @ -184,7 +184,7 @@ public: | |||||||
|              |              | ||||||
|             for(int i=0; i<kwargs.size(); i+=2){ |             for(int i=0; i<kwargs.size(); i+=2){ | ||||||
|                 const Str& key = PyStr_AS_C(kwargs[i]); |                 const Str& key = PyStr_AS_C(kwargs[i]); | ||||||
|                 if(!fn->kwArgs.contains(key)){ |                 if(!fn->kwargs.contains(key)){ | ||||||
|                     TypeError(key.escape(true) + " is an invalid keyword argument for " + fn->name + "()"); |                     TypeError(key.escape(true) + " is an invalid keyword argument for " + fn->name + "()"); | ||||||
|                 } |                 } | ||||||
|                 const PyVar& val = kwargs[i+1]; |                 const PyVar& val = kwargs[i+1]; | ||||||
| @ -196,10 +196,8 @@ public: | |||||||
|                 } |                 } | ||||||
|                 locals[key] = val; |                 locals[key] = val; | ||||||
|             } |             } | ||||||
| 
 |             PyVar _module = fn->_module != nullptr ? fn->_module : top_frame()->_module; | ||||||
|             PyVar* _m = (*callable)->attr().try_get(__module__); |             auto _frame = _new_frame(fn->code, _module, _locals, fn->_closure); | ||||||
|             PyVar _module = _m != nullptr ? *_m : top_frame()->_module; |  | ||||||
|             auto _frame = _new_frame(fn->code, _module, _locals); |  | ||||||
|             if(fn->code->is_generator){ |             if(fn->code->is_generator){ | ||||||
|                 return PyIter(pkpy::make_shared<BaseIter, Generator>( |                 return PyIter(pkpy::make_shared<BaseIter, Generator>( | ||||||
|                     this, std::move(_frame))); |                     this, std::move(_frame))); | ||||||
| @ -208,7 +206,7 @@ public: | |||||||
|             if(opCall) return _py_op_call; |             if(opCall) return _py_op_call; | ||||||
|             return _exec(); |             return _exec(); | ||||||
|         } |         } | ||||||
|         TypeError("'" + OBJ_NAME(_t(*callable)) + "' object is not callable"); |         TypeError(OBJ_NAME(_t(*callable)).escape(true) + " object is not callable"); | ||||||
|         return None; |         return None; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -716,6 +714,8 @@ PyVar NameRef::get(VM* vm, Frame* frame) const{ | |||||||
|     PyVar* val; |     PyVar* val; | ||||||
|     val = frame->f_locals().try_get(name()); |     val = frame->f_locals().try_get(name()); | ||||||
|     if(val) return *val; |     if(val) return *val; | ||||||
|  |     val = frame->f_closure_try_get(name()); | ||||||
|  |     if(val) return *val; | ||||||
|     val = frame->f_globals().try_get(name()); |     val = frame->f_globals().try_get(name()); | ||||||
|     if(val) return *val; |     if(val) return *val; | ||||||
|     val = vm->builtins->attr().try_get(name()); |     val = vm->builtins->attr().try_get(name()); | ||||||
|  | |||||||
							
								
								
									
										19
									
								
								tests/_closure.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								tests/_closure.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,19 @@ | |||||||
|  | # only one level nested closure is implemented | ||||||
|  | 
 | ||||||
|  | def f0(a, b): | ||||||
|  |     def f1(): | ||||||
|  |         return a + b | ||||||
|  |     return f1 | ||||||
|  | 
 | ||||||
|  | a = f0(1, 2) | ||||||
|  | assert a() == 3 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def f0(a, b): | ||||||
|  |     def f1(): | ||||||
|  |         a = 5   # use this first | ||||||
|  |         return a + b | ||||||
|  |     return f1 | ||||||
|  | 
 | ||||||
|  | a = f0(1, 2) | ||||||
|  | assert a() == 7 | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user