44 lines
1.3 KiB
Python
44 lines
1.3 KiB
Python
import sqlite3
|
|
import torch
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from fen2vec import parse_fen
|
|
from cpe2wdl import cpe2wdl
|
|
|
|
class ChessDataset(Dataset):
|
|
def __init__(self, db_path, table_name, transform=None):
|
|
self.db_path = db_path
|
|
self.table_name = table_name
|
|
self.transform = transform
|
|
self.conn = sqlite3.connect(self.db_path)
|
|
self.cursor = self.conn.cursor()
|
|
self.data = self._load_data()
|
|
self.conn.close() # Close connection after data is loaded
|
|
|
|
def _load_data(self):
|
|
self.cursor.execute(f"SELECT fen, cpe FROM {self.table_name}")
|
|
data = self.cursor.fetchall()
|
|
return data
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
fen, cpe = self.data[idx]
|
|
if self.transform:
|
|
fen, cpe = self.transform(fen, cpe)
|
|
return fen, cpe
|
|
|
|
def create_dataloader(db_path='../lichessdb/evals.db', table_name='Train', batch_size=32, shuffle=True, transform=parse_fen, num_workers=0):
|
|
dataset = ChessDataset(db_path, table_name, transform)
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
|
|
return dataloader
|
|
|
|
if __name__ == '__main__':
|
|
train_loader = create_dataloader(batch_size=2, transform=None)
|
|
# Iterate through the DataLoader
|
|
for batch in train_loader:
|
|
fens, cpes = batch
|
|
print(fens, cpes)
|
|
break
|
|
|