diff --git a/include/pocketpy/common/str.h b/include/pocketpy/common/str.h index 398d2ca9..6a708f6c 100644 --- a/include/pocketpy/common/str.h +++ b/include/pocketpy/common/str.h @@ -65,6 +65,7 @@ int c11__byte_index_to_unicode(const char* data, int n); bool c11__is_unicode_Lo_char(int c); int c11__u8_header(unsigned char c, bool suppress); +int c11__u8_value(int u8bytes, const char* data); typedef enum IntParsingResult { IntParsing_SUCCESS, diff --git a/src/common/str.c b/src/common/str.c index 1b5a9df0..8c7c3994 100644 --- a/src/common/str.c +++ b/src/common/str.c @@ -297,6 +297,26 @@ int c11__u8_header(unsigned char c, bool suppress) { return 0; } +int c11__u8_value(int u8bytes, const char* data) { + assert(u8bytes != 0); + if(u8bytes == 1) return (int)data[0]; + uint32_t value = 0; + for(int k = 0; k < u8bytes; k++) { + uint8_t b = data[k]; + if(k == 0) { + if(u8bytes == 2) + value = (b & 0b00011111) << 6; + else if(u8bytes == 3) + value = (b & 0b00001111) << 12; + else if(u8bytes == 4) + value = (b & 0b00000111) << 18; + } else { + value |= (b & 0b00111111) << (6 * (u8bytes - k - 1)); + } + } + return (int)value; +} + IntParsingResult c11__parse_uint(c11_sv text, int64_t* out, int base) { *out = 0; diff --git a/src/compiler/lexer.c b/src/compiler/lexer.c index 884efbeb..ffa2f632 100644 --- a/src/compiler/lexer.c +++ b/src/compiler/lexer.c @@ -225,21 +225,7 @@ static Error* eat_name(Lexer* self) { break; } } - // handle multibyte char - uint32_t value = 0; - for(int k = 0; k < u8bytes; k++) { - uint8_t b = self->curr_char[k]; - if(k == 0) { - if(u8bytes == 2) - value = (b & 0b00011111) << 6; - else if(u8bytes == 3) - value = (b & 0b00001111) << 12; - else if(u8bytes == 4) - value = (b & 0b00000111) << 18; - } else { - value |= (b & 0b00111111) << (6 * (u8bytes - k - 1)); - } - } + int value = c11__u8_value(u8bytes, self->curr_char); if(c11__is_unicode_Lo_char(value)) { self->curr_char += u8bytes; } else { diff --git a/src/public/modules.c b/src/public/modules.c index 9c021311..11037057 100644 --- a/src/public/modules.c +++ b/src/public/modules.c @@ -441,10 +441,15 @@ static bool builtins_ord(int argc, py_Ref argv) { PY_CHECK_ARGC(1); PY_CHECK_ARG_TYPE(0, tp_str); c11_sv sv = py_tosv(py_arg(0)); - if(sv.size != 1) { - return TypeError("ord() expected a character, but string of length %d found", sv.size); + if(c11_sv__u8_length(sv) != 1) { + return TypeError("ord() expected a character, but string of length %d found", c11_sv__u8_length(sv)); } - py_newint(py_retval(), sv.data[0]); + int u8bytes = c11__u8_header(sv.data[0], true); + if (u8bytes == 0) { + return ValueError("invalid char: %c", sv.data[0]); + } + int value = c11__u8_value(u8bytes, sv.data); + py_newint(py_retval(), value); return true; } diff --git a/src/public/py_str.c b/src/public/py_str.c index 279fa85a..769c12dd 100644 --- a/src/public/py_str.c +++ b/src/public/py_str.c @@ -639,6 +639,13 @@ static bool bytes_decode(int argc, py_Ref argv) { return true; } +static bool bytes__len__(int argc, py_Ref argv) { + PY_CHECK_ARGC(1); + c11_bytes* self = py_touserdata(&argv[0]); + py_newint(py_retval(), self->size); + return true; +} + py_Type pk_bytes__register() { py_Type type = pk_newtype("bytes", tp_object, NULL, NULL, false, true); // no need to dtor because the memory is controlled by the object @@ -650,6 +657,7 @@ py_Type pk_bytes__register() { py_bindmagic(tp_bytes, __ne__, bytes__ne__); py_bindmagic(tp_bytes, __add__, bytes__add__); py_bindmagic(tp_bytes, __hash__, bytes__hash__); + py_bindmagic(tp_bytes, __len__, bytes__len__); py_bindmethod(tp_bytes, "decode", bytes_decode); return type; diff --git a/tests/46_bytes.py b/tests/46_bytes.py index 091570b5..7c5663b5 100644 --- a/tests/46_bytes.py +++ b/tests/46_bytes.py @@ -9,6 +9,8 @@ assert b'' + b'' == b'' assert b'\xff\xee' != b'1234' assert b'\xff\xee' == b'\xff\xee' +assert len(b'\xff\xee') == 2 +assert len(b'') == 0 a = '测试123' assert a == a.encode().decode() diff --git a/tests/76_misc.py b/tests/76_misc.py index 9536d11e..0c57109f 100644 --- a/tests/76_misc.py +++ b/tests/76_misc.py @@ -34,3 +34,19 @@ for i in range(ord('a'), ord('z')+1): assert A.a == ord('a') assert A.z == ord('z') + +assert ord('测') == 27979 + +try: + assert ord('测试') + print("Should not reach here") + exit(1) +except TypeError: + pass + +try: + assert ord('12') + print("Should not reach here") + exit(1) +except TypeError: + pass \ No newline at end of file