diff --git a/src/modules/lz4.c b/src/modules/lz4.c index 2e01ac59..cbb3c692 100644 --- a/src/modules/lz4.c +++ b/src/modules/lz4.c @@ -2,6 +2,7 @@ #include #include +#include "pocketpy/common/utils.h" #include "pocketpy/pocketpy.h" #include "lz4/lib/lz4.h" @@ -10,13 +11,13 @@ static bool lz4_compress(int argc, py_Ref argv) { PY_CHECK_ARG_TYPE(0, tp_bytes); int src_size; const void* src = py_tobytes(argv, &src_size); - int dst_capacity = LZ4_compressBound(src_size); - char* p = (char*)py_newbytes(py_retval(), sizeof(int) + dst_capacity); - memcpy(p, &src_size, sizeof(int)); - char* dst = p + sizeof(int); + uint32_t dst_capacity = LZ4_compressBound(src_size); + char* p = (char*)py_newbytes(py_retval(), sizeof(uint32_t) + dst_capacity); + memcpy(p, &src_size, sizeof(uint32_t)); + char* dst = p + sizeof(uint32_t); int dst_size = LZ4_compress_default(src, dst, src_size, dst_capacity); if(dst_size <= 0) return ValueError("LZ4 compression failed"); - py_bytes_resize(py_retval(), sizeof(int) + dst_size); + py_bytes_resize(py_retval(), sizeof(uint32_t) + dst_size); return true; } @@ -24,15 +25,15 @@ static bool lz4_decompress(int argc, py_Ref argv) { PY_CHECK_ARGC(1); PY_CHECK_ARG_TYPE(0, tp_bytes); int total_size; - const int* p = (int*)py_tobytes(argv, &total_size); + const uint32_t* p = (uint32_t*)py_tobytes(argv, &total_size); const char* src = (const char*)(p + 1); - if(total_size < sizeof(int)) return ValueError("invalid LZ4 data"); - int uncompressed_size = *p; - if(uncompressed_size < 0) return ValueError("invalid LZ4 data"); + if(total_size < sizeof(uint32_t)) return ValueError("invalid LZ4 data"); + uint32_t uncompressed_size = *p; + if(uncompressed_size >= INT32_MAX) return ValueError("invalid LZ4 data"); char* dst = (char*)py_newbytes(py_retval(), uncompressed_size); - int dst_size = LZ4_decompress_safe(src, dst, total_size - sizeof(int), uncompressed_size); + int dst_size = LZ4_decompress_safe(src, dst, total_size - sizeof(uint32_t), uncompressed_size); if(dst_size < 0) return ValueError("LZ4 decompression failed"); - assert(dst_size == uncompressed_size); + c11__rtassert(dst_size == uncompressed_size); return true; }