init commit

This commit is contained in:
方而静 2024-07-28 14:48:02 +08:00
commit 2179ff225f
Signed by: szTom
GPG Key ID: 072D999D60C6473C
10 changed files with 460 additions and 0 deletions

19
.gitignore vendored Normal file
View File

@ -0,0 +1,19 @@
# Misc
*.bak
# Editor
*.swp
.vscode
# Python
venv
__pycache__
*.whl
# PyTorch
*.pth
# Data
*.db
*.jsonl
*.zst

6
chectus_net/cpe2wdl.py Normal file
View File

@ -0,0 +1,6 @@
import torch
scale = 410
def cpe2wdl(cpe):
return torch.sigmoid(cpe / scale)

43
chectus_net/dataloader.py Normal file
View File

@ -0,0 +1,43 @@
import sqlite3
import torch
from torch.utils.data import Dataset, DataLoader
from fen2vec import parse_fen
from cpe2wdl import cpe2wdl
class ChessDataset(Dataset):
def __init__(self, db_path, table_name, transform=None):
self.db_path = db_path
self.table_name = table_name
self.transform = transform
self.conn = sqlite3.connect(self.db_path)
self.cursor = self.conn.cursor()
self.data = self._load_data()
self.conn.close() # Close connection after data is loaded
def _load_data(self):
self.cursor.execute(f"SELECT fen, cpe FROM {self.table_name}")
data = self.cursor.fetchall()
return data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
fen, cpe = self.data[idx]
if self.transform:
fen, cpe = self.transform(fen, cpe)
return fen, cpe
def create_dataloader(db_path='../lichessdb/evals.db', table_name='Train', batch_size=32, shuffle=True, transform=parse_fen, num_workers=0):
dataset = ChessDataset(db_path, table_name, transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
return dataloader
if __name__ == '__main__':
train_loader = create_dataloader(batch_size=2, transform=None)
# Iterate through the DataLoader
for batch in train_loader:
fens, cpes = batch
print(fens, cpes)
break

56
chectus_net/fen2vec.py Normal file
View File

@ -0,0 +1,56 @@
import torch
from cpe2wdl import cpe2wdl
# Mapping of FEN piece characters to indices in the tensor
piece_to_index = {
'p': 5, 'n': 4, 'b': 3, 'r': 2, 'q': 1, 'k': 0
}
def parse_fen(fen, cpe):
board_tensor = torch.zeros((8, 8, 8), dtype=torch.float16)
parts = fen.split()
piece_placement = parts[0]
active_color = parts[1]
castling_rights = parts[2]
en_passant_target = parts[3]
current_side = 1 if active_color == 'w' else -1
rows = piece_placement.split('/')
for row_idx, row in enumerate(rows):
col_idx = 0
for char in row:
if char.isdigit():
col_idx += int(char)
else:
piece_side = current_side if char.isupper() else -current_side
board_tensor[row_idx, col_idx, piece_to_index[char.lower()]] = piece_side
col_idx += 1
if en_passant_target != '-':
file = ord(en_passant_target[0]) - ord('a')
rank = 8 - int(en_passant_target[1])
board_tensor[rank, file, 6] = 1
for char in castling_rights:
if char == 'K':
board_tensor[7, 7, 7] = 1
elif char == 'Q':
board_tensor[7, 0, 7] = 1
elif char == 'k':
board_tensor[0, 7, 7] = 1
elif char == 'q':
board_tensor[0, 0, 7] = 1
wdl = cpe2wdl(torch.tensor(cpe if current_side == 1 else -cpe, dtype=torch.float16))
if current_side == -1:
board_tensor = torch.flip(board_tensor, [0])
return board_tensor, wdl
if __name__ == '__main__':
fen = "r2qk2r/3n2p1/1pp1p3/3pPpb1/P2P1nBp/1NB4P/1PP2P2/R3QR1K w kq f6"
# fen = "r2qk2r/3n2p1/1pp1pP2/3p2b1/P2P1nBp/1NB4P/1PP2P2/R3QR1K b kq -"
tensor = parse_fen(fen, 200)
print(tensor)

125
chectus_net/model.py Normal file
View File

@ -0,0 +1,125 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
class ChessPredictModelBaby(nn.Module):
def __init__(self):
super(ChessPredictModelBaby, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels=8, out_channels=4, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=4, out_channels=8, kernel_size=3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(output_size=1),
nn.Flatten(),
nn.Linear(8, 16),
nn.ReLU(),
nn.Linear(16, 1),
nn.Tanh()
)
def forward(self, x):
x = self.model(x.permute(0, 3, 1, 2))
return x
class ChessPredictModelS(nn.Module):
def __init__(self):
super(ChessPredictModelS, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(output_size=1),
nn.Flatten(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Tanh()
)
def forward(self, x):
x = self.model(x.permute(0, 3, 1, 2))
return x
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = None
if stride != 1 or in_channels != out_channels:
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1):
super(ResNet, self).__init__()
self.in_channels = 64
self.model = nn.Sequential(
nn.Conv2d(8, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
self._make_layer(block, 64, layers[0], stride=1),
self._make_layer(block, 128, layers[1], stride=2),
self._make_layer(block, 256, layers[2], stride=2),
self._make_layer(block, 512, layers[3], stride=2),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(512, num_classes),
nn.Tanh()
)
def _make_layer(self, block, out_channels, blocks, stride=1):
layers = []
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels
for _ in range(1, blocks):
layers.append(block(out_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
return self.model(x.permute(0, 3, 1, 2))
def resnet18():
return ResNet(BasicBlock, [2, 2, 2, 2])
if __name__ == '__main__':
model = ChessPredictModelS()
summary(model, (8, 8, 8))

28
chectus_net/run.py Normal file
View File

@ -0,0 +1,28 @@
import torch
from model import ChessPredictModelS
from fen2vec import parse_fen
def run(model_path='best_model.pth'):
model = ChessPredictModelS()
model.load_state_dict(torch.load(model_path, weights_only=True))
model.eval()
while True:
fen_str = input("FEN(q) >")
if fen_str.lower() == 'q':
break
try:
input_tensor = parse_fen(fen_str)
input_tensor = input_tensor.unsqueeze(0) # add batch dim
with torch.no_grad():
output = model(input_tensor)
transformed_output = torch.atanh(output) * 10
print(f"$= {transformed_output[0, 0]} {output[0, 0]}")
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
model_path = 'best_model.pth'
run(model_path)

82
chectus_net/train.py Normal file
View File

@ -0,0 +1,82 @@
import torch
import torch.nn as nn
import torch.optim as optim
import os
from model import resnet18, ChessPredictModelBaby, ChessPredictModelS
from dataloader import create_dataloader
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def train(model, criterion, optimizer, train_loader, val_loader, num_epochs=10, patience=3, model_path='best_model.pth'):
best_loss = float('inf')
epochs_no_improve = 0
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, weights_only=True))
print(f"Loaded saved model from {model_path}")
print('Started training')
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]", unit="batch")
for inputs, labels in train_loader_tqdm:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs.squeeze(1), labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * inputs.size(0)
train_loss /= len(train_loader.dataset)
model.eval()
val_loss = 0.0
with torch.no_grad():
val_loader_tqdm = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]", unit="batch")
for inputs, labels in val_loader_tqdm:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = criterion(outputs.squeeze(1), labels)
val_loss += loss.item() * inputs.size(0)
val_loss /= len(val_loader.dataset)
print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
# Check for overfitting
if val_loss < best_loss:
best_loss = val_loss
epochs_no_improve = 0
torch.save(model.state_dict(), model_path)
print('Model saved!')
else:
epochs_no_improve += 1
if epochs_no_improve == patience:
print('Early stopping!')
break
if __name__ == "__main__":
batch_size = 32
num_epochs = 50
learning_rate = 0.001
patience = 2
model_path = 'best_model_baby.pth'
weight_decay = 0
print('Loading Data')
train_loader = create_dataloader(table_name='train', batch_size=batch_size, shuffle=True, num_workers=3)
val_loader = create_dataloader(table_name='test', batch_size=batch_size, shuffle=False, num_workers=3)
print('Loaded Data')
model = ChessPredictModelBaby().half().to(device)
criterion = nn.SmoothL1Loss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
train(model, criterion, optimizer, train_loader, val_loader, num_epochs=num_epochs, patience=patience, model_path=model_path)

34
lichessdb/init.sql Normal file
View File

@ -0,0 +1,34 @@
DROP TABLE IF EXISTS Train;
DROP TABLE IF EXISTS Test;
CREATE TABLE train (
id INTEGER PRIMARY KEY ASC AUTOINCREMENT
UNIQUE
NOT NULL,
fen TEXT UNIQUE
NOT NULL,
cpe INTEGER NOT NULL,
dep INTEGER NOT NULL,
nxt TEXT,
mate INTEGER,
flag INTEGER NOT NULL
DEFAULT (0),
ver INTEGER NOT NULL
DEFAULT (1)
);
CREATE TABLE test (
id INTEGER PRIMARY KEY ASC AUTOINCREMENT
UNIQUE
NOT NULL,
fen TEXT UNIQUE
NOT NULL,
cpe INTEGER NOT NULL,
dep INTEGER NOT NULL,
nxt TEXT,
mate INTEGER,
flag INTEGER NOT NULL
DEFAULT (0),
ver INTEGER NOT NULL
DEFAULT (1)
);

20
lichessdb/prepare.sh Normal file
View File

@ -0,0 +1,20 @@
#!/bin/sh
# Downloads
if [ ! -f lichess_db_eval.jsonl ]; then
wget https://database.lichess.org/lichess_db_eval.jsonl.zst
zstd -d lichess_db_eval.jsonl.zst
fi
# Split
head -n 2000000 lichess_db_eval.jsonl > train2M.jsonl
tail -n 200000 lichess_db_eval.jsonl > test200K.jsonl
# Create database
if [ ! -f evals.db ]; then
sqlite3 evals.db < init.sql
fi
# Processz
python3 process_data.py evals.db train2M.jsonl Train
python3 process_data.py evals.db test200K.jsonl Test

47
lichessdb/process_data.py Normal file
View File

@ -0,0 +1,47 @@
import sqlite3
import json
from tqdm import tqdm
import argparse
def main():
# Argument parsing
parser = argparse.ArgumentParser(description='Process and insert JSON data into SQLite database.')
parser.add_argument('db_name', type=str, help='Name of the SQLite database file.')
parser.add_argument('input_file', type=str, help='Name of the input JSON file.')
parser.add_argument('table_name', type=str, help='Name of the table in the database.')
args = parser.parse_args()
# Connect to the SQLite database
conn = sqlite3.connect(args.db_name)
cursor = conn.cursor()
# Process and insert JSON data
with open(args.input_file, 'r', encoding='utf-8') as f:
for line in tqdm(f, desc='Processing data', unit=' lines'):
data = json.loads(line)
fen = data['fen']
best_eval = max(data['evals'], key=lambda x: x['depth'])
depth = best_eval['depth']
cur_player = fen.split(' ')[1]
# Safely get evaluation details
pvs = best_eval.get('pvs', [{}])[0]
cpe = pvs.get('cp', 20000 if cur_player == 'w' else -20000)
mate = pvs.get('mate')
nxt = pvs.get('line', '').split(' ')[0]
# Insert data into the table
cursor.execute(f'''
INSERT OR IGNORE INTO {args.table_name} (fen, cpe, dep, nxt, mate)
VALUES (?, ?, ?, ?, ?)
''', (fen, cpe, depth, nxt, mate))
# Commit the transaction and close the connection
conn.commit()
conn.close()
print("Data processing and insertion complete.")
if __name__ == "__main__":
main()