mirror of
https://github.com/pocketpy/pocketpy
synced 2025-10-20 11:30:18 +00:00
add random.choices
This commit is contained in:
parent
c140847a3a
commit
7095db428c
@ -31,6 +31,7 @@ struct Tuple {
|
|||||||
|
|
||||||
PyObject** begin() const { return _args; }
|
PyObject** begin() const { return _args; }
|
||||||
PyObject** end() const { return _args + _size; }
|
PyObject** end() const { return _args + _size; }
|
||||||
|
PyObject** data() const { return _args; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// a lightweight view for function args, it does not own the memory
|
// a lightweight view for function args, it does not own the memory
|
||||||
|
@ -184,6 +184,8 @@ public:
|
|||||||
PyObject* py_json(PyObject* obj);
|
PyObject* py_json(PyObject* obj);
|
||||||
PyObject* py_iter(PyObject* obj);
|
PyObject* py_iter(PyObject* obj);
|
||||||
|
|
||||||
|
std::pair<PyObject**, int> _cast_array(PyObject* obj);
|
||||||
|
|
||||||
PyObject* find_name_in_mro(Type cls, StrName name);
|
PyObject* find_name_in_mro(Type cls, StrName name);
|
||||||
bool isinstance(PyObject* obj, Type base);
|
bool isinstance(PyObject* obj, Type base);
|
||||||
bool issubclass(Type cls, Type base);
|
bool issubclass(Type cls, Type base);
|
||||||
|
@ -23,6 +23,7 @@ struct Random{
|
|||||||
Random& self = _CAST(Random&, args[0]);
|
Random& self = _CAST(Random&, args[0]);
|
||||||
i64 a = CAST(i64, args[1]);
|
i64 a = CAST(i64, args[1]);
|
||||||
i64 b = CAST(i64, args[2]);
|
i64 b = CAST(i64, args[2]);
|
||||||
|
if (a > b) vm->ValueError("randint(a, b): a must be less than or equal to b");
|
||||||
std::uniform_int_distribution<i64> dis(a, b);
|
std::uniform_int_distribution<i64> dis(a, b);
|
||||||
return VAR(dis(self.gen));
|
return VAR(dis(self.gen));
|
||||||
});
|
});
|
||||||
@ -50,9 +51,36 @@ struct Random{
|
|||||||
|
|
||||||
vm->bind_method<1>(type, "choice", [](VM* vm, ArgsView args) {
|
vm->bind_method<1>(type, "choice", [](VM* vm, ArgsView args) {
|
||||||
Random& self = _CAST(Random&, args[0]);
|
Random& self = _CAST(Random&, args[0]);
|
||||||
const List& L = CAST(List&, args[1]);
|
auto [data, size] = vm->_cast_array(args[1]);
|
||||||
std::uniform_int_distribution<i64> dis(0, L.size() - 1);
|
if(size == 0) vm->IndexError("cannot choose from an empty sequence");
|
||||||
return L[dis(self.gen)];
|
std::uniform_int_distribution<i64> dis(0, size - 1);
|
||||||
|
return data[dis(self.gen)];
|
||||||
|
});
|
||||||
|
|
||||||
|
vm->bind(type, "choices(self, population, weights=None, k=1)", [](VM* vm, ArgsView args) {
|
||||||
|
Random& self = _CAST(Random&, args[0]);
|
||||||
|
auto [data, size] = vm->_cast_array(args[1]);
|
||||||
|
if(size == 0) vm->IndexError("cannot choose from an empty sequence");
|
||||||
|
std::vector<f64> cum_weights(size);
|
||||||
|
if(args[2] == vm->None){
|
||||||
|
for(int i = 0; i < size; i++) cum_weights[i] = i + 1;
|
||||||
|
}else{
|
||||||
|
auto [weights, weights_size] = vm->_cast_array(args[2]);
|
||||||
|
if(weights_size != size) vm->ValueError(_S("len(weights) != ", size));
|
||||||
|
cum_weights[0] = CAST(f64, weights[0]);
|
||||||
|
for(int i = 1; i < size; i++){
|
||||||
|
cum_weights[i] = cum_weights[i - 1] + CAST(f64, weights[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(cum_weights[size - 1] <= 0) vm->ValueError("total of weights must be greater than zero");
|
||||||
|
int k = CAST(i64, args[3]);
|
||||||
|
List result(k);
|
||||||
|
for(int i = 0; i < k; i++){
|
||||||
|
f64 r = std::uniform_real_distribution<f64>(0.0, cum_weights[size - 1])(self.gen);
|
||||||
|
int idx = std::lower_bound(cum_weights.begin(), cum_weights.end(), r) - cum_weights.begin();
|
||||||
|
result[i] = data[idx];
|
||||||
|
}
|
||||||
|
return VAR(std::move(result));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -67,6 +95,7 @@ void add_module_random(VM* vm){
|
|||||||
mod->attr().set("randint", vm->getattr(instance, "randint"));
|
mod->attr().set("randint", vm->getattr(instance, "randint"));
|
||||||
mod->attr().set("shuffle", vm->getattr(instance, "shuffle"));
|
mod->attr().set("shuffle", vm->getattr(instance, "shuffle"));
|
||||||
mod->attr().set("choice", vm->getattr(instance, "choice"));
|
mod->attr().set("choice", vm->getattr(instance, "choice"));
|
||||||
|
mod->attr().set("choices", vm->getattr(instance, "choices"));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace pkpy
|
} // namespace pkpy
|
12
src/vm.cpp
12
src/vm.cpp
@ -117,6 +117,18 @@ namespace pkpy{
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<PyObject**, int> VM::_cast_array(PyObject* obj){
|
||||||
|
if(is_non_tagged_type(obj, VM::tp_list)){
|
||||||
|
List& list = PK_OBJ_GET(List, obj);
|
||||||
|
return {list.data(), list.size()};
|
||||||
|
}else if(is_non_tagged_type(obj, VM::tp_tuple)){
|
||||||
|
Tuple& tuple = PK_OBJ_GET(Tuple, obj);
|
||||||
|
return {tuple.data(), tuple.size()};
|
||||||
|
}
|
||||||
|
TypeError(_S("expected list or tuple, got ", _type_name(this, _tp(obj)).escape()));
|
||||||
|
PK_UNREACHABLE();
|
||||||
|
}
|
||||||
|
|
||||||
FrameId VM::top_frame(){
|
FrameId VM::top_frame(){
|
||||||
#if PK_DEBUG_EXTRA_CHECK
|
#if PK_DEBUG_EXTRA_CHECK
|
||||||
if(callstack.empty()) PK_FATAL_ERROR();
|
if(callstack.empty()) PK_FATAL_ERROR();
|
||||||
|
@ -14,8 +14,31 @@ for _ in range(100):
|
|||||||
a = [1, 2, 3, 4]
|
a = [1, 2, 3, 4]
|
||||||
r.shuffle(a)
|
r.shuffle(a)
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(10):
|
||||||
assert r.choice(a) in a
|
assert r.choice(a) in a
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
assert r.choice(tuple(a)) in a
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
assert r.randint(1, 1) == 1
|
||||||
|
|
||||||
|
# test choices
|
||||||
|
x = (1,)
|
||||||
|
res = r.choices(x, k=4)
|
||||||
|
assert (res == [1, 1, 1, 1]), res
|
||||||
|
|
||||||
|
w = (1, 2, 3)
|
||||||
|
assert r.choices([1, 2, 3], (0.0, 0.0, 0.5)) == [3]
|
||||||
|
|
||||||
|
try:
|
||||||
|
r.choices([1, 2, 3], (0.0, 0.0, 0.5, 0.5))
|
||||||
|
exit(1)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
r.choices([])
|
||||||
|
exit(1)
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
Loading…
x
Reference in New Issue
Block a user