diff --git a/src/modules/random.c b/src/modules/random.c index 82660349..cbc5d00f 100644 --- a/src/modules/random.c +++ b/src/modules/random.c @@ -147,6 +147,21 @@ static bool Random__new__(int argc, py_Ref argv) { return TypeError("Random(): expected 0 or 1 arguments, got %d", argc - 1); } +static bool Random__init__(int argc, py_Ref argv) { + if(argc == 1) { + // do nothing + } else if(argc == 2) { + mt19937* ud = py_touserdata(py_arg(0)); + PY_CHECK_ARG_TYPE(1, tp_int); + py_i64 seed = py_toint(py_arg(1)); + mt19937__seed(ud, (uint32_t)seed); + } else { + return TypeError("Random(): expected 1 or 2 arguments, got %d"); + } + py_newnone(py_retval()); + return true; +} + static bool Random_seed(int argc, py_Ref argv) { PY_CHECK_ARGC(2); PY_CHECK_ARG_TYPE(1, tp_int); @@ -278,6 +293,7 @@ void pk__add_module_random() { py_Type type = py_newtype("Random", tp_object, mod, NULL); py_bindmagic(type, __new__, Random__new__); + py_bindmagic(type, __init__, Random__init__); py_bindmethod(type, "seed", Random_seed); py_bindmethod(type, "random", Random_random); py_bindmethod(type, "uniform", Random_uniform); diff --git a/tests/70_random.py b/tests/70_random.py index 03951181..274fe14e 100644 --- a/tests/70_random.py +++ b/tests/70_random.py @@ -66,4 +66,7 @@ assert (a, b, c) == (16, -418020281577586157, 76) seed(7) assert a == randint(1, 100) assert b == randint(-2**60, 1) -assert c == randint(50, 100) \ No newline at end of file +assert c == randint(50, 100) + +import random +assert random.Random(7).randint(1, 100) == a