diff --git a/include/pocketpy/tuplelist.h b/include/pocketpy/tuplelist.h index 8d08266d..8c82e15f 100644 --- a/include/pocketpy/tuplelist.h +++ b/include/pocketpy/tuplelist.h @@ -31,6 +31,7 @@ struct Tuple { PyObject** begin() const { return _args; } PyObject** end() const { return _args + _size; } + PyObject** data() const { return _args; } }; // a lightweight view for function args, it does not own the memory diff --git a/include/pocketpy/vm.h b/include/pocketpy/vm.h index c5275be4..9c6973ad 100644 --- a/include/pocketpy/vm.h +++ b/include/pocketpy/vm.h @@ -184,6 +184,8 @@ public: PyObject* py_json(PyObject* obj); PyObject* py_iter(PyObject* obj); + std::pair _cast_array(PyObject* obj); + PyObject* find_name_in_mro(Type cls, StrName name); bool isinstance(PyObject* obj, Type base); bool issubclass(Type cls, Type base); diff --git a/src/random.cpp b/src/random.cpp index 463ff1cb..2b06c856 100644 --- a/src/random.cpp +++ b/src/random.cpp @@ -23,6 +23,7 @@ struct Random{ Random& self = _CAST(Random&, args[0]); i64 a = CAST(i64, args[1]); i64 b = CAST(i64, args[2]); + if (a > b) vm->ValueError("randint(a, b): a must be less than or equal to b"); std::uniform_int_distribution dis(a, b); return VAR(dis(self.gen)); }); @@ -50,9 +51,36 @@ struct Random{ vm->bind_method<1>(type, "choice", [](VM* vm, ArgsView args) { Random& self = _CAST(Random&, args[0]); - const List& L = CAST(List&, args[1]); - std::uniform_int_distribution dis(0, L.size() - 1); - return L[dis(self.gen)]; + auto [data, size] = vm->_cast_array(args[1]); + if(size == 0) vm->IndexError("cannot choose from an empty sequence"); + std::uniform_int_distribution dis(0, size - 1); + return data[dis(self.gen)]; + }); + + vm->bind(type, "choices(self, population, weights=None, k=1)", [](VM* vm, ArgsView args) { + Random& self = _CAST(Random&, args[0]); + auto [data, size] = vm->_cast_array(args[1]); + if(size == 0) vm->IndexError("cannot choose from an empty sequence"); + std::vector cum_weights(size); + if(args[2] == vm->None){ + for(int i = 0; i < size; i++) cum_weights[i] = i + 1; + }else{ + auto [weights, weights_size] = vm->_cast_array(args[2]); + if(weights_size != size) vm->ValueError(_S("len(weights) != ", size)); + cum_weights[0] = CAST(f64, weights[0]); + for(int i = 1; i < size; i++){ + cum_weights[i] = cum_weights[i - 1] + CAST(f64, weights[i]); + } + } + if(cum_weights[size - 1] <= 0) vm->ValueError("total of weights must be greater than zero"); + int k = CAST(i64, args[3]); + List result(k); + for(int i = 0; i < k; i++){ + f64 r = std::uniform_real_distribution(0.0, cum_weights[size - 1])(self.gen); + int idx = std::lower_bound(cum_weights.begin(), cum_weights.end(), r) - cum_weights.begin(); + result[i] = data[idx]; + } + return VAR(std::move(result)); }); } }; @@ -67,6 +95,7 @@ void add_module_random(VM* vm){ mod->attr().set("randint", vm->getattr(instance, "randint")); mod->attr().set("shuffle", vm->getattr(instance, "shuffle")); mod->attr().set("choice", vm->getattr(instance, "choice")); + mod->attr().set("choices", vm->getattr(instance, "choices")); } } // namespace pkpy \ No newline at end of file diff --git a/src/vm.cpp b/src/vm.cpp index 5f8540d2..0351310e 100644 --- a/src/vm.cpp +++ b/src/vm.cpp @@ -117,6 +117,18 @@ namespace pkpy{ return nullptr; } + std::pair VM::_cast_array(PyObject* obj){ + if(is_non_tagged_type(obj, VM::tp_list)){ + List& list = PK_OBJ_GET(List, obj); + return {list.data(), list.size()}; + }else if(is_non_tagged_type(obj, VM::tp_tuple)){ + Tuple& tuple = PK_OBJ_GET(Tuple, obj); + return {tuple.data(), tuple.size()}; + } + TypeError(_S("expected list or tuple, got ", _type_name(this, _tp(obj)).escape())); + PK_UNREACHABLE(); + } + FrameId VM::top_frame(){ #if PK_DEBUG_EXTRA_CHECK if(callstack.empty()) PK_FATAL_ERROR(); diff --git a/tests/70_random.py b/tests/70_random.py index 4f13ee86..1adafe99 100644 --- a/tests/70_random.py +++ b/tests/70_random.py @@ -14,8 +14,31 @@ for _ in range(100): a = [1, 2, 3, 4] r.shuffle(a) -for i in range(100): +for i in range(10): assert r.choice(a) in a +for i in range(10): + assert r.choice(tuple(a)) in a +for i in range(10): + assert r.randint(1, 1) == 1 +# test choices +x = (1,) +res = r.choices(x, k=4) +assert (res == [1, 1, 1, 1]), res + +w = (1, 2, 3) +assert r.choices([1, 2, 3], (0.0, 0.0, 0.5)) == [3] + +try: + r.choices([1, 2, 3], (0.0, 0.0, 0.5, 0.5)) + exit(1) +except ValueError: + pass + +try: + r.choices([]) + exit(1) +except IndexError: + pass