From e86baa2e2f067765d006666fb388a6842318def7 Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Fri, 29 Mar 2024 01:29:32 +0800 Subject: [PATCH] add optimized opcodes for `FOR_ITER`s --- docs/retype.yml | 2 +- include/pocketpy/common.h | 2 +- include/pocketpy/expr.h | 2 ++ include/pocketpy/opcodes.h | 3 +++ src/ceval.cpp | 39 +++++++++++++++++++++++++++++++++++--- src/compiler.cpp | 10 ++++------ src/expr.cpp | 28 ++++++++++++++++++++++++++- src/vm.cpp | 6 +++--- 8 files changed, 77 insertions(+), 15 deletions(-) diff --git a/docs/retype.yml b/docs/retype.yml index 215dff05..334974c0 100644 --- a/docs/retype.yml +++ b/docs/retype.yml @@ -3,7 +3,7 @@ output: .retype url: https://pocketpy.dev branding: title: pocketpy - label: v1.4.3 + label: v1.4.4 logo: "./static/logo.png" favicon: "./static/logo.png" meta: diff --git a/include/pocketpy/common.h b/include/pocketpy/common.h index 4bd2c098..c1ea3120 100644 --- a/include/pocketpy/common.h +++ b/include/pocketpy/common.h @@ -21,7 +21,7 @@ #include #include -#define PK_VERSION "1.4.3" +#define PK_VERSION "1.4.4" #include "config.h" #include "export.h" diff --git a/include/pocketpy/expr.h b/include/pocketpy/expr.h index d030c2ad..b5df6f58 100644 --- a/include/pocketpy/expr.h +++ b/include/pocketpy/expr.h @@ -105,6 +105,7 @@ struct CodeEmitContext{ void exit_block(); void emit_expr(); // clear the expression stack and generate bytecode int emit_(Opcode opcode, uint16_t arg, int line, bool is_virtual=false); + void revert_last_emit_(); int emit_int(i64 value, int line); void patch_jump(int index); bool add_label(StrName name); @@ -113,6 +114,7 @@ struct CodeEmitContext{ int add_const_string(std::string_view); int add_func_decl(FuncDecl_ decl); void emit_store_name(NameScope scope, StrName name, int line); + void try_merge_for_iter_store(int); }; struct NameExpr: Expr{ diff --git a/include/pocketpy/opcodes.h b/include/pocketpy/opcodes.h index bb98b568..269cf094 100644 --- a/include/pocketpy/opcodes.h +++ b/include/pocketpy/opcodes.h @@ -133,6 +133,9 @@ OPCODE(UNARY_INVERT) /**************************/ OPCODE(GET_ITER) OPCODE(FOR_ITER) +OPCODE(FOR_ITER_STORE_FAST) +OPCODE(FOR_ITER_STORE_GLOBAL) +OPCODE(FOR_ITER_YIELD_VALUE) /**************************/ OPCODE(IMPORT_PATH) OPCODE(POP_IMPORT_STAR) diff --git a/src/ceval.cpp b/src/ceval.cpp index 5fb54671..e559f779 100644 --- a/src/ceval.cpp +++ b/src/ceval.cpp @@ -719,9 +719,34 @@ __NEXT_STEP:; if(_0 != StopIteration){ PUSH(_0); }else{ - frame->jump_abs_break(&s_data, byte.arg); + frame->jump_abs_break(&s_data, co->_get_block_codei(frame->_ip).end); } } DISPATCH(); + TARGET(FOR_ITER_STORE_FAST){ + PyObject* _0 = py_next(TOP()); + if(_0 != StopIteration){ + frame->_locals[byte.arg] = _0; + }else{ + frame->jump_abs_break(&s_data, co->_get_block_codei(frame->_ip).end); + } + } DISPATCH() + TARGET(FOR_ITER_STORE_GLOBAL){ + PyObject* _0 = py_next(TOP()); + if(_0 != StopIteration){ + frame->f_globals().set(StrName(byte.arg), _0); + }else{ + frame->jump_abs_break(&s_data, co->_get_block_codei(frame->_ip).end); + } + } DISPATCH() + TARGET(FOR_ITER_YIELD_VALUE){ + PyObject* _0 = py_next(TOP()); + if(_0 != StopIteration){ + PUSH(_0); + return PY_OP_YIELD; + }else{ + frame->jump_abs_break(&s_data, co->_get_block_codei(frame->_ip).end); + } + } DISPATCH() /*****************************************/ TARGET(IMPORT_PATH){ PyObject* _0 = co->consts[byte.arg]; @@ -877,8 +902,16 @@ __NEXT_STEP:; *p = VAR(CAST(i64, *p) - 1); } DISPATCH(); /*****************************************/ - static_assert(OP_DEC_GLOBAL == 133); - case 134: case 135: case 136: case 137: case 138: case 139: case 140: case 141: case 142: case 143: case 144: case 145: case 146: case 147: case 148: case 149: case 150: case 151: case 152: case 153: case 154: case 155: case 156: case 157: case 158: case 159: case 160: case 161: case 162: case 163: case 164: case 165: case 166: case 167: case 168: case 169: case 170: case 171: case 172: case 173: case 174: case 175: case 176: case 177: case 178: case 179: case 180: case 181: case 182: case 183: case 184: case 185: case 186: case 187: case 188: case 189: case 190: case 191: case 192: case 193: case 194: case 195: case 196: case 197: case 198: case 199: case 200: case 201: case 202: case 203: case 204: case 205: case 206: case 207: case 208: case 209: case 210: case 211: case 212: case 213: case 214: case 215: case 216: case 217: case 218: case 219: case 220: case 221: case 222: case 223: case 224: case 225: case 226: case 227: case 228: case 229: case 230: case 231: case 232: case 233: case 234: case 235: case 236: case 237: case 238: case 239: case 240: case 241: case 242: case 243: case 244: case 245: case 246: case 247: case 248: case 249: case 250: case 251: case 252: case 253: case 254: case 255: PK_UNREACHABLE() break; + static_assert(OP_DEC_GLOBAL == 136); + case 137: case 138: case 139: case 140: case 141: case 142: case 143: case 144: case 145: case 146: case 147: case 148: case 149: + case 150: case 151: case 152: case 153: case 154: case 155: case 156: case 157: case 158: case 159: case 160: case 161: case 162: case 163: case 164: + case 165: case 166: case 167: case 168: case 169: case 170: case 171: case 172: case 173: case 174: case 175: case 176: case 177: case 178: case 179: + case 180: case 181: case 182: case 183: case 184: case 185: case 186: case 187: case 188: case 189: case 190: case 191: case 192: case 193: case 194: + case 195: case 196: case 197: case 198: case 199: case 200: case 201: case 202: case 203: case 204: case 205: case 206: case 207: case 208: case 209: + case 210: case 211: case 212: case 213: case 214: case 215: case 216: case 217: case 218: case 219: case 220: case 221: case 222: case 223: case 224: + case 225: case 226: case 227: case 228: case 229: case 230: case 231: case 232: case 233: case 234: case 235: case 236: case 237: case 238: case 239: + case 240: case 241: case 242: case 243: case 244: case 245: case 246: case 247: case 248: case 249: case 250: case 251: case 252: case 253: case 254: + case 255: break; } } diff --git a/src/compiler.cpp b/src/compiler.cpp index 1fff3aed..9b6bc251 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -50,15 +50,13 @@ namespace pkpy{ // json mode does not contain jump instructions, so it is safe to ignore this check SyntaxError("maximum number of opcodes exceeded"); } - // pre-compute LOOP_BREAK and LOOP_CONTINUE and FOR_ITER + // pre-compute LOOP_BREAK and LOOP_CONTINUE for(int i=0; ico->blocks[bc.arg].start; }else if(bc.op == OP_LOOP_BREAK){ bc.arg = ctx()->co->blocks[bc.arg].get_break_end(); - }else if(bc.op == OP_FOR_ITER){ - bc.arg = ctx()->co->_get_block_codei(i).end; } } // pre-compute func->is_simple @@ -658,9 +656,10 @@ __EAT_DOTS_END: EXPR_TUPLE(); ctx()->emit_expr(); ctx()->emit_(OP_GET_ITER, BC_NOARG, BC_KEEPLINE); CodeBlock* block = ctx()->enter_block(CodeBlockType::FOR_LOOP); - ctx()->emit_(OP_FOR_ITER, BC_NOARG, BC_KEEPLINE); + int for_codei = ctx()->emit_(OP_FOR_ITER, BC_NOARG, BC_KEEPLINE); bool ok = vars->emit_store(ctx()); if(!ok) SyntaxError(); // this error occurs in `vars` instead of this line, but...nevermind + ctx()->try_merge_for_iter_store(for_codei); compile_block_body(); ctx()->emit_(OP_LOOP_CONTINUE, ctx()->get_loop(), BC_KEEPLINE, true); ctx()->exit_block(); @@ -822,8 +821,7 @@ __EAT_DOTS_END: ctx()->co->is_generator = true; ctx()->emit_(OP_GET_ITER, BC_NOARG, kw_line); ctx()->enter_block(CodeBlockType::FOR_LOOP); - ctx()->emit_(OP_FOR_ITER, BC_NOARG, kw_line); - ctx()->emit_(OP_YIELD_VALUE, BC_NOARG, kw_line); + ctx()->emit_(OP_FOR_ITER_YIELD_VALUE, BC_NOARG, kw_line); ctx()->emit_(OP_LOOP_CONTINUE, ctx()->get_loop(), kw_line); ctx()->exit_block(); consume_end_stmt(); diff --git a/src/expr.cpp b/src/expr.cpp index 76d5981d..132cbcac 100644 --- a/src/expr.cpp +++ b/src/expr.cpp @@ -60,6 +60,31 @@ namespace pkpy{ return i; } + void CodeEmitContext::revert_last_emit_(){ + co->codes.pop_back(); + co->iblocks.pop_back(); + co->lines.pop_back(); + } + + void CodeEmitContext::try_merge_for_iter_store(int i){ + // [FOR_ITER, STORE_?, ] + if(co->codes[i].op != OP_FOR_ITER) return; + if(co->codes.size() - i != 2) return; + uint16_t arg = co->codes[i+1].arg; + if(co->codes[i+1].op == OP_STORE_FAST){ + revert_last_emit_(); + co->codes[i].op = OP_FOR_ITER_STORE_FAST; + co->codes[i].arg = arg; + return; + } + if(co->codes[i+1].op == OP_STORE_GLOBAL){ + revert_last_emit_(); + co->codes[i].op = OP_FOR_ITER_STORE_GLOBAL; + co->codes[i].arg = arg; + return; + } + } + int CodeEmitContext::emit_int(i64 value, int line){ bool allow_neg_int = is_negative_shift_well_defined() || value >= 0; if(allow_neg_int && value >= -5 && value <= 16){ @@ -370,10 +395,11 @@ namespace pkpy{ iter->emit_(ctx); ctx->emit_(OP_GET_ITER, BC_NOARG, BC_KEEPLINE); ctx->enter_block(CodeBlockType::FOR_LOOP); - ctx->emit_(OP_FOR_ITER, BC_NOARG, BC_KEEPLINE); + int for_codei = ctx->emit_(OP_FOR_ITER, BC_NOARG, BC_KEEPLINE); bool ok = vars->emit_store(ctx); // this error occurs in `vars` instead of this line, but...nevermind PK_ASSERT(ok); // TODO: raise a SyntaxError instead + ctx->try_merge_for_iter_store(for_codei); if(cond){ cond->emit_(ctx); int patch = ctx->emit_(OP_POP_JUMP_IF_FALSE, BC_NOARG, BC_KEEPLINE); diff --git a/src/vm.cpp b/src/vm.cpp index 327b0bd2..68b2e387 100644 --- a/src/vm.cpp +++ b/src/vm.cpp @@ -573,10 +573,10 @@ static std::string _opcode_argstr(VM* vm, Bytecode byte, const CodeObject* co){ case OP_LOAD_NAME: case OP_LOAD_GLOBAL: case OP_LOAD_NONLOCAL: case OP_STORE_GLOBAL: case OP_LOAD_ATTR: case OP_LOAD_METHOD: case OP_STORE_ATTR: case OP_DELETE_ATTR: case OP_BEGIN_CLASS: case OP_GOTO: - case OP_DELETE_GLOBAL: case OP_INC_GLOBAL: case OP_DEC_GLOBAL: case OP_STORE_CLASS_ATTR: + case OP_DELETE_GLOBAL: case OP_INC_GLOBAL: case OP_DEC_GLOBAL: case OP_STORE_CLASS_ATTR: case OP_FOR_ITER_STORE_GLOBAL: argStr += _S(" (", StrName(byte.arg).sv(), ")").sv(); break; - case OP_LOAD_FAST: case OP_STORE_FAST: case OP_DELETE_FAST: case OP_INC_FAST: case OP_DEC_FAST: + case OP_LOAD_FAST: case OP_STORE_FAST: case OP_DELETE_FAST: case OP_INC_FAST: case OP_DEC_FAST: case OP_FOR_ITER_STORE_FAST: argStr += _S(" (", co->varnames[byte.arg].sv(), ")").sv(); break; case OP_LOAD_FUNCTION: @@ -594,7 +594,7 @@ Str VM::disassemble(CodeObject_ co){ pod_vector jumpTargets; for(auto byte : co->codes){ - if(byte.op == OP_JUMP_ABSOLUTE || byte.op == OP_POP_JUMP_IF_FALSE || byte.op == OP_SHORTCUT_IF_FALSE_OR_POP || byte.op == OP_FOR_ITER){ + if(byte.op == OP_JUMP_ABSOLUTE || byte.op == OP_POP_JUMP_IF_FALSE || byte.op == OP_SHORTCUT_IF_FALSE_OR_POP){ jumpTargets.push_back(byte.arg); } if(byte.op == OP_GOTO){