improve io

This commit is contained in:
blueloveTH 2024-03-28 15:48:59 +08:00
parent eaf231fd9d
commit 79cafcf32c
2 changed files with 74 additions and 19 deletions

View File

@ -12,12 +12,10 @@ namespace pkpy{
struct FileIO {
PY_CLASS(FileIO, io, FileIO)
Str file;
Str mode;
FILE* fp;
bool is_text;
bool is_text() const { return mode != "rb" && mode != "wb" && mode != "ab"; }
FileIO(VM* vm, std::string file, std::string mode);
FileIO(VM* vm, const Str& file, const Str& mode);
void close();
static void _register(VM* vm, PyObject* mod, PyObject* type);
};
@ -62,27 +60,34 @@ void FileIO::_register(VM* vm, PyObject* mod, PyObject* type){
vm->bind_constructor<3>(type, [](VM* vm, ArgsView args){
Type cls = PK_OBJ_GET(Type, args[0]);
return vm->heap.gcnew<FileIO>(cls, vm,
py_cast<Str&>(vm, args[1]).str(),
py_cast<Str&>(vm, args[2]).str());
py_cast<Str&>(vm, args[1]),
py_cast<Str&>(vm, args[2]));
});
vm->bind_method<0>(type, "read", [](VM* vm, ArgsView args){
FileIO& io = CAST(FileIO&, args[0]);
fseek(io.fp, 0, SEEK_END);
int buffer_size = ftell(io.fp);
vm->bind(type, "read(self, size=-1)", [](VM* vm, ArgsView args){
FileIO& io = PK_OBJ_GET(FileIO, args[0]);
i64 size = CAST(i64, args[1]);
i64 buffer_size;
if(size < 0){
long current = ftell(io.fp);
fseek(io.fp, 0, SEEK_END);
buffer_size = ftell(io.fp);
fseek(io.fp, current, SEEK_SET);
}else{
buffer_size = size;
}
unsigned char* buffer = new unsigned char[buffer_size];
fseek(io.fp, 0, SEEK_SET);
size_t actual_size = io_fread(buffer, 1, buffer_size, io.fp);
PK_ASSERT(actual_size <= buffer_size);
// in text mode, CR may be dropped, which may cause `actual_size < buffer_size`
Bytes b(buffer, actual_size);
if(io.is_text()) return VAR(b.str());
if(io.is_text) return VAR(b.str());
return VAR(std::move(b));
});
vm->bind_method<1>(type, "write", [](VM* vm, ArgsView args){
FileIO& io = CAST(FileIO&, args[0]);
if(io.is_text()){
FileIO& io = PK_OBJ_GET(FileIO, args[0]);
if(io.is_text){
Str& s = CAST(Str&, args[1]);
fwrite(s.data, 1, s.length(), io.fp);
}else{
@ -92,14 +97,30 @@ void FileIO::_register(VM* vm, PyObject* mod, PyObject* type){
return vm->None;
});
vm->bind_method<0>(type, "tell", [](VM* vm, ArgsView args){
FileIO& io = PK_OBJ_GET(FileIO, args[0]);
long pos = ftell(io.fp);
if(pos == -1) vm->IOError(strerror(errno));
return VAR(pos);
});
vm->bind_method<2>(type, "seek", [](VM* vm, ArgsView args){
FileIO& io = PK_OBJ_GET(FileIO, args[0]);
long offset = CAST(long, args[1]);
int whence = CAST(int, args[2]);
int ret = fseek(io.fp, offset, whence);
if(ret != 0) vm->IOError(strerror(errno));
return vm->None;
});
vm->bind_method<0>(type, "close", [](VM* vm, ArgsView args){
FileIO& io = CAST(FileIO&, args[0]);
FileIO& io = PK_OBJ_GET(FileIO, args[0]);
io.close();
return vm->None;
});
vm->bind_method<0>(type, "__exit__", [](VM* vm, ArgsView args){
FileIO& io = CAST(FileIO&, args[0]);
FileIO& io = PK_OBJ_GET(FileIO, args[0]);
io.close();
return vm->None;
});
@ -107,7 +128,8 @@ void FileIO::_register(VM* vm, PyObject* mod, PyObject* type){
vm->bind_method<0>(type, "__enter__", PK_LAMBDA(args[0]));
}
FileIO::FileIO(VM* vm, std::string file, std::string mode): file(file), mode(mode) {
FileIO::FileIO(VM* vm, const Str& file, const Str& mode){
this->is_text = mode.sv().find("b") == std::string::npos;
fp = io_fopen(file.c_str(), mode.c_str());
if(!fp) vm->IOError(strerror(errno));
}
@ -121,6 +143,11 @@ void FileIO::close(){
void add_module_io(VM* vm){
PyObject* mod = vm->new_module("io");
FileIO::register_class(vm, mod);
mod->attr().set("SEEK_SET", VAR(SEEK_SET));
mod->attr().set("SEEK_CUR", VAR(SEEK_CUR));
mod->attr().set("SEEK_END", VAR(SEEK_END));
vm->bind(vm->builtins, "open(path, mode='r')", [](VM* vm, ArgsView args){
PK_LOCAL_STATIC StrName m_io("io");
PK_LOCAL_STATIC StrName m_FileIO("FileIO");

View File

@ -13,6 +13,34 @@ a.close()
with open('123.txt', 'rt') as f:
assert f.read() == '123456'
with open('123.txt', 'rt') as f:
assert f.read(3) == '123'
assert f.tell() == 3
assert f.read(3) == '456'
assert f.tell() == 6
assert f.read(3) == '' # EOF
assert f.tell() == 6
with open('123.txt', 'rb') as f:
assert f.read(2) == b'12'
assert f.tell() == 2
assert f.read(2) == b'34'
assert f.tell() == 4
assert f.read(2) == b'56'
assert f.tell() == 6
assert f.read(2) == b'' # EOF
assert f.tell() == 6
# test fseek
with open('123.txt', 'rt') as f:
f.seek(0, io.SEEK_END)
assert f.tell() == 6
assert f.read() == ''
f.seek(3, io.SEEK_SET)
assert f.tell() == 3
assert f.read() == '456'
assert f.tell() == 6
with open('123.txt', 'a') as f:
f.write('测试')
@ -29,13 +57,13 @@ with open('123.bin', 'wb') as f:
f.write('123'.encode())
f.write('测试'.encode())
def f():
def f_():
with open('123.bin', 'rb') as f:
b = f.read()
assert isinstance(b, bytes)
assert b == '123测试'.encode()
f()
f_()
assert os.path.exists('123.bin')
os.remove('123.bin')