From 2179ff225ffaf6c14bd729a019aebb777c032fce Mon Sep 17 00:00:00 2001 From: szdytom Date: Sun, 28 Jul 2024 14:48:02 +0800 Subject: [PATCH] init commit --- .gitignore | 19 ++++++ chectus_net/cpe2wdl.py | 6 ++ chectus_net/dataloader.py | 43 +++++++++++++ chectus_net/fen2vec.py | 56 +++++++++++++++++ chectus_net/model.py | 125 ++++++++++++++++++++++++++++++++++++++ chectus_net/run.py | 28 +++++++++ chectus_net/train.py | 82 +++++++++++++++++++++++++ lichessdb/init.sql | 34 +++++++++++ lichessdb/prepare.sh | 20 ++++++ lichessdb/process_data.py | 47 ++++++++++++++ 10 files changed, 460 insertions(+) create mode 100644 .gitignore create mode 100644 chectus_net/cpe2wdl.py create mode 100644 chectus_net/dataloader.py create mode 100644 chectus_net/fen2vec.py create mode 100644 chectus_net/model.py create mode 100644 chectus_net/run.py create mode 100644 chectus_net/train.py create mode 100644 lichessdb/init.sql create mode 100644 lichessdb/prepare.sh create mode 100644 lichessdb/process_data.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7ef578d --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +# Misc +*.bak + +# Editor +*.swp +.vscode + +# Python +venv +__pycache__ +*.whl + +# PyTorch +*.pth + +# Data +*.db +*.jsonl +*.zst diff --git a/chectus_net/cpe2wdl.py b/chectus_net/cpe2wdl.py new file mode 100644 index 0000000..1f6cc50 --- /dev/null +++ b/chectus_net/cpe2wdl.py @@ -0,0 +1,6 @@ +import torch + +scale = 410 + +def cpe2wdl(cpe): + return torch.sigmoid(cpe / scale) diff --git a/chectus_net/dataloader.py b/chectus_net/dataloader.py new file mode 100644 index 0000000..7667517 --- /dev/null +++ b/chectus_net/dataloader.py @@ -0,0 +1,43 @@ +import sqlite3 +import torch +from torch.utils.data import Dataset, DataLoader +from fen2vec import parse_fen +from cpe2wdl import cpe2wdl + +class ChessDataset(Dataset): + def __init__(self, db_path, table_name, transform=None): + self.db_path = db_path + self.table_name = table_name + self.transform = transform + self.conn = sqlite3.connect(self.db_path) + self.cursor = self.conn.cursor() + self.data = self._load_data() + self.conn.close() # Close connection after data is loaded + + def _load_data(self): + self.cursor.execute(f"SELECT fen, cpe FROM {self.table_name}") + data = self.cursor.fetchall() + return data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + fen, cpe = self.data[idx] + if self.transform: + fen, cpe = self.transform(fen, cpe) + return fen, cpe + +def create_dataloader(db_path='../lichessdb/evals.db', table_name='Train', batch_size=32, shuffle=True, transform=parse_fen, num_workers=0): + dataset = ChessDataset(db_path, table_name, transform) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) + return dataloader + +if __name__ == '__main__': + train_loader = create_dataloader(batch_size=2, transform=None) + # Iterate through the DataLoader + for batch in train_loader: + fens, cpes = batch + print(fens, cpes) + break + diff --git a/chectus_net/fen2vec.py b/chectus_net/fen2vec.py new file mode 100644 index 0000000..25246a2 --- /dev/null +++ b/chectus_net/fen2vec.py @@ -0,0 +1,56 @@ +import torch +from cpe2wdl import cpe2wdl + +# Mapping of FEN piece characters to indices in the tensor +piece_to_index = { + 'p': 5, 'n': 4, 'b': 3, 'r': 2, 'q': 1, 'k': 0 +} + +def parse_fen(fen, cpe): + board_tensor = torch.zeros((8, 8, 8), dtype=torch.float16) + + parts = fen.split() + piece_placement = parts[0] + active_color = parts[1] + castling_rights = parts[2] + en_passant_target = parts[3] + + current_side = 1 if active_color == 'w' else -1 + + rows = piece_placement.split('/') + for row_idx, row in enumerate(rows): + col_idx = 0 + for char in row: + if char.isdigit(): + col_idx += int(char) + else: + piece_side = current_side if char.isupper() else -current_side + board_tensor[row_idx, col_idx, piece_to_index[char.lower()]] = piece_side + col_idx += 1 + + if en_passant_target != '-': + file = ord(en_passant_target[0]) - ord('a') + rank = 8 - int(en_passant_target[1]) + board_tensor[rank, file, 6] = 1 + + for char in castling_rights: + if char == 'K': + board_tensor[7, 7, 7] = 1 + elif char == 'Q': + board_tensor[7, 0, 7] = 1 + elif char == 'k': + board_tensor[0, 7, 7] = 1 + elif char == 'q': + board_tensor[0, 0, 7] = 1 + + wdl = cpe2wdl(torch.tensor(cpe if current_side == 1 else -cpe, dtype=torch.float16)) + if current_side == -1: + board_tensor = torch.flip(board_tensor, [0]) + + return board_tensor, wdl + +if __name__ == '__main__': + fen = "r2qk2r/3n2p1/1pp1p3/3pPpb1/P2P1nBp/1NB4P/1PP2P2/R3QR1K w kq f6" + # fen = "r2qk2r/3n2p1/1pp1pP2/3p2b1/P2P1nBp/1NB4P/1PP2P2/R3QR1K b kq -" + tensor = parse_fen(fen, 200) + print(tensor) diff --git a/chectus_net/model.py b/chectus_net/model.py new file mode 100644 index 0000000..6f137ed --- /dev/null +++ b/chectus_net/model.py @@ -0,0 +1,125 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchsummary import summary + +class ChessPredictModelBaby(nn.Module): + def __init__(self): + super(ChessPredictModelBaby, self).__init__() + self.model = nn.Sequential( + nn.Conv2d(in_channels=8, out_channels=4, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + + nn.Conv2d(in_channels=4, out_channels=8, kernel_size=3, padding=1), + nn.ReLU(), + nn.AdaptiveAvgPool2d(output_size=1), + + nn.Flatten(), + nn.Linear(8, 16), + nn.ReLU(), + nn.Linear(16, 1), + nn.Tanh() + ) + + def forward(self, x): + x = self.model(x.permute(0, 3, 1, 2)) + return x + +class ChessPredictModelS(nn.Module): + def __init__(self): + super(ChessPredictModelS, self).__init__() + self.model = nn.Sequential( + nn.Conv2d(in_channels=8, out_channels=32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), + nn.ReLU(), + nn.AdaptiveAvgPool2d(output_size=1), + + nn.Flatten(), + nn.Linear(128, 64), + nn.ReLU(), + nn.Linear(64, 1), + nn.Tanh() + ) + + def forward(self, x): + x = self.model(x.permute(0, 3, 1, 2)) + return x + +class BasicBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) + self.bn1 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.downsample = None + if stride != 1 or in_channels != out_channels: + self.downsample = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride), + nn.BatchNorm2d(out_channels) + ) + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + +class ResNet(nn.Module): + def __init__(self, block, layers, num_classes=1): + super(ResNet, self).__init__() + self.in_channels = 64 + + self.model = nn.Sequential( + nn.Conv2d(8, 64, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + self._make_layer(block, 64, layers[0], stride=1), + self._make_layer(block, 128, layers[1], stride=2), + self._make_layer(block, 256, layers[2], stride=2), + self._make_layer(block, 512, layers[3], stride=2), + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(512, num_classes), + nn.Tanh() + ) + + def _make_layer(self, block, out_channels, blocks, stride=1): + layers = [] + layers.append(block(self.in_channels, out_channels, stride)) + self.in_channels = out_channels + for _ in range(1, blocks): + layers.append(block(out_channels, out_channels)) + return nn.Sequential(*layers) + + def forward(self, x): + return self.model(x.permute(0, 3, 1, 2)) + +def resnet18(): + return ResNet(BasicBlock, [2, 2, 2, 2]) + +if __name__ == '__main__': + model = ChessPredictModelS() + summary(model, (8, 8, 8)) diff --git a/chectus_net/run.py b/chectus_net/run.py new file mode 100644 index 0000000..33d7d09 --- /dev/null +++ b/chectus_net/run.py @@ -0,0 +1,28 @@ +import torch +from model import ChessPredictModelS +from fen2vec import parse_fen + +def run(model_path='best_model.pth'): + model = ChessPredictModelS() + model.load_state_dict(torch.load(model_path, weights_only=True)) + model.eval() + + while True: + fen_str = input("FEN(q) >") + if fen_str.lower() == 'q': + break + + try: + input_tensor = parse_fen(fen_str) + input_tensor = input_tensor.unsqueeze(0) # add batch dim + + with torch.no_grad(): + output = model(input_tensor) + transformed_output = torch.atanh(output) * 10 + print(f"$= {transformed_output[0, 0]} {output[0, 0]}") + except Exception as e: + print(f"Error: {e}") + +if __name__ == "__main__": + model_path = 'best_model.pth' + run(model_path) diff --git a/chectus_net/train.py b/chectus_net/train.py new file mode 100644 index 0000000..3f91191 --- /dev/null +++ b/chectus_net/train.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import os +from model import resnet18, ChessPredictModelBaby, ChessPredictModelS +from dataloader import create_dataloader +from tqdm import tqdm + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def train(model, criterion, optimizer, train_loader, val_loader, num_epochs=10, patience=3, model_path='best_model.pth'): + best_loss = float('inf') + epochs_no_improve = 0 + + if os.path.exists(model_path): + model.load_state_dict(torch.load(model_path, weights_only=True)) + print(f"Loaded saved model from {model_path}") + + print('Started training') + for epoch in range(num_epochs): + model.train() + train_loss = 0.0 + + train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]", unit="batch") + for inputs, labels in train_loader_tqdm: + inputs = inputs.to(device) + labels = labels.to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = criterion(outputs.squeeze(1), labels) + loss.backward() + optimizer.step() + train_loss += loss.item() * inputs.size(0) + + train_loss /= len(train_loader.dataset) + + model.eval() + val_loss = 0.0 + + with torch.no_grad(): + val_loader_tqdm = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]", unit="batch") + for inputs, labels in val_loader_tqdm: + inputs = inputs.to(device) + labels = labels.to(device) + outputs = model(inputs) + loss = criterion(outputs.squeeze(1), labels) + val_loss += loss.item() * inputs.size(0) + + val_loss /= len(val_loader.dataset) + + print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}') + + # Check for overfitting + if val_loss < best_loss: + best_loss = val_loss + epochs_no_improve = 0 + torch.save(model.state_dict(), model_path) + print('Model saved!') + else: + epochs_no_improve += 1 + if epochs_no_improve == patience: + print('Early stopping!') + break + +if __name__ == "__main__": + batch_size = 32 + num_epochs = 50 + learning_rate = 0.001 + patience = 2 + model_path = 'best_model_baby.pth' + weight_decay = 0 + + print('Loading Data') + train_loader = create_dataloader(table_name='train', batch_size=batch_size, shuffle=True, num_workers=3) + val_loader = create_dataloader(table_name='test', batch_size=batch_size, shuffle=False, num_workers=3) + print('Loaded Data') + + model = ChessPredictModelBaby().half().to(device) + criterion = nn.SmoothL1Loss() + optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + + train(model, criterion, optimizer, train_loader, val_loader, num_epochs=num_epochs, patience=patience, model_path=model_path) diff --git a/lichessdb/init.sql b/lichessdb/init.sql new file mode 100644 index 0000000..c45f2c8 --- /dev/null +++ b/lichessdb/init.sql @@ -0,0 +1,34 @@ +DROP TABLE IF EXISTS Train; +DROP TABLE IF EXISTS Test; + +CREATE TABLE train ( + id INTEGER PRIMARY KEY ASC AUTOINCREMENT + UNIQUE + NOT NULL, + fen TEXT UNIQUE + NOT NULL, + cpe INTEGER NOT NULL, + dep INTEGER NOT NULL, + nxt TEXT, + mate INTEGER, + flag INTEGER NOT NULL + DEFAULT (0), + ver INTEGER NOT NULL + DEFAULT (1) +); + +CREATE TABLE test ( + id INTEGER PRIMARY KEY ASC AUTOINCREMENT + UNIQUE + NOT NULL, + fen TEXT UNIQUE + NOT NULL, + cpe INTEGER NOT NULL, + dep INTEGER NOT NULL, + nxt TEXT, + mate INTEGER, + flag INTEGER NOT NULL + DEFAULT (0), + ver INTEGER NOT NULL + DEFAULT (1) +); diff --git a/lichessdb/prepare.sh b/lichessdb/prepare.sh new file mode 100644 index 0000000..e849e9e --- /dev/null +++ b/lichessdb/prepare.sh @@ -0,0 +1,20 @@ +#!/bin/sh + +# Downloads +if [ ! -f lichess_db_eval.jsonl ]; then + wget https://database.lichess.org/lichess_db_eval.jsonl.zst + zstd -d lichess_db_eval.jsonl.zst +fi + +# Split +head -n 2000000 lichess_db_eval.jsonl > train2M.jsonl +tail -n 200000 lichess_db_eval.jsonl > test200K.jsonl + +# Create database +if [ ! -f evals.db ]; then + sqlite3 evals.db < init.sql +fi + +# Processz +python3 process_data.py evals.db train2M.jsonl Train +python3 process_data.py evals.db test200K.jsonl Test diff --git a/lichessdb/process_data.py b/lichessdb/process_data.py new file mode 100644 index 0000000..1f10b3a --- /dev/null +++ b/lichessdb/process_data.py @@ -0,0 +1,47 @@ +import sqlite3 +import json +from tqdm import tqdm +import argparse + +def main(): + # Argument parsing + parser = argparse.ArgumentParser(description='Process and insert JSON data into SQLite database.') + parser.add_argument('db_name', type=str, help='Name of the SQLite database file.') + parser.add_argument('input_file', type=str, help='Name of the input JSON file.') + parser.add_argument('table_name', type=str, help='Name of the table in the database.') + args = parser.parse_args() + + # Connect to the SQLite database + conn = sqlite3.connect(args.db_name) + cursor = conn.cursor() + + # Process and insert JSON data + with open(args.input_file, 'r', encoding='utf-8') as f: + for line in tqdm(f, desc='Processing data', unit=' lines'): + data = json.loads(line) + fen = data['fen'] + best_eval = max(data['evals'], key=lambda x: x['depth']) + depth = best_eval['depth'] + + cur_player = fen.split(' ')[1] + + # Safely get evaluation details + pvs = best_eval.get('pvs', [{}])[0] + cpe = pvs.get('cp', 20000 if cur_player == 'w' else -20000) + mate = pvs.get('mate') + nxt = pvs.get('line', '').split(' ')[0] + + # Insert data into the table + cursor.execute(f''' + INSERT OR IGNORE INTO {args.table_name} (fen, cpe, dep, nxt, mate) + VALUES (?, ?, ?, ?, ?) + ''', (fen, cpe, depth, nxt, mate)) + + # Commit the transaction and close the connection + conn.commit() + conn.close() + + print("Data processing and insertion complete.") + +if __name__ == "__main__": + main()