chectus/chectus_net/dataloader.py
2024-07-28 14:48:02 +08:00

44 lines
1.3 KiB
Python

import sqlite3
import torch
from torch.utils.data import Dataset, DataLoader
from fen2vec import parse_fen
from cpe2wdl import cpe2wdl
class ChessDataset(Dataset):
def __init__(self, db_path, table_name, transform=None):
self.db_path = db_path
self.table_name = table_name
self.transform = transform
self.conn = sqlite3.connect(self.db_path)
self.cursor = self.conn.cursor()
self.data = self._load_data()
self.conn.close() # Close connection after data is loaded
def _load_data(self):
self.cursor.execute(f"SELECT fen, cpe FROM {self.table_name}")
data = self.cursor.fetchall()
return data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
fen, cpe = self.data[idx]
if self.transform:
fen, cpe = self.transform(fen, cpe)
return fen, cpe
def create_dataloader(db_path='../lichessdb/evals.db', table_name='Train', batch_size=32, shuffle=True, transform=parse_fen, num_workers=0):
dataset = ChessDataset(db_path, table_name, transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
return dataloader
if __name__ == '__main__':
train_loader = create_dataloader(batch_size=2, transform=None)
# Iterate through the DataLoader
for batch in train_loader:
fens, cpes = batch
print(fens, cpes)
break