mirror of
				https://github.com/pocketpy/pocketpy
				synced 2025-10-20 19:40:18 +00:00 
			
		
		
		
	add random.choices
				
					
				
			This commit is contained in:
		
							parent
							
								
									c140847a3a
								
							
						
					
					
						commit
						7095db428c
					
				| @ -31,6 +31,7 @@ struct Tuple { | |||||||
| 
 | 
 | ||||||
|     PyObject** begin() const { return _args; } |     PyObject** begin() const { return _args; } | ||||||
|     PyObject** end() const { return _args + _size; } |     PyObject** end() const { return _args + _size; } | ||||||
|  |     PyObject** data() const { return _args; } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| // a lightweight view for function args, it does not own the memory
 | // a lightweight view for function args, it does not own the memory
 | ||||||
|  | |||||||
| @ -184,6 +184,8 @@ public: | |||||||
|     PyObject* py_json(PyObject* obj); |     PyObject* py_json(PyObject* obj); | ||||||
|     PyObject* py_iter(PyObject* obj); |     PyObject* py_iter(PyObject* obj); | ||||||
| 
 | 
 | ||||||
|  |     std::pair<PyObject**, int> _cast_array(PyObject* obj); | ||||||
|  | 
 | ||||||
|     PyObject* find_name_in_mro(Type cls, StrName name); |     PyObject* find_name_in_mro(Type cls, StrName name); | ||||||
|     bool isinstance(PyObject* obj, Type base); |     bool isinstance(PyObject* obj, Type base); | ||||||
|     bool issubclass(Type cls, Type base); |     bool issubclass(Type cls, Type base); | ||||||
|  | |||||||
| @ -23,6 +23,7 @@ struct Random{ | |||||||
|             Random& self = _CAST(Random&, args[0]); |             Random& self = _CAST(Random&, args[0]); | ||||||
|             i64 a = CAST(i64, args[1]); |             i64 a = CAST(i64, args[1]); | ||||||
|             i64 b = CAST(i64, args[2]); |             i64 b = CAST(i64, args[2]); | ||||||
|  |             if (a > b) vm->ValueError("randint(a, b): a must be less than or equal to b"); | ||||||
|             std::uniform_int_distribution<i64> dis(a, b); |             std::uniform_int_distribution<i64> dis(a, b); | ||||||
|             return VAR(dis(self.gen)); |             return VAR(dis(self.gen)); | ||||||
|         }); |         }); | ||||||
| @ -50,9 +51,36 @@ struct Random{ | |||||||
| 
 | 
 | ||||||
|         vm->bind_method<1>(type, "choice", [](VM* vm, ArgsView args) { |         vm->bind_method<1>(type, "choice", [](VM* vm, ArgsView args) { | ||||||
|             Random& self = _CAST(Random&, args[0]); |             Random& self = _CAST(Random&, args[0]); | ||||||
|             const List& L = CAST(List&, args[1]); |             auto [data, size] = vm->_cast_array(args[1]); | ||||||
|             std::uniform_int_distribution<i64> dis(0, L.size() - 1); |             if(size == 0) vm->IndexError("cannot choose from an empty sequence"); | ||||||
|             return L[dis(self.gen)]; |             std::uniform_int_distribution<i64> dis(0, size - 1); | ||||||
|  |             return data[dis(self.gen)]; | ||||||
|  |         }); | ||||||
|  | 
 | ||||||
|  |         vm->bind(type, "choices(self, population, weights=None, k=1)", [](VM* vm, ArgsView args) { | ||||||
|  |             Random& self = _CAST(Random&, args[0]); | ||||||
|  |             auto [data, size] = vm->_cast_array(args[1]); | ||||||
|  |             if(size == 0) vm->IndexError("cannot choose from an empty sequence"); | ||||||
|  |             std::vector<f64> cum_weights(size); | ||||||
|  |             if(args[2] == vm->None){ | ||||||
|  |                 for(int i = 0; i < size; i++) cum_weights[i] = i + 1; | ||||||
|  |             }else{ | ||||||
|  |                 auto [weights, weights_size] = vm->_cast_array(args[2]); | ||||||
|  |                 if(weights_size != size) vm->ValueError(_S("len(weights) != ", size)); | ||||||
|  |                 cum_weights[0] = CAST(f64, weights[0]); | ||||||
|  |                 for(int i = 1; i < size; i++){ | ||||||
|  |                     cum_weights[i] = cum_weights[i - 1] + CAST(f64, weights[i]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             if(cum_weights[size - 1] <= 0) vm->ValueError("total of weights must be greater than zero"); | ||||||
|  |             int k = CAST(i64, args[3]); | ||||||
|  |             List result(k); | ||||||
|  |             for(int i = 0; i < k; i++){ | ||||||
|  |                 f64 r = std::uniform_real_distribution<f64>(0.0, cum_weights[size - 1])(self.gen); | ||||||
|  |                 int idx = std::lower_bound(cum_weights.begin(), cum_weights.end(), r) - cum_weights.begin(); | ||||||
|  |                 result[i] = data[idx]; | ||||||
|  |             } | ||||||
|  |             return VAR(std::move(result)); | ||||||
|         }); |         }); | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
| @ -67,6 +95,7 @@ void add_module_random(VM* vm){ | |||||||
|     mod->attr().set("randint", vm->getattr(instance, "randint")); |     mod->attr().set("randint", vm->getattr(instance, "randint")); | ||||||
|     mod->attr().set("shuffle", vm->getattr(instance, "shuffle")); |     mod->attr().set("shuffle", vm->getattr(instance, "shuffle")); | ||||||
|     mod->attr().set("choice", vm->getattr(instance, "choice")); |     mod->attr().set("choice", vm->getattr(instance, "choice")); | ||||||
|  |     mod->attr().set("choices", vm->getattr(instance, "choices")); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }   // namespace pkpy
 | }   // namespace pkpy
 | ||||||
							
								
								
									
										12
									
								
								src/vm.cpp
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								src/vm.cpp
									
									
									
									
									
								
							| @ -117,6 +117,18 @@ namespace pkpy{ | |||||||
|         return nullptr; |         return nullptr; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     std::pair<PyObject**, int> VM::_cast_array(PyObject* obj){ | ||||||
|  |         if(is_non_tagged_type(obj, VM::tp_list)){ | ||||||
|  |             List& list = PK_OBJ_GET(List, obj); | ||||||
|  |             return {list.data(), list.size()}; | ||||||
|  |         }else if(is_non_tagged_type(obj, VM::tp_tuple)){ | ||||||
|  |             Tuple& tuple = PK_OBJ_GET(Tuple, obj); | ||||||
|  |             return {tuple.data(), tuple.size()}; | ||||||
|  |         } | ||||||
|  |         TypeError(_S("expected list or tuple, got ", _type_name(this, _tp(obj)).escape())); | ||||||
|  |         PK_UNREACHABLE(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     FrameId VM::top_frame(){ |     FrameId VM::top_frame(){ | ||||||
| #if PK_DEBUG_EXTRA_CHECK | #if PK_DEBUG_EXTRA_CHECK | ||||||
|         if(callstack.empty()) PK_FATAL_ERROR(); |         if(callstack.empty()) PK_FATAL_ERROR(); | ||||||
|  | |||||||
| @ -14,8 +14,31 @@ for _ in range(100): | |||||||
| a = [1, 2, 3, 4] | a = [1, 2, 3, 4] | ||||||
| r.shuffle(a) | r.shuffle(a) | ||||||
| 
 | 
 | ||||||
| for i in range(100): | for i in range(10): | ||||||
|     assert r.choice(a) in a |     assert r.choice(a) in a | ||||||
| 
 | 
 | ||||||
|  | for i in range(10): | ||||||
|  |     assert r.choice(tuple(a)) in a | ||||||
| 
 | 
 | ||||||
|  | for i in range(10): | ||||||
|  |     assert r.randint(1, 1) == 1 | ||||||
| 
 | 
 | ||||||
|  | # test choices | ||||||
|  | x = (1,) | ||||||
|  | res = r.choices(x, k=4) | ||||||
|  | assert (res == [1, 1, 1, 1]), res | ||||||
|  | 
 | ||||||
|  | w = (1, 2, 3) | ||||||
|  | assert r.choices([1, 2, 3], (0.0, 0.0, 0.5)) == [3] | ||||||
|  | 
 | ||||||
|  | try: | ||||||
|  |     r.choices([1, 2, 3], (0.0, 0.0, 0.5, 0.5)) | ||||||
|  |     exit(1) | ||||||
|  | except ValueError: | ||||||
|  |     pass | ||||||
|  | 
 | ||||||
|  | try: | ||||||
|  |     r.choices([]) | ||||||
|  |     exit(1) | ||||||
|  | except IndexError: | ||||||
|  |     pass | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user