init commit
This commit is contained in:
commit
2179ff225f
19
.gitignore
vendored
Normal file
19
.gitignore
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
# Misc
|
||||
*.bak
|
||||
|
||||
# Editor
|
||||
*.swp
|
||||
.vscode
|
||||
|
||||
# Python
|
||||
venv
|
||||
__pycache__
|
||||
*.whl
|
||||
|
||||
# PyTorch
|
||||
*.pth
|
||||
|
||||
# Data
|
||||
*.db
|
||||
*.jsonl
|
||||
*.zst
|
6
chectus_net/cpe2wdl.py
Normal file
6
chectus_net/cpe2wdl.py
Normal file
@ -0,0 +1,6 @@
|
||||
import torch
|
||||
|
||||
scale = 410
|
||||
|
||||
def cpe2wdl(cpe):
|
||||
return torch.sigmoid(cpe / scale)
|
43
chectus_net/dataloader.py
Normal file
43
chectus_net/dataloader.py
Normal file
@ -0,0 +1,43 @@
|
||||
import sqlite3
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from fen2vec import parse_fen
|
||||
from cpe2wdl import cpe2wdl
|
||||
|
||||
class ChessDataset(Dataset):
|
||||
def __init__(self, db_path, table_name, transform=None):
|
||||
self.db_path = db_path
|
||||
self.table_name = table_name
|
||||
self.transform = transform
|
||||
self.conn = sqlite3.connect(self.db_path)
|
||||
self.cursor = self.conn.cursor()
|
||||
self.data = self._load_data()
|
||||
self.conn.close() # Close connection after data is loaded
|
||||
|
||||
def _load_data(self):
|
||||
self.cursor.execute(f"SELECT fen, cpe FROM {self.table_name}")
|
||||
data = self.cursor.fetchall()
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
fen, cpe = self.data[idx]
|
||||
if self.transform:
|
||||
fen, cpe = self.transform(fen, cpe)
|
||||
return fen, cpe
|
||||
|
||||
def create_dataloader(db_path='../lichessdb/evals.db', table_name='Train', batch_size=32, shuffle=True, transform=parse_fen, num_workers=0):
|
||||
dataset = ChessDataset(db_path, table_name, transform)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
|
||||
return dataloader
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_loader = create_dataloader(batch_size=2, transform=None)
|
||||
# Iterate through the DataLoader
|
||||
for batch in train_loader:
|
||||
fens, cpes = batch
|
||||
print(fens, cpes)
|
||||
break
|
||||
|
56
chectus_net/fen2vec.py
Normal file
56
chectus_net/fen2vec.py
Normal file
@ -0,0 +1,56 @@
|
||||
import torch
|
||||
from cpe2wdl import cpe2wdl
|
||||
|
||||
# Mapping of FEN piece characters to indices in the tensor
|
||||
piece_to_index = {
|
||||
'p': 5, 'n': 4, 'b': 3, 'r': 2, 'q': 1, 'k': 0
|
||||
}
|
||||
|
||||
def parse_fen(fen, cpe):
|
||||
board_tensor = torch.zeros((8, 8, 8), dtype=torch.float16)
|
||||
|
||||
parts = fen.split()
|
||||
piece_placement = parts[0]
|
||||
active_color = parts[1]
|
||||
castling_rights = parts[2]
|
||||
en_passant_target = parts[3]
|
||||
|
||||
current_side = 1 if active_color == 'w' else -1
|
||||
|
||||
rows = piece_placement.split('/')
|
||||
for row_idx, row in enumerate(rows):
|
||||
col_idx = 0
|
||||
for char in row:
|
||||
if char.isdigit():
|
||||
col_idx += int(char)
|
||||
else:
|
||||
piece_side = current_side if char.isupper() else -current_side
|
||||
board_tensor[row_idx, col_idx, piece_to_index[char.lower()]] = piece_side
|
||||
col_idx += 1
|
||||
|
||||
if en_passant_target != '-':
|
||||
file = ord(en_passant_target[0]) - ord('a')
|
||||
rank = 8 - int(en_passant_target[1])
|
||||
board_tensor[rank, file, 6] = 1
|
||||
|
||||
for char in castling_rights:
|
||||
if char == 'K':
|
||||
board_tensor[7, 7, 7] = 1
|
||||
elif char == 'Q':
|
||||
board_tensor[7, 0, 7] = 1
|
||||
elif char == 'k':
|
||||
board_tensor[0, 7, 7] = 1
|
||||
elif char == 'q':
|
||||
board_tensor[0, 0, 7] = 1
|
||||
|
||||
wdl = cpe2wdl(torch.tensor(cpe if current_side == 1 else -cpe, dtype=torch.float16))
|
||||
if current_side == -1:
|
||||
board_tensor = torch.flip(board_tensor, [0])
|
||||
|
||||
return board_tensor, wdl
|
||||
|
||||
if __name__ == '__main__':
|
||||
fen = "r2qk2r/3n2p1/1pp1p3/3pPpb1/P2P1nBp/1NB4P/1PP2P2/R3QR1K w kq f6"
|
||||
# fen = "r2qk2r/3n2p1/1pp1pP2/3p2b1/P2P1nBp/1NB4P/1PP2P2/R3QR1K b kq -"
|
||||
tensor = parse_fen(fen, 200)
|
||||
print(tensor)
|
125
chectus_net/model.py
Normal file
125
chectus_net/model.py
Normal file
@ -0,0 +1,125 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchsummary import summary
|
||||
|
||||
class ChessPredictModelBaby(nn.Module):
|
||||
def __init__(self):
|
||||
super(ChessPredictModelBaby, self).__init__()
|
||||
self.model = nn.Sequential(
|
||||
nn.Conv2d(in_channels=8, out_channels=4, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
|
||||
nn.Conv2d(in_channels=4, out_channels=8, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AdaptiveAvgPool2d(output_size=1),
|
||||
|
||||
nn.Flatten(),
|
||||
nn.Linear(8, 16),
|
||||
nn.ReLU(),
|
||||
nn.Linear(16, 1),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(x.permute(0, 3, 1, 2))
|
||||
return x
|
||||
|
||||
class ChessPredictModelS(nn.Module):
|
||||
def __init__(self):
|
||||
super(ChessPredictModelS, self).__init__()
|
||||
self.model = nn.Sequential(
|
||||
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
|
||||
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
|
||||
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AdaptiveAvgPool2d(output_size=1),
|
||||
|
||||
nn.Flatten(),
|
||||
nn.Linear(128, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 1),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(x.permute(0, 3, 1, 2))
|
||||
return x
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
|
||||
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
|
||||
self.downsample = None
|
||||
if stride != 1 or in_channels != out_channels:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
|
||||
nn.BatchNorm2d(out_channels)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, layers, num_classes=1):
|
||||
super(ResNet, self).__init__()
|
||||
self.in_channels = 64
|
||||
|
||||
self.model = nn.Sequential(
|
||||
nn.Conv2d(8, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
self._make_layer(block, 64, layers[0], stride=1),
|
||||
self._make_layer(block, 128, layers[1], stride=2),
|
||||
self._make_layer(block, 256, layers[2], stride=2),
|
||||
self._make_layer(block, 512, layers[3], stride=2),
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
nn.Flatten(),
|
||||
nn.Linear(512, num_classes),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def _make_layer(self, block, out_channels, blocks, stride=1):
|
||||
layers = []
|
||||
layers.append(block(self.in_channels, out_channels, stride))
|
||||
self.in_channels = out_channels
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(out_channels, out_channels))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x.permute(0, 3, 1, 2))
|
||||
|
||||
def resnet18():
|
||||
return ResNet(BasicBlock, [2, 2, 2, 2])
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = ChessPredictModelS()
|
||||
summary(model, (8, 8, 8))
|
28
chectus_net/run.py
Normal file
28
chectus_net/run.py
Normal file
@ -0,0 +1,28 @@
|
||||
import torch
|
||||
from model import ChessPredictModelS
|
||||
from fen2vec import parse_fen
|
||||
|
||||
def run(model_path='best_model.pth'):
|
||||
model = ChessPredictModelS()
|
||||
model.load_state_dict(torch.load(model_path, weights_only=True))
|
||||
model.eval()
|
||||
|
||||
while True:
|
||||
fen_str = input("FEN(q) >")
|
||||
if fen_str.lower() == 'q':
|
||||
break
|
||||
|
||||
try:
|
||||
input_tensor = parse_fen(fen_str)
|
||||
input_tensor = input_tensor.unsqueeze(0) # add batch dim
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_tensor)
|
||||
transformed_output = torch.atanh(output) * 10
|
||||
print(f"$= {transformed_output[0, 0]} {output[0, 0]}")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_path = 'best_model.pth'
|
||||
run(model_path)
|
82
chectus_net/train.py
Normal file
82
chectus_net/train.py
Normal file
@ -0,0 +1,82 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import os
|
||||
from model import resnet18, ChessPredictModelBaby, ChessPredictModelS
|
||||
from dataloader import create_dataloader
|
||||
from tqdm import tqdm
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def train(model, criterion, optimizer, train_loader, val_loader, num_epochs=10, patience=3, model_path='best_model.pth'):
|
||||
best_loss = float('inf')
|
||||
epochs_no_improve = 0
|
||||
|
||||
if os.path.exists(model_path):
|
||||
model.load_state_dict(torch.load(model_path, weights_only=True))
|
||||
print(f"Loaded saved model from {model_path}")
|
||||
|
||||
print('Started training')
|
||||
for epoch in range(num_epochs):
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
|
||||
train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]", unit="batch")
|
||||
for inputs, labels in train_loader_tqdm:
|
||||
inputs = inputs.to(device)
|
||||
labels = labels.to(device)
|
||||
optimizer.zero_grad()
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs.squeeze(1), labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
train_loss += loss.item() * inputs.size(0)
|
||||
|
||||
train_loss /= len(train_loader.dataset)
|
||||
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
val_loader_tqdm = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]", unit="batch")
|
||||
for inputs, labels in val_loader_tqdm:
|
||||
inputs = inputs.to(device)
|
||||
labels = labels.to(device)
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs.squeeze(1), labels)
|
||||
val_loss += loss.item() * inputs.size(0)
|
||||
|
||||
val_loss /= len(val_loader.dataset)
|
||||
|
||||
print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
|
||||
|
||||
# Check for overfitting
|
||||
if val_loss < best_loss:
|
||||
best_loss = val_loss
|
||||
epochs_no_improve = 0
|
||||
torch.save(model.state_dict(), model_path)
|
||||
print('Model saved!')
|
||||
else:
|
||||
epochs_no_improve += 1
|
||||
if epochs_no_improve == patience:
|
||||
print('Early stopping!')
|
||||
break
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_size = 32
|
||||
num_epochs = 50
|
||||
learning_rate = 0.001
|
||||
patience = 2
|
||||
model_path = 'best_model_baby.pth'
|
||||
weight_decay = 0
|
||||
|
||||
print('Loading Data')
|
||||
train_loader = create_dataloader(table_name='train', batch_size=batch_size, shuffle=True, num_workers=3)
|
||||
val_loader = create_dataloader(table_name='test', batch_size=batch_size, shuffle=False, num_workers=3)
|
||||
print('Loaded Data')
|
||||
|
||||
model = ChessPredictModelBaby().half().to(device)
|
||||
criterion = nn.SmoothL1Loss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||
|
||||
train(model, criterion, optimizer, train_loader, val_loader, num_epochs=num_epochs, patience=patience, model_path=model_path)
|
34
lichessdb/init.sql
Normal file
34
lichessdb/init.sql
Normal file
@ -0,0 +1,34 @@
|
||||
DROP TABLE IF EXISTS Train;
|
||||
DROP TABLE IF EXISTS Test;
|
||||
|
||||
CREATE TABLE train (
|
||||
id INTEGER PRIMARY KEY ASC AUTOINCREMENT
|
||||
UNIQUE
|
||||
NOT NULL,
|
||||
fen TEXT UNIQUE
|
||||
NOT NULL,
|
||||
cpe INTEGER NOT NULL,
|
||||
dep INTEGER NOT NULL,
|
||||
nxt TEXT,
|
||||
mate INTEGER,
|
||||
flag INTEGER NOT NULL
|
||||
DEFAULT (0),
|
||||
ver INTEGER NOT NULL
|
||||
DEFAULT (1)
|
||||
);
|
||||
|
||||
CREATE TABLE test (
|
||||
id INTEGER PRIMARY KEY ASC AUTOINCREMENT
|
||||
UNIQUE
|
||||
NOT NULL,
|
||||
fen TEXT UNIQUE
|
||||
NOT NULL,
|
||||
cpe INTEGER NOT NULL,
|
||||
dep INTEGER NOT NULL,
|
||||
nxt TEXT,
|
||||
mate INTEGER,
|
||||
flag INTEGER NOT NULL
|
||||
DEFAULT (0),
|
||||
ver INTEGER NOT NULL
|
||||
DEFAULT (1)
|
||||
);
|
20
lichessdb/prepare.sh
Normal file
20
lichessdb/prepare.sh
Normal file
@ -0,0 +1,20 @@
|
||||
#!/bin/sh
|
||||
|
||||
# Downloads
|
||||
if [ ! -f lichess_db_eval.jsonl ]; then
|
||||
wget https://database.lichess.org/lichess_db_eval.jsonl.zst
|
||||
zstd -d lichess_db_eval.jsonl.zst
|
||||
fi
|
||||
|
||||
# Split
|
||||
head -n 2000000 lichess_db_eval.jsonl > train2M.jsonl
|
||||
tail -n 200000 lichess_db_eval.jsonl > test200K.jsonl
|
||||
|
||||
# Create database
|
||||
if [ ! -f evals.db ]; then
|
||||
sqlite3 evals.db < init.sql
|
||||
fi
|
||||
|
||||
# Processz
|
||||
python3 process_data.py evals.db train2M.jsonl Train
|
||||
python3 process_data.py evals.db test200K.jsonl Test
|
47
lichessdb/process_data.py
Normal file
47
lichessdb/process_data.py
Normal file
@ -0,0 +1,47 @@
|
||||
import sqlite3
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
|
||||
def main():
|
||||
# Argument parsing
|
||||
parser = argparse.ArgumentParser(description='Process and insert JSON data into SQLite database.')
|
||||
parser.add_argument('db_name', type=str, help='Name of the SQLite database file.')
|
||||
parser.add_argument('input_file', type=str, help='Name of the input JSON file.')
|
||||
parser.add_argument('table_name', type=str, help='Name of the table in the database.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Connect to the SQLite database
|
||||
conn = sqlite3.connect(args.db_name)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Process and insert JSON data
|
||||
with open(args.input_file, 'r', encoding='utf-8') as f:
|
||||
for line in tqdm(f, desc='Processing data', unit=' lines'):
|
||||
data = json.loads(line)
|
||||
fen = data['fen']
|
||||
best_eval = max(data['evals'], key=lambda x: x['depth'])
|
||||
depth = best_eval['depth']
|
||||
|
||||
cur_player = fen.split(' ')[1]
|
||||
|
||||
# Safely get evaluation details
|
||||
pvs = best_eval.get('pvs', [{}])[0]
|
||||
cpe = pvs.get('cp', 20000 if cur_player == 'w' else -20000)
|
||||
mate = pvs.get('mate')
|
||||
nxt = pvs.get('line', '').split(' ')[0]
|
||||
|
||||
# Insert data into the table
|
||||
cursor.execute(f'''
|
||||
INSERT OR IGNORE INTO {args.table_name} (fen, cpe, dep, nxt, mate)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
''', (fen, cpe, depth, nxt, mate))
|
||||
|
||||
# Commit the transaction and close the connection
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
print("Data processing and insertion complete.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user