add optimized opcodes for FOR_ITERs

This commit is contained in:
blueloveTH 2024-03-29 01:29:32 +08:00
parent 41e9900b37
commit e86baa2e2f
8 changed files with 77 additions and 15 deletions

View File

@ -3,7 +3,7 @@ output: .retype
url: https://pocketpy.dev
branding:
title: pocketpy
label: v1.4.3
label: v1.4.4
logo: "./static/logo.png"
favicon: "./static/logo.png"
meta:

View File

@ -21,7 +21,7 @@
#include <typeinfo>
#include <initializer_list>
#define PK_VERSION "1.4.3"
#define PK_VERSION "1.4.4"
#include "config.h"
#include "export.h"

View File

@ -105,6 +105,7 @@ struct CodeEmitContext{
void exit_block();
void emit_expr(); // clear the expression stack and generate bytecode
int emit_(Opcode opcode, uint16_t arg, int line, bool is_virtual=false);
void revert_last_emit_();
int emit_int(i64 value, int line);
void patch_jump(int index);
bool add_label(StrName name);
@ -113,6 +114,7 @@ struct CodeEmitContext{
int add_const_string(std::string_view);
int add_func_decl(FuncDecl_ decl);
void emit_store_name(NameScope scope, StrName name, int line);
void try_merge_for_iter_store(int);
};
struct NameExpr: Expr{

View File

@ -133,6 +133,9 @@ OPCODE(UNARY_INVERT)
/**************************/
OPCODE(GET_ITER)
OPCODE(FOR_ITER)
OPCODE(FOR_ITER_STORE_FAST)
OPCODE(FOR_ITER_STORE_GLOBAL)
OPCODE(FOR_ITER_YIELD_VALUE)
/**************************/
OPCODE(IMPORT_PATH)
OPCODE(POP_IMPORT_STAR)

View File

@ -719,9 +719,34 @@ __NEXT_STEP:;
if(_0 != StopIteration){
PUSH(_0);
}else{
frame->jump_abs_break(&s_data, byte.arg);
frame->jump_abs_break(&s_data, co->_get_block_codei(frame->_ip).end);
}
} DISPATCH();
TARGET(FOR_ITER_STORE_FAST){
PyObject* _0 = py_next(TOP());
if(_0 != StopIteration){
frame->_locals[byte.arg] = _0;
}else{
frame->jump_abs_break(&s_data, co->_get_block_codei(frame->_ip).end);
}
} DISPATCH()
TARGET(FOR_ITER_STORE_GLOBAL){
PyObject* _0 = py_next(TOP());
if(_0 != StopIteration){
frame->f_globals().set(StrName(byte.arg), _0);
}else{
frame->jump_abs_break(&s_data, co->_get_block_codei(frame->_ip).end);
}
} DISPATCH()
TARGET(FOR_ITER_YIELD_VALUE){
PyObject* _0 = py_next(TOP());
if(_0 != StopIteration){
PUSH(_0);
return PY_OP_YIELD;
}else{
frame->jump_abs_break(&s_data, co->_get_block_codei(frame->_ip).end);
}
} DISPATCH()
/*****************************************/
TARGET(IMPORT_PATH){
PyObject* _0 = co->consts[byte.arg];
@ -877,8 +902,16 @@ __NEXT_STEP:;
*p = VAR(CAST(i64, *p) - 1);
} DISPATCH();
/*****************************************/
static_assert(OP_DEC_GLOBAL == 133);
case 134: case 135: case 136: case 137: case 138: case 139: case 140: case 141: case 142: case 143: case 144: case 145: case 146: case 147: case 148: case 149: case 150: case 151: case 152: case 153: case 154: case 155: case 156: case 157: case 158: case 159: case 160: case 161: case 162: case 163: case 164: case 165: case 166: case 167: case 168: case 169: case 170: case 171: case 172: case 173: case 174: case 175: case 176: case 177: case 178: case 179: case 180: case 181: case 182: case 183: case 184: case 185: case 186: case 187: case 188: case 189: case 190: case 191: case 192: case 193: case 194: case 195: case 196: case 197: case 198: case 199: case 200: case 201: case 202: case 203: case 204: case 205: case 206: case 207: case 208: case 209: case 210: case 211: case 212: case 213: case 214: case 215: case 216: case 217: case 218: case 219: case 220: case 221: case 222: case 223: case 224: case 225: case 226: case 227: case 228: case 229: case 230: case 231: case 232: case 233: case 234: case 235: case 236: case 237: case 238: case 239: case 240: case 241: case 242: case 243: case 244: case 245: case 246: case 247: case 248: case 249: case 250: case 251: case 252: case 253: case 254: case 255: PK_UNREACHABLE() break;
static_assert(OP_DEC_GLOBAL == 136);
case 137: case 138: case 139: case 140: case 141: case 142: case 143: case 144: case 145: case 146: case 147: case 148: case 149:
case 150: case 151: case 152: case 153: case 154: case 155: case 156: case 157: case 158: case 159: case 160: case 161: case 162: case 163: case 164:
case 165: case 166: case 167: case 168: case 169: case 170: case 171: case 172: case 173: case 174: case 175: case 176: case 177: case 178: case 179:
case 180: case 181: case 182: case 183: case 184: case 185: case 186: case 187: case 188: case 189: case 190: case 191: case 192: case 193: case 194:
case 195: case 196: case 197: case 198: case 199: case 200: case 201: case 202: case 203: case 204: case 205: case 206: case 207: case 208: case 209:
case 210: case 211: case 212: case 213: case 214: case 215: case 216: case 217: case 218: case 219: case 220: case 221: case 222: case 223: case 224:
case 225: case 226: case 227: case 228: case 229: case 230: case 231: case 232: case 233: case 234: case 235: case 236: case 237: case 238: case 239:
case 240: case 241: case 242: case 243: case 244: case 245: case 246: case 247: case 248: case 249: case 250: case 251: case 252: case 253: case 254:
case 255: break;
}
}

View File

@ -50,15 +50,13 @@ namespace pkpy{
// json mode does not contain jump instructions, so it is safe to ignore this check
SyntaxError("maximum number of opcodes exceeded");
}
// pre-compute LOOP_BREAK and LOOP_CONTINUE and FOR_ITER
// pre-compute LOOP_BREAK and LOOP_CONTINUE
for(int i=0; i<codes.size(); i++){
Bytecode& bc = codes[i];
if(bc.op == OP_LOOP_CONTINUE){
bc.arg = ctx()->co->blocks[bc.arg].start;
}else if(bc.op == OP_LOOP_BREAK){
bc.arg = ctx()->co->blocks[bc.arg].get_break_end();
}else if(bc.op == OP_FOR_ITER){
bc.arg = ctx()->co->_get_block_codei(i).end;
}
}
// pre-compute func->is_simple
@ -658,9 +656,10 @@ __EAT_DOTS_END:
EXPR_TUPLE(); ctx()->emit_expr();
ctx()->emit_(OP_GET_ITER, BC_NOARG, BC_KEEPLINE);
CodeBlock* block = ctx()->enter_block(CodeBlockType::FOR_LOOP);
ctx()->emit_(OP_FOR_ITER, BC_NOARG, BC_KEEPLINE);
int for_codei = ctx()->emit_(OP_FOR_ITER, BC_NOARG, BC_KEEPLINE);
bool ok = vars->emit_store(ctx());
if(!ok) SyntaxError(); // this error occurs in `vars` instead of this line, but...nevermind
ctx()->try_merge_for_iter_store(for_codei);
compile_block_body();
ctx()->emit_(OP_LOOP_CONTINUE, ctx()->get_loop(), BC_KEEPLINE, true);
ctx()->exit_block();
@ -822,8 +821,7 @@ __EAT_DOTS_END:
ctx()->co->is_generator = true;
ctx()->emit_(OP_GET_ITER, BC_NOARG, kw_line);
ctx()->enter_block(CodeBlockType::FOR_LOOP);
ctx()->emit_(OP_FOR_ITER, BC_NOARG, kw_line);
ctx()->emit_(OP_YIELD_VALUE, BC_NOARG, kw_line);
ctx()->emit_(OP_FOR_ITER_YIELD_VALUE, BC_NOARG, kw_line);
ctx()->emit_(OP_LOOP_CONTINUE, ctx()->get_loop(), kw_line);
ctx()->exit_block();
consume_end_stmt();

View File

@ -60,6 +60,31 @@ namespace pkpy{
return i;
}
void CodeEmitContext::revert_last_emit_(){
co->codes.pop_back();
co->iblocks.pop_back();
co->lines.pop_back();
}
void CodeEmitContext::try_merge_for_iter_store(int i){
// [FOR_ITER, STORE_?, ]
if(co->codes[i].op != OP_FOR_ITER) return;
if(co->codes.size() - i != 2) return;
uint16_t arg = co->codes[i+1].arg;
if(co->codes[i+1].op == OP_STORE_FAST){
revert_last_emit_();
co->codes[i].op = OP_FOR_ITER_STORE_FAST;
co->codes[i].arg = arg;
return;
}
if(co->codes[i+1].op == OP_STORE_GLOBAL){
revert_last_emit_();
co->codes[i].op = OP_FOR_ITER_STORE_GLOBAL;
co->codes[i].arg = arg;
return;
}
}
int CodeEmitContext::emit_int(i64 value, int line){
bool allow_neg_int = is_negative_shift_well_defined() || value >= 0;
if(allow_neg_int && value >= -5 && value <= 16){
@ -370,10 +395,11 @@ namespace pkpy{
iter->emit_(ctx);
ctx->emit_(OP_GET_ITER, BC_NOARG, BC_KEEPLINE);
ctx->enter_block(CodeBlockType::FOR_LOOP);
ctx->emit_(OP_FOR_ITER, BC_NOARG, BC_KEEPLINE);
int for_codei = ctx->emit_(OP_FOR_ITER, BC_NOARG, BC_KEEPLINE);
bool ok = vars->emit_store(ctx);
// this error occurs in `vars` instead of this line, but...nevermind
PK_ASSERT(ok); // TODO: raise a SyntaxError instead
ctx->try_merge_for_iter_store(for_codei);
if(cond){
cond->emit_(ctx);
int patch = ctx->emit_(OP_POP_JUMP_IF_FALSE, BC_NOARG, BC_KEEPLINE);

View File

@ -573,10 +573,10 @@ static std::string _opcode_argstr(VM* vm, Bytecode byte, const CodeObject* co){
case OP_LOAD_NAME: case OP_LOAD_GLOBAL: case OP_LOAD_NONLOCAL: case OP_STORE_GLOBAL:
case OP_LOAD_ATTR: case OP_LOAD_METHOD: case OP_STORE_ATTR: case OP_DELETE_ATTR:
case OP_BEGIN_CLASS: case OP_GOTO:
case OP_DELETE_GLOBAL: case OP_INC_GLOBAL: case OP_DEC_GLOBAL: case OP_STORE_CLASS_ATTR:
case OP_DELETE_GLOBAL: case OP_INC_GLOBAL: case OP_DEC_GLOBAL: case OP_STORE_CLASS_ATTR: case OP_FOR_ITER_STORE_GLOBAL:
argStr += _S(" (", StrName(byte.arg).sv(), ")").sv();
break;
case OP_LOAD_FAST: case OP_STORE_FAST: case OP_DELETE_FAST: case OP_INC_FAST: case OP_DEC_FAST:
case OP_LOAD_FAST: case OP_STORE_FAST: case OP_DELETE_FAST: case OP_INC_FAST: case OP_DEC_FAST: case OP_FOR_ITER_STORE_FAST:
argStr += _S(" (", co->varnames[byte.arg].sv(), ")").sv();
break;
case OP_LOAD_FUNCTION:
@ -594,7 +594,7 @@ Str VM::disassemble(CodeObject_ co){
pod_vector<int> jumpTargets;
for(auto byte : co->codes){
if(byte.op == OP_JUMP_ABSOLUTE || byte.op == OP_POP_JUMP_IF_FALSE || byte.op == OP_SHORTCUT_IF_FALSE_OR_POP || byte.op == OP_FOR_ITER){
if(byte.op == OP_JUMP_ABSOLUTE || byte.op == OP_POP_JUMP_IF_FALSE || byte.op == OP_SHORTCUT_IF_FALSE_OR_POP){
jumpTargets.push_back(byte.arg);
}
if(byte.op == OP_GOTO){