diff --git a/src/modules/random.c b/src/modules/random.c index b6be294f..32e5ed42 100644 --- a/src/modules/random.c +++ b/src/modules/random.c @@ -218,12 +218,21 @@ static bool Random_randint(int argc, py_Ref argv) { static bool Random_choice(int argc, py_Ref argv) { PY_CHECK_ARGC(2); mt19937* ud = py_touserdata(py_arg(0)); - py_TValue* p; - int length = pk_arrayview(py_arg(1), &p); - if(length == -1) return TypeError("choice(): argument must be a list or tuple"); - if(length == 0) return IndexError("cannot choose from an empty sequence"); - int index = mt19937__randint(ud, 0, length - 1); - py_assign(py_retval(), p + index); + if (py_isstr(py_arg(1))) { + c11_sv sv = py_tosv(py_arg(1)); + int length = c11_sv__u8_length(sv); + if(length == 0) return IndexError("cannot choose from an empty sequence"); + int index = mt19937__randint(ud, 0, length - 1); + c11_sv ch = c11_sv__u8_getitem(sv, index); + py_newstrv(py_retval(), ch); + } else { + py_TValue* p; + int length = pk_arrayview(py_arg(1), &p); + if(length == -1) return TypeError("choice(): argument must be a list, tuple or str"); + if(length == 0) return IndexError("cannot choose from an empty sequence"); + int index = mt19937__randint(ud, 0, length - 1); + py_assign(py_retval(), p + index); + } return true; } diff --git a/src/public/PyList.c b/src/public/PyList.c index 67e15d71..ba937307 100644 --- a/src/public/PyList.c +++ b/src/public/PyList.c @@ -263,8 +263,22 @@ static bool list_extend(int argc, py_Ref argv) { List* self = py_touserdata(py_arg(0)); py_TValue* p; int length = pk_arrayview(py_arg(1), &p); - if(length == -1) return TypeError("extend() argument must be a list or tuple"); - c11_vector__extend(self, p, length); + if(length >= 0) { + c11_vector__extend(self, p, length); + } else { + // get iterator + if (!py_iter(py_arg(1))) return false; + py_StackRef tmp_iter = py_pushtmp(); + py_assign(tmp_iter, py_retval()); + while(true) { + int res = py_next(tmp_iter); + if (res == 0) break; + if (res == -1) return false; + assert(res == 1); + c11_vector__push(py_TValue, self, *py_retval()); + } + py_pop(); + } py_newnone(py_retval()); return true; } diff --git a/tests/050_list.py b/tests/050_list.py index 5d61040f..109c9901 100644 --- a/tests/050_list.py +++ b/tests/050_list.py @@ -170,6 +170,13 @@ b = a.copy(); del b[-1:]; assert b == [1, 2, 3] b = a.copy(); del b[:]; assert b == [] assert a == [1, 2, 3, 4] +# test extend with iterable +c = [1] +c.extend('123') +assert c == [1, '1', '2', '3'] +c.extend(range(1, 6)) +assert c == [1, '1', '2', '3', 1, 2, 3, 4, 5] + # test cyclic reference # a = [] # a.append(0) diff --git a/tests/705_random.py b/tests/705_random.py index 274fe14e..db11ff72 100644 --- a/tests/705_random.py +++ b/tests/705_random.py @@ -20,6 +20,9 @@ for i in range(10): for i in range(10): assert r.choice(tuple(a)) in a +for i in range(10): + assert r.choice('hello') in 'hello' + for i in range(10): assert r.randint(1, 1) == 1