From 4de7a16af4e37d64da614be50ba4dcef8668b999 Mon Sep 17 00:00:00 2001 From: anmoltyagi18 Date: Wed, 1 Apr 2026 19:26:48 +0530 Subject: [PATCH] fix(bytes): make bytes object iterable - fixes #450 --- include/pocketpy/interpreter/vm.h | 1 + include/pocketpy/pocketpy.h | 1 + src/bindings/py_str.c | 28 ++++++++++++ src/interpreter/vm.c | 1 + test_bytes_iter.py | 76 +++++++++++++++++++++++++++++++ 5 files changed, 107 insertions(+) create mode 100644 test_bytes_iter.py diff --git a/include/pocketpy/interpreter/vm.h b/include/pocketpy/interpreter/vm.h index 1b3126ac..769d95c6 100644 --- a/include/pocketpy/interpreter/vm.h +++ b/include/pocketpy/interpreter/vm.h @@ -133,6 +133,7 @@ void pk_number__register(); py_Type pk_str__register(); py_Type pk_str_iterator__register(); py_Type pk_bytes__register(); +py_Type pk_bytes_iterator__register(); py_Type pk_dict__register(); py_Type pk_dict_items__register(); py_Type pk_list__register(); diff --git a/include/pocketpy/pocketpy.h b/include/pocketpy/pocketpy.h index b012234a..077ecd0f 100644 --- a/include/pocketpy/pocketpy.h +++ b/include/pocketpy/pocketpy.h @@ -840,6 +840,7 @@ enum py_PredefinedType { tp_BaseException, tp_Exception, tp_bytes, + tp_bytes_iterator, tp_namedict, tp_locals, tp_code, diff --git a/src/bindings/py_str.c b/src/bindings/py_str.c index de5c556c..71261b91 100644 --- a/src/bindings/py_str.c +++ b/src/bindings/py_str.c @@ -674,6 +674,33 @@ py_Type pk_str_iterator__register() { return type; } +static bool bytes__iter__(int argc, py_Ref argv) { + PY_CHECK_ARGC(1); + int* ud = py_newobject(py_retval(), tp_bytes_iterator, 1, sizeof(int)); + *ud = 0; + py_setslot(py_retval(), 0, argv); // keep a reference to the bytes object + return true; +} + +bool bytes_iterator__next__(int argc, py_Ref argv) { + PY_CHECK_ARGC(1); + int* ud = py_touserdata(&argv[0]); + int size; + unsigned char* data = py_tobytes(py_getslot(argv, 0), &size); + if(*ud == size) return StopIteration(); + py_newint(py_retval(), data[*ud]); // return the byte value as an integer (0-255) + *ud += 1; + return true; +} + +py_Type pk_bytes_iterator__register() { + py_Type type = pk_newtype("bytes_iterator", tp_object, NULL, NULL, false, true); + + py_bindmagic(type, __iter__, pk_wrapper__self); + py_bindmagic(type, __next__, bytes_iterator__next__); + return type; +} + static bool bytes__new__(int argc, py_Ref argv) { if(argc == 1) { py_newbytes(py_retval(), 0); @@ -808,6 +835,7 @@ py_Type pk_bytes__register() { py_bindmagic(tp_bytes, __add__, bytes__add__); py_bindmagic(tp_bytes, __hash__, bytes__hash__); py_bindmagic(tp_bytes, __len__, bytes__len__); + py_bindmagic(tp_bytes, __iter__, bytes__iter__); py_bindmethod(tp_bytes, "decode", bytes_decode); return type; diff --git a/src/interpreter/vm.c b/src/interpreter/vm.c index f5573780..aa945e27 100644 --- a/src/interpreter/vm.c +++ b/src/interpreter/vm.c @@ -156,6 +156,7 @@ void VM__ctor(VM* self) { validate(tp_BaseException, pk_BaseException__register()); validate(tp_Exception, pk_Exception__register()); validate(tp_bytes, pk_bytes__register()); + validate(tp_bytes_iterator, pk_bytes_iterator__register()); validate(tp_namedict, pk_namedict__register()); validate(tp_locals, pk_newtype("locals", tp_object, NULL, NULL, false, true)); validate(tp_code, pk_code__register()); diff --git a/test_bytes_iter.py b/test_bytes_iter.py new file mode 100644 index 00000000..78fd629d --- /dev/null +++ b/test_bytes_iter.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# Test script for bytes iteration fix (issue #450) + +print("Testing bytes iteration...") +print() + +# Test 1: Basic bytes creation and representation +text = "Hello" +byte_data = text.encode() +print("Test 1: Basic bytes creation and representation") +print(f"Text: {text}") +print(f"Bytes: {byte_data}") +print() + +# Test 2: list(bytes) +print("Test 2: list(byte_data):") +try: + result = list(byte_data) + print(f"Result: {result}") + expected = [72, 101, 108, 108, 111] + assert result == expected, f"Expected {expected}, got {result}" + print("✓ PASS") +except Exception as e: + print(f"✗ FAIL: {e}") +print() + +# Test 3: for loop iteration +print("Test 3: for loop iteration") +try: + result = [] + for x in byte_data: + result.append(x) + print(f"Result: {result}") + expected = [72, 101, 108, 108, 111] + assert result == expected, f"Expected {expected}, got {result}" + print("✓ PASS") +except Exception as e: + print(f"✗ FAIL: {e}") +print() + +# Test 4: bytes indexing (should still work) +print("Test 4: bytes indexing") +try: + result = byte_data[0] + print(f"byte_data[0] = {result}") + assert result == 72, f"Expected 72, got {result}" + print("✓ PASS") +except Exception as e: + print(f"✗ FAIL: {e}") +print() + +# Test 5: len(bytes) +print("Test 5: len(byte_data)") +try: + result = len(byte_data) + print(f"len(byte_data) = {result}") + assert result == 5, f"Expected 5, got {result}" + print("✓ PASS") +except Exception as e: + print(f"✗ FAIL: {e}") +print() + +# Test 6: bytes slicing +print("Test 6: bytes slicing") +try: + result = byte_data[1:3] + result_list = list(result) + print(f"list(byte_data[1:3]) = {result_list}") + expected = [101, 108] + assert result_list == expected, f"Expected {expected}, got {result_list}" + print("✓ PASS") +except Exception as e: + print(f"✗ FAIL: {e}") +print() + +print("All tests completed!")