import sqlite3 import torch from torch.utils.data import Dataset, DataLoader from fen2vec import parse_fen from cpe2wdl import cpe2wdl class ChessDataset(Dataset): def __init__(self, db_path, table_name, transform=None): self.db_path = db_path self.table_name = table_name self.transform = transform self.conn = sqlite3.connect(self.db_path) self.cursor = self.conn.cursor() self.data = self._load_data() self.conn.close() # Close connection after data is loaded def _load_data(self): self.cursor.execute(f"SELECT fen, cpe FROM {self.table_name}") data = self.cursor.fetchall() return data def __len__(self): return len(self.data) def __getitem__(self, idx): fen, cpe = self.data[idx] if self.transform: fen, cpe = self.transform(fen, cpe) return fen, cpe def create_dataloader(db_path='../lichessdb/evals.db', table_name='Train', batch_size=32, shuffle=True, transform=parse_fen, num_workers=0): dataset = ChessDataset(db_path, table_name, transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) return dataloader if __name__ == '__main__': train_loader = create_dataloader(batch_size=2, transform=None) # Iterate through the DataLoader for batch in train_loader: fens, cpes = batch print(fens, cpes) break