add a fast path for int

This commit is contained in:
blueloveTH 2023-04-09 18:01:54 +08:00
parent 4de81edd4a
commit da760c301c
6 changed files with 110 additions and 51 deletions

View File

@ -180,18 +180,54 @@ __NEXT_STEP:;
args[0] = frame->top(); // rhs
frame->top() = fast_call(BINARY_SPECIAL_METHODS[byte.arg], std::move(args));
} DISPATCH();
case OP_COMPARE_OP: {
Args args(2);
args[1] = frame->popx(); // lhs
args[0] = frame->top(); // rhs
frame->top() = fast_call(COMPARE_SPECIAL_METHODS[byte.arg], std::move(args));
} DISPATCH();
case OP_BITWISE_OP: {
Args args(2);
args[1] = frame->popx(); // lhs
args[0] = frame->top(); // rhs
frame->top() = fast_call(BITWISE_SPECIAL_METHODS[byte.arg], std::move(args));
} DISPATCH();
#define INT_BINARY_OP(op, func) \
if(is_both_int(frame->top(), frame->top_1())){ \
i64 b = _CAST(i64, frame->top()); \
i64 a = _CAST(i64, frame->top_1()); \
frame->pop(); \
frame->top() = VAR(a op b); \
}else{ \
Args args(2); \
args[1] = frame->popx(); \
args[0] = frame->top(); \
frame->top() = fast_call(func, std::move(args));\
} \
DISPATCH();
case OP_BINARY_ADD:
INT_BINARY_OP(+, __add__)
case OP_BINARY_SUB:
INT_BINARY_OP(-, __sub__)
case OP_BINARY_MUL:
INT_BINARY_OP(*, __mul__)
case OP_BINARY_FLOORDIV:
INT_BINARY_OP(/, __floordiv__)
case OP_BINARY_MOD:
INT_BINARY_OP(%, __mod__)
case OP_COMPARE_LT:
INT_BINARY_OP(<, __lt__)
case OP_COMPARE_LE:
INT_BINARY_OP(<=, __le__)
case OP_COMPARE_EQ:
INT_BINARY_OP(==, __eq__)
case OP_COMPARE_NE:
INT_BINARY_OP(!=, __ne__)
case OP_COMPARE_GT:
INT_BINARY_OP(>, __gt__)
case OP_COMPARE_GE:
INT_BINARY_OP(>=, __ge__)
case OP_BITWISE_LSHIFT:
INT_BINARY_OP(<<, __lshift__)
case OP_BITWISE_RSHIFT:
INT_BINARY_OP(>>, __rshift__)
case OP_BITWISE_AND:
INT_BINARY_OP(&, __and__)
case OP_BITWISE_OR:
INT_BINARY_OP(|, __or__)
case OP_BITWISE_XOR:
INT_BINARY_OP(^, __xor__)
#undef INT_BINARY_OP
case OP_IS_OP: {
PyObject* rhs = frame->popx();
PyObject* lhs = frame->top();

View File

@ -650,30 +650,30 @@ struct BinaryExpr: Expr{
lhs->emit(ctx);
rhs->emit(ctx);
switch (op) {
case TK("+"): ctx->emit(OP_BINARY_OP, 0, line); break;
case TK("-"): ctx->emit(OP_BINARY_OP, 1, line); break;
case TK("*"): ctx->emit(OP_BINARY_OP, 2, line); break;
case TK("+"): ctx->emit(OP_BINARY_ADD, BC_NOARG, line); break;
case TK("-"): ctx->emit(OP_BINARY_SUB, BC_NOARG, line); break;
case TK("*"): ctx->emit(OP_BINARY_MUL, BC_NOARG, line); break;
case TK("/"): ctx->emit(OP_BINARY_OP, 3, line); break;
case TK("//"): ctx->emit(OP_BINARY_OP, 4, line); break;
case TK("%"): ctx->emit(OP_BINARY_OP, 5, line); break;
case TK("//"): ctx->emit(OP_BINARY_FLOORDIV, BC_NOARG, line); break;
case TK("%"): ctx->emit(OP_BINARY_MOD, BC_NOARG, line); break;
case TK("**"): ctx->emit(OP_BINARY_OP, 6, line); break;
case TK("<"): ctx->emit(OP_COMPARE_OP, 0, line); break;
case TK("<="): ctx->emit(OP_COMPARE_OP, 1, line); break;
case TK("=="): ctx->emit(OP_COMPARE_OP, 2, line); break;
case TK("!="): ctx->emit(OP_COMPARE_OP, 3, line); break;
case TK(">"): ctx->emit(OP_COMPARE_OP, 4, line); break;
case TK(">="): ctx->emit(OP_COMPARE_OP, 5, line); break;
case TK("<"): ctx->emit(OP_COMPARE_LT, BC_NOARG, line); break;
case TK("<="): ctx->emit(OP_COMPARE_LE, BC_NOARG, line); break;
case TK("=="): ctx->emit(OP_COMPARE_EQ, BC_NOARG, line); break;
case TK("!="): ctx->emit(OP_COMPARE_NE, BC_NOARG, line); break;
case TK(">"): ctx->emit(OP_COMPARE_GT, BC_NOARG, line); break;
case TK(">="): ctx->emit(OP_COMPARE_GE, BC_NOARG, line); break;
case TK("in"): ctx->emit(OP_CONTAINS_OP, 0, line); break;
case TK("not in"): ctx->emit(OP_CONTAINS_OP, 1, line); break;
case TK("is"): ctx->emit(OP_IS_OP, 0, line); break;
case TK("is not"): ctx->emit(OP_IS_OP, 1, line); break;
case TK("<<"): ctx->emit(OP_BITWISE_OP, 0, line); break;
case TK(">>"): ctx->emit(OP_BITWISE_OP, 1, line); break;
case TK("&"): ctx->emit(OP_BITWISE_OP, 2, line); break;
case TK("|"): ctx->emit(OP_BITWISE_OP, 3, line); break;
case TK("^"): ctx->emit(OP_BITWISE_OP, 4, line); break;
case TK("<<"): ctx->emit(OP_BITWISE_LSHIFT, BC_NOARG, line); break;
case TK(">>"): ctx->emit(OP_BITWISE_RSHIFT, BC_NOARG, line); break;
case TK("&"): ctx->emit(OP_BITWISE_AND, BC_NOARG, line); break;
case TK("|"): ctx->emit(OP_BITWISE_OR, BC_NOARG, line); break;
case TK("^"): ctx->emit(OP_BITWISE_XOR, BC_NOARG, line); break;
default: UNREACHABLE();
}
}

View File

@ -41,8 +41,26 @@ OPCODE(BUILD_TUPLE)
OPCODE(BUILD_STRING)
/**************************/
OPCODE(BINARY_OP)
OPCODE(BINARY_ADD)
OPCODE(BINARY_SUB)
OPCODE(BINARY_MUL)
OPCODE(BINARY_FLOORDIV)
OPCODE(BINARY_MOD)
OPCODE(COMPARE_OP)
OPCODE(BITWISE_OP)
OPCODE(COMPARE_LT)
OPCODE(COMPARE_LE)
OPCODE(COMPARE_EQ)
OPCODE(COMPARE_NE)
OPCODE(COMPARE_GT)
OPCODE(COMPARE_GE)
OPCODE(BITWISE_LSHIFT)
OPCODE(BITWISE_RSHIFT)
OPCODE(BITWISE_AND)
OPCODE(BITWISE_OR)
OPCODE(BITWISE_XOR)
OPCODE(IS_OP)
OPCODE(CONTAINS_OP)
/**************************/

View File

@ -25,23 +25,23 @@ inline CodeObject_ VM::compile(Str source, Str filename, CompileMode mode) {
}
#define BIND_NUM_ARITH_OPT(name, op) \
_vm->_bind_methods<1>({"int","float"}, #name, [](VM* vm, Args& args){ \
_vm->_bind_methods<1>({"int","float"}, #name, [](VM* vm, Args& args){ \
if(is_both_int(args[0], args[1])){ \
return VAR(_CAST(i64, args[0]) op _CAST(i64, args[1])); \
return VAR(_CAST(i64, args[0]) op _CAST(i64, args[1])); \
}else{ \
return VAR(vm->num_to_float(args[0]) op vm->num_to_float(args[1])); \
return VAR(vm->num_to_float(args[0]) op vm->num_to_float(args[1])); \
} \
});
#define BIND_NUM_LOGICAL_OPT(name, op, is_eq) \
_vm->_bind_methods<1>({"int","float"}, #name, [](VM* vm, Args& args){ \
_vm->_bind_methods<1>({"int","float"}, #name, [](VM* vm, Args& args){ \
if(is_both_int(args[0], args[1])) \
return VAR(_CAST(i64, args[0]) op _CAST(i64, args[1])); \
if(!is_both_int_or_float(args[0], args[1])){ \
if constexpr(is_eq) return VAR(args[0] op args[1]); \
if constexpr(is_eq) return VAR(args[0] op args[1]); \
vm->TypeError("unsupported operand type(s) for " #op ); \
} \
if(is_both_int(args[0], args[1])) \
return VAR(_CAST(i64, args[0]) op _CAST(i64, args[1])); \
return VAR(vm->num_to_float(args[0]) op vm->num_to_float(args[1])); \
return VAR(vm->num_to_float(args[0]) op vm->num_to_float(args[1])); \
});

View File

@ -404,10 +404,13 @@ const StrName m_add = StrName::get("add");
const StrName __enter__ = StrName::get("__enter__");
const StrName __exit__ = StrName::get("__exit__");
const StrName COMPARE_SPECIAL_METHODS[] = {
StrName::get("__lt__"), StrName::get("__le__"), StrName::get("__eq__"),
StrName::get("__ne__"), StrName::get("__gt__"), StrName::get("__ge__")
};
const StrName __add__ = StrName::get("__add__");
const StrName __sub__ = StrName::get("__sub__");
const StrName __mul__ = StrName::get("__mul__");
// const StrName __truediv__ = StrName::get("__truediv__");
const StrName __floordiv__ = StrName::get("__floordiv__");
const StrName __mod__ = StrName::get("__mod__");
// const StrName __pow__ = StrName::get("__pow__");
const StrName BINARY_SPECIAL_METHODS[] = {
StrName::get("__add__"), StrName::get("__sub__"), StrName::get("__mul__"),
@ -415,9 +418,17 @@ const StrName BINARY_SPECIAL_METHODS[] = {
StrName::get("__mod__"), StrName::get("__pow__")
};
const StrName BITWISE_SPECIAL_METHODS[] = {
StrName::get("__lshift__"), StrName::get("__rshift__"),
StrName::get("__and__"), StrName::get("__or__"), StrName::get("__xor__")
};
const StrName __lt__ = StrName::get("__lt__");
const StrName __le__ = StrName::get("__le__");
const StrName __eq__ = StrName::get("__eq__");
const StrName __ne__ = StrName::get("__ne__");
const StrName __gt__ = StrName::get("__gt__");
const StrName __ge__ = StrName::get("__ge__");
const StrName __lshift__ = StrName::get("__lshift__");
const StrName __rshift__ = StrName::get("__rshift__");
const StrName __and__ = StrName::get("__and__");
const StrName __or__ = StrName::get("__or__");
const StrName __xor__ = StrName::get("__xor__");
} // namespace pkpy

View File

@ -592,12 +592,6 @@ inline Str VM::disassemble(CodeObject_ co){
case OP_BINARY_OP:
argStr += fmt(" (", BINARY_SPECIAL_METHODS[byte.arg], ")");
break;
case OP_COMPARE_OP:
argStr += fmt(" (", COMPARE_SPECIAL_METHODS[byte.arg], ")");
break;
case OP_BITWISE_OP:
argStr += fmt(" (", BITWISE_SPECIAL_METHODS[byte.arg], ")");
break;
}
ss << pad(argStr, 40); // may overflow
ss << co->blocks[byte.block].type;