diff --git a/src/public/py_str.c b/src/public/py_str.c index d53f20d7..98e09512 100644 --- a/src/public/py_str.c +++ b/src/public/py_str.c @@ -436,6 +436,28 @@ static bool _py_str__rjust(int argc, py_Ref argv) { return _py_str__widthjust_impl(false, argc, argv); } +static bool _py_str__find(int argc, py_Ref argv) { + if(argc > 3) return TypeError("find() takes at most 3 arguments"); + int start = 0; + if(argc == 3) { + PY_CHECK_ARG_TYPE(2, tp_int); + start = py_toint(py_arg(2)); + } + c11_string* self = py_touserdata(&argv[0]); + PY_CHECK_ARG_TYPE(1, tp_str); + c11_string* sub = py_touserdata(&argv[1]); + int res = c11_sv__index2(c11_string__sv(self), c11_string__sv(sub), start); + py_newint(py_retval(), res); + return true; +} + +static bool _py_str__index(int argc, py_Ref argv) { + bool ok = _py_str__find(argc, argv); + if(!ok) return false; + if(py_toint(py_retval()) == -1) return ValueError("substring not found"); + return true; +} + py_Type pk_str__register() { pk_VM* vm = pk_current_vm; py_Type type = pk_VM__new_type(vm, "str", tp_object, NULL, false); @@ -474,6 +496,8 @@ py_Type pk_str__register() { py_bindmethod(tp_str, "zfill", _py_str__zfill); py_bindmethod(tp_str, "ljust", _py_str__ljust); py_bindmethod(tp_str, "rjust", _py_str__rjust); + py_bindmethod(tp_str, "find", _py_str__find); + py_bindmethod(tp_str, "index", _py_str__index); return type; } diff --git a/tests/04_str.py b/tests/04_str.py index de293689..d092e7bc 100644 --- a/tests/04_str.py +++ b/tests/04_str.py @@ -108,6 +108,23 @@ assert str(num) == '6' 测试 = "test" assert 测试 == "test" +a = 'abcd' +assert list(a) == ['a', 'b', 'c', 'd'] +a = '测试' +assert list(a) == ['测', '试'] +a = 'a测b试c' +assert list(a) == ['a', '测', 'b', '试', 'c'] +a = 'a测b试' +assert list(a) == ['a', '测', 'b', '试'] +a = '测b试c' +assert list(a) == ['测', 'b', '试', 'c'] +a = '测b' +assert list(a) == ['测', 'b'] +a = 'b' +assert list(a) == ['b'] +a = '测' +assert list(a) == ['测'] + # 3rd slice a = "Hello, World!" assert a[::-1] == "!dlroW ,olleH" @@ -126,6 +143,18 @@ assert '\x30\x31\x32' == '012' assert '\b\b\b' == '\x08\x08\x08' assert repr('\x1f\x1e\x1f') == '\'\\x1f\\x1e\\x1f\'' +a = '123' +assert a.index('2') == 1 +assert a.index('1') == 0 +assert a.index('3') == 2 +assert a.index('23') == 1 + +assert a.index('2', 1) == 1 +assert a.index('1', 0) == 0 + +assert a.find('1') == 0 +assert a.find('1', 1) == -1 + b = list("Hello, World!") assert b == ['H', 'e', 'l', 'l', 'o', ',', ' ', 'W', 'o', 'r', 'l', 'd', '!'] assert b[::-1] == ['!', 'd', 'l', 'r', 'o', 'W', ' ', ',', 'o', 'l', 'l', 'e', 'H'] @@ -146,34 +175,6 @@ assert hex(256) == '0x100' assert hex(257) == '0x101' assert hex(17) == '0x11' -a = '123' -assert a.index('2') == 1 -assert a.index('1') == 0 -assert a.index('3') == 2 - -assert a.index('2', 1) == 1 -assert a.index('1', 0) == 0 - -assert a.find('1') == 0 -assert a.find('1', 1) == -1 - -a = 'abcd' -assert list(a) == ['a', 'b', 'c', 'd'] -a = '测试' -assert list(a) == ['测', '试'] -a = 'a测b试c' -assert list(a) == ['a', '测', 'b', '试', 'c'] -a = 'a测b试' -assert list(a) == ['a', '测', 'b', '试'] -a = '测b试c' -assert list(a) == ['测', 'b', '试', 'c'] -a = '测b' -assert list(a) == ['测', 'b'] -a = 'b' -assert list(a) == ['b'] -a = '测' -assert list(a) == ['测'] - def test(*seq): return s1.join(seq) assert test("r", "u", "n", "o", "o", "b") == "r-u-n-o-o-b"