add closure impl

This commit is contained in:
blueloveTH 2023-02-19 02:20:23 +08:00
parent c4321b8f4b
commit 2285300ed3
10 changed files with 65 additions and 29 deletions

View File

@ -14,9 +14,14 @@ PyVar VM::run_frame(Frame* frame){
case OP_LOAD_CONST: frame->push(frame->co->consts[byte.arg]); continue;
case OP_LOAD_FUNCTION: {
PyVar obj = frame->co->consts[byte.arg];
setattr(obj, __module__, frame->_module);
auto& f = PyFunction_AS_C(obj);
f->_module = frame->_module;
frame->push(obj);
} continue;
case OP_SETUP_CLOSURE: {
auto& f = PyFunction_AS_C(frame->top());
f->_closure = frame->_locals;
} continue;
case OP_LOAD_NAME_REF: {
frame->push(PyRef(NameRef(frame->co->names[byte.arg])));
} continue;

View File

@ -34,7 +34,7 @@
#define UNREACHABLE() throw std::runtime_error( __FILE__ + std::string(":") + std::to_string(__LINE__) + " UNREACHABLE()!");
#endif
#define PK_VERSION "0.8.7"
#define PK_VERSION "0.8.8"
typedef int64_t i64;
typedef double f64;

View File

@ -976,7 +976,7 @@ __LISTCOMP:
consume(TK("@id"));
const Str& name = parser->prev.str();
if(func->hasName(name)) SyntaxError("duplicate argument name");
if(func->has_name(name)) SyntaxError("duplicate argument name");
// eat type hints
if(enable_type_hints && match(TK(":"))) consume(TK("@id"));
@ -986,15 +986,15 @@ __LISTCOMP:
switch (state)
{
case 0: func->args.push_back(name); break;
case 1: func->starredArg = name; state+=1; break;
case 1: func->starred_arg = name; state+=1; break;
case 2: {
consume(TK("="));
PyVarOrNull value = read_literal();
if(value == nullptr){
SyntaxError(Str("expect a literal, not ") + TK_STR(parser->curr.type));
}
func->kwArgs[name] = value;
func->kwArgsOrder.push_back(name);
func->kwargs[name] = value;
func->kwargs_order.push_back(name);
} break;
case 3: SyntaxError("**kwargs is not supported yet"); break;
}
@ -1021,6 +1021,7 @@ __LISTCOMP:
func->code->optimize(vm);
this->codes.pop();
emit(OP_LOAD_FUNCTION, co()->add_const(vm->PyFunction(func)));
if(name_scope() == NAME_LOCAL) emit(OP_SETUP_CLOSURE);
if(!is_compiling_class) emit(OP_STORE_NAME, co()->add_name(func->name, name_scope()));
}

View File

@ -12,14 +12,21 @@ struct Frame {
const CodeObject_ co;
PyVar _module;
pkpy::shared_ptr<pkpy::NameDict> _locals;
pkpy::shared_ptr<pkpy::NameDict> _closure;
const i64 id;
std::stack<std::pair<int, std::vector<PyVar>>> s_try_block;
inline pkpy::NameDict& f_locals() noexcept { return *_locals; }
inline pkpy::NameDict& f_globals() noexcept { return _module->attr(); }
Frame(const CodeObject_ co, PyVar _module, pkpy::shared_ptr<pkpy::NameDict> _locals)
: co(co), _module(_module), _locals(_locals), id(kFrameGlobalId++) { }
inline PyVar* f_closure_try_get(const Str& name) noexcept {
if(_closure == nullptr) return nullptr;
return _closure->try_get(name);
}
Frame(const CodeObject_ co, PyVar _module,
pkpy::shared_ptr<pkpy::NameDict> _locals, pkpy::shared_ptr<pkpy::NameDict> _closure=nullptr)
: co(co), _module(_module), _locals(_locals), _closure(_closure), id(kFrameGlobalId++) { }
inline const Bytecode& next_bytecode() {
_ip = _next_ip++;

View File

@ -18,14 +18,14 @@ namespace pkpy{
template <typename T>
class shared_ptr {
int* counter = nullptr;
int* counter;
#define _t() ((T*)(counter + 1))
#define _inc_counter() if(counter) ++(*counter)
#define _dec_counter() if(counter && --(*counter) == 0){ SpAllocator<T>::dealloc(counter); }
public:
shared_ptr() {}
shared_ptr() : counter(nullptr) {}
shared_ptr(int* counter) : counter(counter) {}
shared_ptr(const shared_ptr& other) : counter(other.counter) {
_inc_counter();

View File

@ -24,14 +24,18 @@ struct Function {
Str name;
CodeObject_ code;
std::vector<Str> args;
Str starredArg; // empty if no *arg
pkpy::NameDict kwArgs; // empty if no k=v
std::vector<Str> kwArgsOrder;
Str starred_arg; // empty if no *arg
pkpy::NameDict kwargs; // empty if no k=v
std::vector<Str> kwargs_order;
bool hasName(const Str& val) const {
// runtime settings
PyVar _module;
pkpy::shared_ptr<pkpy::NameDict> _closure;
bool has_name(const Str& val) const {
bool _0 = std::find(args.begin(), args.end(), val) != args.end();
bool _1 = starredArg == val;
bool _2 = kwArgs.find(val) != kwArgs.end();
bool _1 = starred_arg == val;
bool _2 = kwargs.find(val) != kwargs.end();
return _0 || _1 || _2;
}
};
@ -99,8 +103,7 @@ struct Py_ : PyObject {
Py_(Type type, T&& val): PyObject(type, sizeof(Py_<T>)), _value(std::move(val)) { _init(); }
inline void _init() noexcept {
if constexpr (std::is_same_v<T, Dummy> || std::is_same_v<T, Type>
|| std::is_same_v<T, pkpy::Function_> || std::is_same_v<T, pkpy::NativeFunc>) {
if constexpr (std::is_same_v<T, Dummy> || std::is_same_v<T, Type>) {
_attr = new pkpy::NameDict();
}else{
_attr = nullptr;

View File

@ -76,4 +76,6 @@ OPCODE(FAST_INDEX_REF) // a[x]
OPCODE(INPLACE_BINARY_OP)
OPCODE(INPLACE_BITWISE_OP)
OPCODE(SETUP_CLOSURE)
#endif

View File

@ -153,7 +153,6 @@ const Str __new__ = Str("__new__");
const Str __iter__ = Str("__iter__");
const Str __str__ = Str("__str__");
const Str __repr__ = Str("__repr__");
const Str __module__ = Str("__module__");
const Str __getitem__ = Str("__getitem__");
const Str __setitem__ = Str("__setitem__");
const Str __delitem__ = Str("__delitem__");

View File

@ -163,15 +163,15 @@ public:
TypeError("missing positional argument '" + name + "'");
}
locals.insert(fn->kwArgs.begin(), fn->kwArgs.end());
locals.insert(fn->kwargs.begin(), fn->kwargs.end());
std::vector<Str> positional_overrided_keys;
if(!fn->starredArg.empty()){
if(!fn->starred_arg.empty()){
pkpy::List vargs; // handle *args
while(i < args.size()) vargs.push_back(args[i++]);
locals.emplace(fn->starredArg, PyTuple(std::move(vargs)));
locals.emplace(fn->starred_arg, PyTuple(std::move(vargs)));
}else{
for(const auto& key : fn->kwArgsOrder){
for(const auto& key : fn->kwargs_order){
if(i < args.size()){
locals[key] = args[i++];
positional_overrided_keys.push_back(key);
@ -184,7 +184,7 @@ public:
for(int i=0; i<kwargs.size(); i+=2){
const Str& key = PyStr_AS_C(kwargs[i]);
if(!fn->kwArgs.contains(key)){
if(!fn->kwargs.contains(key)){
TypeError(key.escape(true) + " is an invalid keyword argument for " + fn->name + "()");
}
const PyVar& val = kwargs[i+1];
@ -196,10 +196,8 @@ public:
}
locals[key] = val;
}
PyVar* _m = (*callable)->attr().try_get(__module__);
PyVar _module = _m != nullptr ? *_m : top_frame()->_module;
auto _frame = _new_frame(fn->code, _module, _locals);
PyVar _module = fn->_module != nullptr ? fn->_module : top_frame()->_module;
auto _frame = _new_frame(fn->code, _module, _locals, fn->_closure);
if(fn->code->is_generator){
return PyIter(pkpy::make_shared<BaseIter, Generator>(
this, std::move(_frame)));
@ -208,7 +206,7 @@ public:
if(opCall) return _py_op_call;
return _exec();
}
TypeError("'" + OBJ_NAME(_t(*callable)) + "' object is not callable");
TypeError(OBJ_NAME(_t(*callable)).escape(true) + " object is not callable");
return None;
}
@ -716,6 +714,8 @@ PyVar NameRef::get(VM* vm, Frame* frame) const{
PyVar* val;
val = frame->f_locals().try_get(name());
if(val) return *val;
val = frame->f_closure_try_get(name());
if(val) return *val;
val = frame->f_globals().try_get(name());
if(val) return *val;
val = vm->builtins->attr().try_get(name());

19
tests/_closure.py Normal file
View File

@ -0,0 +1,19 @@
# only one level nested closure is implemented
def f0(a, b):
def f1():
return a + b
return f1
a = f0(1, 2)
assert a() == 3
def f0(a, b):
def f1():
a = 5 # use this first
return a + b
return f1
a = f0(1, 2)
assert a() == 7