diff --git a/include/pocketpy/codeobject.h b/include/pocketpy/codeobject.h index 7aed2355..9ad1df12 100644 --- a/include/pocketpy/codeobject.h +++ b/include/pocketpy/codeobject.h @@ -123,12 +123,13 @@ struct CodeObject { struct FuncDecl { struct KwArg { - int key; // index in co->varnames + int index; // index in co->varnames + StrName key; // name of this argument PyObject* value; // default value }; CodeObject_ code; // code object of this function - pod_vector args; // indices in co->varnames - pod_vector kwargs; // indices in co->varnames + std::vector args; // indices in co->varnames + std::vector kwargs; // indices in co->varnames int starred_arg = -1; // index in co->varnames, -1 if no *arg int starred_kwarg = -1; // index in co->varnames, -1 if no **kwarg bool nested = false; // whether this function is nested @@ -136,6 +137,11 @@ struct FuncDecl { Str signature; // signature of this function Str docstring; // docstring of this function bool is_simple; + + int keyword_to_index(StrName key) const{ + for(const KwArg& item: kwargs) if(item.key == key) return item.index; + return -1; + } void _gc_mark() const; }; diff --git a/src/compiler.cpp b/src/compiler.cpp index 4a2ccc2d..36d94219 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -1007,7 +1007,7 @@ __EAT_DOTS_END: } } for(auto& kv: decl->kwargs){ - if(decl->code->varnames[kv.key] == name){ + if(decl->code->varnames[kv.index] == name){ SyntaxError("duplicate argument name"); } } @@ -1037,7 +1037,7 @@ __EAT_DOTS_END: if(value == nullptr){ SyntaxError(Str("default argument must be a literal")); } - decl->kwargs.push_back(FuncDecl::KwArg{index, value}); + decl->kwargs.push_back(FuncDecl::KwArg{index, name, value}); } break; case 3: decl->starred_kwarg = index; diff --git a/src/vm.cpp b/src/vm.cpp index 1036222e..595c3a5c 100644 --- a/src/vm.cpp +++ b/src/vm.cpp @@ -580,7 +580,7 @@ PyObject* VM::new_module(Str name, Str package) { } static std::string _opcode_argstr(VM* vm, Bytecode byte, const CodeObject* co){ - std::string argStr = byte.arg == BC_NOARG ? "" : std::to_string(byte.arg); + std::string argStr = std::to_string(byte.arg); switch(byte.op){ case OP_LOAD_CONST: case OP_FORMAT_STRING: case OP_IMPORT_PATH: if(vm != nullptr){ @@ -818,7 +818,7 @@ void VM::_prepare_py_call(PyObject** buffer, ArgsView args, ArgsView kwargs, con // set extra varnames to PY_NULL for(int j=i; jkwargs) buffer[kv.key] = kv.value; + for(auto& kv: decl->kwargs) buffer[kv.index] = kv.value; // handle *args if(decl->starred_arg != -1){ @@ -829,7 +829,7 @@ void VM::_prepare_py_call(PyObject** buffer, ArgsView args, ArgsView kwargs, con // kwdefaults override for(auto& kv: decl->kwargs){ if(i >= args.size()) break; - buffer[kv.key] = args[i++]; + buffer[kv.index] = args[i++]; } if(i < args.size()) TypeError(fmt("too many arguments", " (", decl->code->name, ')')); } @@ -844,16 +844,18 @@ void VM::_prepare_py_call(PyObject** buffer, ArgsView args, ArgsView kwargs, con for(int j=0; jvarnames_inv.try_get_likely_found(key); - if(index < 0){ + int index = decl->keyword_to_index(key); + // if key is an explicit key, set as local variable + if(index != -1){ + buffer[index] = kwargs[j+1]; + }else{ + // otherwise, set as **kwargs if possible if(vkwargs == nullptr){ TypeError(fmt(key.escape(), " is an invalid keyword argument for ", co->name, "()")); }else{ Dict& dict = _CAST(Dict&, vkwargs); dict.set(VAR(key.sv()), kwargs[j+1]); } - }else{ - buffer[index] = kwargs[j+1]; } } } diff --git a/tests/99_bugs.py b/tests/99_bugs.py index 07f7949a..b6563d92 100644 --- a/tests/99_bugs.py +++ b/tests/99_bugs.py @@ -78,4 +78,26 @@ def f(): ++g f(); f() -assert g == 3 \ No newline at end of file +assert g == 3 + + +def f(**kw): + x = 1 + y = 2 + return kw, x, y +assert f(x=4, z=1) == ({'x': 4, 'z': 1}, 1, 2) + +def g(**kw): + x, y = 1, 2 + return kw + +ret = g( + a=1, b=2, c=3, d=4, e=5, f=6, g=7, h=8, i=9, + j=10, k=11, l=12, m=13, n=14, o=15, p=16, q=17, + r=18, s=19, t=20, u=21, v=22, w=23, x=24, y=25, + z=26 +) +assert ret == {chr(i+97): i+1 for i in range(26)} + +assert g(**ret) == ret +assert g(**g(**ret)) == ret