setup cmake system
Signed-off-by: szdytom <szdytom@qq.com>
This commit is contained in:
parent
059261edfd
commit
e7e9d8a48e
5
.gitignore
vendored
5
.gitignore
vendored
@ -12,6 +12,11 @@ __pycache__
|
|||||||
|
|
||||||
# PyTorch
|
# PyTorch
|
||||||
*.pth
|
*.pth
|
||||||
|
libtorch
|
||||||
|
|
||||||
|
# Xmake
|
||||||
|
.xmake
|
||||||
|
build/
|
||||||
|
|
||||||
# Data
|
# Data
|
||||||
*.db
|
*.db
|
||||||
|
17
CMakeLists.txt
Normal file
17
CMakeLists.txt
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.24)
|
||||||
|
|
||||||
|
project(
|
||||||
|
Chectus
|
||||||
|
VERSION 0.1.0
|
||||||
|
DESCRIPTION "A chess engine that plays fun moves!"
|
||||||
|
LANGUAGES CXX
|
||||||
|
)
|
||||||
|
|
||||||
|
option("ENABLE_CUDA" "Enable CUDA support" OFF)
|
||||||
|
|
||||||
|
include("cmake/third_party.cmake")
|
||||||
|
|
||||||
|
set(CHECTUS_SRC src/main.cpp)
|
||||||
|
|
||||||
|
add_executable(chectus_engine ${CHECTUS_SRC})
|
||||||
|
target_link_libraries(chectus_engine PRIVATE fmt simdjson "${TORCH_LIBRARIES}")
|
@ -1,6 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
scale = 410
|
|
||||||
|
|
||||||
def cpe2wdl(cpe):
|
|
||||||
return torch.sigmoid(cpe / scale)
|
|
@ -1,43 +0,0 @@
|
|||||||
import sqlite3
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
from fen2vec import parse_fen
|
|
||||||
from cpe2wdl import cpe2wdl
|
|
||||||
|
|
||||||
class ChessDataset(Dataset):
|
|
||||||
def __init__(self, db_path, table_name, transform=None):
|
|
||||||
self.db_path = db_path
|
|
||||||
self.table_name = table_name
|
|
||||||
self.transform = transform
|
|
||||||
self.conn = sqlite3.connect(self.db_path)
|
|
||||||
self.cursor = self.conn.cursor()
|
|
||||||
self.data = self._load_data()
|
|
||||||
self.conn.close() # Close connection after data is loaded
|
|
||||||
|
|
||||||
def _load_data(self):
|
|
||||||
self.cursor.execute(f"SELECT fen, cpe FROM {self.table_name}")
|
|
||||||
data = self.cursor.fetchall()
|
|
||||||
return data
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
fen, cpe = self.data[idx]
|
|
||||||
if self.transform:
|
|
||||||
fen, cpe = self.transform(fen, cpe)
|
|
||||||
return fen, cpe
|
|
||||||
|
|
||||||
def create_dataloader(db_path='../lichessdb/evals.db', table_name='Train', batch_size=32, shuffle=True, transform=parse_fen, num_workers=0):
|
|
||||||
dataset = ChessDataset(db_path, table_name, transform)
|
|
||||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
|
|
||||||
return dataloader
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
train_loader = create_dataloader(batch_size=2, transform=None)
|
|
||||||
# Iterate through the DataLoader
|
|
||||||
for batch in train_loader:
|
|
||||||
fens, cpes = batch
|
|
||||||
print(fens, cpes)
|
|
||||||
break
|
|
||||||
|
|
@ -1,56 +0,0 @@
|
|||||||
import torch
|
|
||||||
from cpe2wdl import cpe2wdl
|
|
||||||
|
|
||||||
# Mapping of FEN piece characters to indices in the tensor
|
|
||||||
piece_to_index = {
|
|
||||||
'p': 5, 'n': 4, 'b': 3, 'r': 2, 'q': 1, 'k': 0
|
|
||||||
}
|
|
||||||
|
|
||||||
def parse_fen(fen, cpe):
|
|
||||||
board_tensor = torch.zeros((8, 8, 8), dtype=torch.float16)
|
|
||||||
|
|
||||||
parts = fen.split()
|
|
||||||
piece_placement = parts[0]
|
|
||||||
active_color = parts[1]
|
|
||||||
castling_rights = parts[2]
|
|
||||||
en_passant_target = parts[3]
|
|
||||||
|
|
||||||
current_side = 1 if active_color == 'w' else -1
|
|
||||||
|
|
||||||
rows = piece_placement.split('/')
|
|
||||||
for row_idx, row in enumerate(rows):
|
|
||||||
col_idx = 0
|
|
||||||
for char in row:
|
|
||||||
if char.isdigit():
|
|
||||||
col_idx += int(char)
|
|
||||||
else:
|
|
||||||
piece_side = current_side if char.isupper() else -current_side
|
|
||||||
board_tensor[row_idx, col_idx, piece_to_index[char.lower()]] = piece_side
|
|
||||||
col_idx += 1
|
|
||||||
|
|
||||||
if en_passant_target != '-':
|
|
||||||
file = ord(en_passant_target[0]) - ord('a')
|
|
||||||
rank = 8 - int(en_passant_target[1])
|
|
||||||
board_tensor[rank, file, 6] = 1
|
|
||||||
|
|
||||||
for char in castling_rights:
|
|
||||||
if char == 'K':
|
|
||||||
board_tensor[7, 7, 7] = 1
|
|
||||||
elif char == 'Q':
|
|
||||||
board_tensor[7, 0, 7] = 1
|
|
||||||
elif char == 'k':
|
|
||||||
board_tensor[0, 7, 7] = 1
|
|
||||||
elif char == 'q':
|
|
||||||
board_tensor[0, 0, 7] = 1
|
|
||||||
|
|
||||||
wdl = cpe2wdl(torch.tensor(cpe if current_side == 1 else -cpe, dtype=torch.float16))
|
|
||||||
if current_side == -1:
|
|
||||||
board_tensor = torch.flip(board_tensor, [0])
|
|
||||||
|
|
||||||
return board_tensor, wdl
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
fen = "r2qk2r/3n2p1/1pp1p3/3pPpb1/P2P1nBp/1NB4P/1PP2P2/R3QR1K w kq f6"
|
|
||||||
# fen = "r2qk2r/3n2p1/1pp1pP2/3p2b1/P2P1nBp/1NB4P/1PP2P2/R3QR1K b kq -"
|
|
||||||
tensor = parse_fen(fen, 200)
|
|
||||||
print(tensor)
|
|
@ -1,125 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torchsummary import summary
|
|
||||||
|
|
||||||
class ChessPredictModelBaby(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(ChessPredictModelBaby, self).__init__()
|
|
||||||
self.model = nn.Sequential(
|
|
||||||
nn.Conv2d(in_channels=8, out_channels=4, kernel_size=3, padding=1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
||||||
|
|
||||||
nn.Conv2d(in_channels=4, out_channels=8, kernel_size=3, padding=1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.AdaptiveAvgPool2d(output_size=1),
|
|
||||||
|
|
||||||
nn.Flatten(),
|
|
||||||
nn.Linear(8, 16),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(16, 1),
|
|
||||||
nn.Tanh()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.model(x.permute(0, 3, 1, 2))
|
|
||||||
return x
|
|
||||||
|
|
||||||
class ChessPredictModelS(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(ChessPredictModelS, self).__init__()
|
|
||||||
self.model = nn.Sequential(
|
|
||||||
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=3, padding=1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
||||||
|
|
||||||
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
||||||
|
|
||||||
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.AdaptiveAvgPool2d(output_size=1),
|
|
||||||
|
|
||||||
nn.Flatten(),
|
|
||||||
nn.Linear(128, 64),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(64, 1),
|
|
||||||
nn.Tanh()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.model(x.permute(0, 3, 1, 2))
|
|
||||||
return x
|
|
||||||
|
|
||||||
class BasicBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, stride=1):
|
|
||||||
super(BasicBlock, self).__init__()
|
|
||||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
|
|
||||||
self.bn1 = nn.BatchNorm2d(out_channels)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
||||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
|
||||||
|
|
||||||
self.downsample = None
|
|
||||||
if stride != 1 or in_channels != out_channels:
|
|
||||||
self.downsample = nn.Sequential(
|
|
||||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
|
|
||||||
nn.BatchNorm2d(out_channels)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
identity = x
|
|
||||||
|
|
||||||
out = self.conv1(x)
|
|
||||||
out = self.bn1(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
out = self.conv2(out)
|
|
||||||
out = self.bn2(out)
|
|
||||||
|
|
||||||
if self.downsample is not None:
|
|
||||||
identity = self.downsample(x)
|
|
||||||
|
|
||||||
out += identity
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
class ResNet(nn.Module):
|
|
||||||
def __init__(self, block, layers, num_classes=1):
|
|
||||||
super(ResNet, self).__init__()
|
|
||||||
self.in_channels = 64
|
|
||||||
|
|
||||||
self.model = nn.Sequential(
|
|
||||||
nn.Conv2d(8, 64, kernel_size=3, stride=1, padding=1),
|
|
||||||
nn.BatchNorm2d(64),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
self._make_layer(block, 64, layers[0], stride=1),
|
|
||||||
self._make_layer(block, 128, layers[1], stride=2),
|
|
||||||
self._make_layer(block, 256, layers[2], stride=2),
|
|
||||||
self._make_layer(block, 512, layers[3], stride=2),
|
|
||||||
nn.AdaptiveAvgPool2d((1, 1)),
|
|
||||||
nn.Flatten(),
|
|
||||||
nn.Linear(512, num_classes),
|
|
||||||
nn.Tanh()
|
|
||||||
)
|
|
||||||
|
|
||||||
def _make_layer(self, block, out_channels, blocks, stride=1):
|
|
||||||
layers = []
|
|
||||||
layers.append(block(self.in_channels, out_channels, stride))
|
|
||||||
self.in_channels = out_channels
|
|
||||||
for _ in range(1, blocks):
|
|
||||||
layers.append(block(out_channels, out_channels))
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.model(x.permute(0, 3, 1, 2))
|
|
||||||
|
|
||||||
def resnet18():
|
|
||||||
return ResNet(BasicBlock, [2, 2, 2, 2])
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
model = ChessPredictModelS()
|
|
||||||
summary(model, (8, 8, 8))
|
|
@ -1,29 +0,0 @@
|
|||||||
import torch
|
|
||||||
from model import ChessPredictModelS
|
|
||||||
from fen2vec import parse_fen
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
def run(model_path='best_model.pth'):
|
|
||||||
model = ChessPredictModelS().half().to(device)
|
|
||||||
model.load_state_dict(torch.load(model_path, weights_only=True))
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
while True:
|
|
||||||
fen_str = input("FEN(q) >")
|
|
||||||
if fen_str.lower() == 'q':
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
input_tensor, _wdl = parse_fen(fen_str, 0)
|
|
||||||
input_tensor = input_tensor.unsqueeze(0).to(device) # add batch dim
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
output = model(input_tensor)
|
|
||||||
print(f"$= {transformed_output[0, 0]} {output[0, 0]}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
model_path = 'best_model.pth'
|
|
||||||
run(model_path)
|
|
@ -1,82 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
import os
|
|
||||||
from model import resnet18, ChessPredictModelBaby, ChessPredictModelS
|
|
||||||
from dataloader import create_dataloader
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
def train(model, criterion, optimizer, train_loader, val_loader, num_epochs=10, patience=3, model_path='best_model.pth'):
|
|
||||||
best_loss = float('inf')
|
|
||||||
epochs_no_improve = 0
|
|
||||||
|
|
||||||
if os.path.exists(model_path):
|
|
||||||
model.load_state_dict(torch.load(model_path, weights_only=True))
|
|
||||||
print(f"Loaded saved model from {model_path}")
|
|
||||||
|
|
||||||
print('Started training')
|
|
||||||
for epoch in range(num_epochs):
|
|
||||||
model.train()
|
|
||||||
train_loss = 0.0
|
|
||||||
|
|
||||||
train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]", unit="batch")
|
|
||||||
for inputs, labels in train_loader_tqdm:
|
|
||||||
inputs = inputs.to(device)
|
|
||||||
labels = labels.to(device)
|
|
||||||
optimizer.zero_grad()
|
|
||||||
outputs = model(inputs)
|
|
||||||
loss = criterion(outputs.squeeze(1), labels)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
train_loss += loss.item() * inputs.size(0)
|
|
||||||
|
|
||||||
train_loss /= len(train_loader.dataset)
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
val_loss = 0.0
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
val_loader_tqdm = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]", unit="batch")
|
|
||||||
for inputs, labels in val_loader_tqdm:
|
|
||||||
inputs = inputs.to(device)
|
|
||||||
labels = labels.to(device)
|
|
||||||
outputs = model(inputs)
|
|
||||||
loss = criterion(outputs.squeeze(1), labels)
|
|
||||||
val_loss += loss.item() * inputs.size(0)
|
|
||||||
|
|
||||||
val_loss /= len(val_loader.dataset)
|
|
||||||
|
|
||||||
print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
|
|
||||||
|
|
||||||
# Check for overfitting
|
|
||||||
if val_loss < best_loss:
|
|
||||||
best_loss = val_loss
|
|
||||||
epochs_no_improve = 0
|
|
||||||
torch.save(model.state_dict(), model_path)
|
|
||||||
print('Model saved!')
|
|
||||||
else:
|
|
||||||
epochs_no_improve += 1
|
|
||||||
if epochs_no_improve == patience:
|
|
||||||
print('Early stopping!')
|
|
||||||
break
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
batch_size = 256
|
|
||||||
num_epochs = 50
|
|
||||||
learning_rate = 0.001
|
|
||||||
patience = 2
|
|
||||||
model_path = 'best_model.pth'
|
|
||||||
weight_decay = 0
|
|
||||||
|
|
||||||
print('Loading Data')
|
|
||||||
train_loader = create_dataloader(table_name='train', batch_size=batch_size, shuffle=True, num_workers=3)
|
|
||||||
val_loader = create_dataloader(table_name='test', batch_size=batch_size, shuffle=False, num_workers=3)
|
|
||||||
print('Loaded Data')
|
|
||||||
|
|
||||||
model = ChessPredictModelS().half().to(device)
|
|
||||||
criterion = nn.SmoothL1Loss()
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
|
||||||
|
|
||||||
train(model, criterion, optimizer, train_loader, val_loader, num_epochs=num_epochs, patience=patience, model_path=model_path)
|
|
51
cmake/fetch_torch.cmake
Normal file
51
cmake/fetch_torch.cmake
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.27)
|
||||||
|
|
||||||
|
include(FetchContent)
|
||||||
|
|
||||||
|
set(LIBTORCH_PLATFORM "none" CACHE STRING "Determines libtorch platform version to download (CUDA11.8, CUDA12.1, CUDA12.8, ROCm6.1 or none).")
|
||||||
|
|
||||||
|
if(${LIBTORCH_PLATFORM} STREQUAL "none")
|
||||||
|
set(LIBTORCH_DEVICE "cpu")
|
||||||
|
elseif(${LIBTORCH_PLATFORM} STREQUAL "CUDA11.8")
|
||||||
|
set(LIBTORCH_DEVICE "cu118")
|
||||||
|
elseif(${LIBTORCH_PLATFORM} STREQUAL "CUDA12.1")
|
||||||
|
set(LIBTORCH_DEVICE "cu121")
|
||||||
|
elseif(${LIBTORCH_PLATFORM} STREQUAL "CUDA12.8")
|
||||||
|
set(LIBTORCH_DEVICE "cu128")
|
||||||
|
elseif(${LIBTORCH_PLATFORM} STREQUAL "ROCm6.1")
|
||||||
|
set(LIBTORCH_DEVICE "rocm6.1")
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "Invalid libtorch platform, must be either CUDA11.8, CUDA12.1, CUDA12.8, ROCm6.1 or none.")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(PYTORCH_VERSION "2.4.0")
|
||||||
|
|
||||||
|
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
|
||||||
|
set(LIBTORCH_URL "${LIBTORCH_DEVICE}/libtorch-win-shared-with-deps-${PYTORCH_VERSION}%2B${LIBTORCH_DEVICE}.zip")
|
||||||
|
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
|
||||||
|
set(LIBTORCH_URL "${LIBTORCH_DEVICE}/libtorch-shared-with-deps-${PYTORCH_VERSION}%2B${LIBTORCH_DEVICE}.zip")
|
||||||
|
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin")
|
||||||
|
if(NOT ${LIBTORCH_DEVICE} STREQUAL "cpu")
|
||||||
|
message(WARNING "MacOS binaries support CPU version only, using it instead.")
|
||||||
|
set(LIBTORCH_DEVICE "cpu")
|
||||||
|
endif()
|
||||||
|
set(LIBTORCH_URL "cpu/libtorch-macos-arm64-${PYTORCH_VERSION}.zip")
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "Unsupported CMake System Name '${CMAKE_SYSTEM_NAME}' (expected 'Windows', 'Linux' or 'Darwin')")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
message(STATUS "Downloading libtorch version ${PYTORCH_VERSION} for ${LIBTORCH_DEVICE} on ${CMAKE_SYSTEM_NAME} from ${LIBTORCH_URL}...")
|
||||||
|
|
||||||
|
FetchContent_Declare(
|
||||||
|
libtorch
|
||||||
|
PREFIX libtorch
|
||||||
|
DOWNLOAD_DIR ${CMAKE_SOURCE_DIR}/libtorch
|
||||||
|
SOURCE_DIR ${CMAKE_SOURCE_DIR}/libtorch
|
||||||
|
URL "https://download.pytorch.org/libtorch/${LIBTORCH_URL}"
|
||||||
|
)
|
||||||
|
|
||||||
|
FetchContent_MakeAvailable(libtorch)
|
||||||
|
|
||||||
|
message(STATUS "Downloaded libtorch.")
|
||||||
|
|
||||||
|
find_package(Torch REQUIRED PATHS "${CMAKE_SOURCE_DIR}/libtorch")
|
27
cmake/third_party.cmake
Normal file
27
cmake/third_party.cmake
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.27)
|
||||||
|
|
||||||
|
cmake_policy(VERSION 3.27)
|
||||||
|
include(FetchContent)
|
||||||
|
|
||||||
|
# fmtlib
|
||||||
|
message(STATUS "Downloading fmtlib...")
|
||||||
|
FetchContent_Declare(
|
||||||
|
fmt
|
||||||
|
URL "https://github.com/fmtlib/fmt/archive/refs/tags/11.0.2.zip"
|
||||||
|
FIND_PACKAGE_ARGS NAMES fmt
|
||||||
|
)
|
||||||
|
FetchContent_MakeAvailable(fmt)
|
||||||
|
message(STATUS "fmtlib downloaded.")
|
||||||
|
|
||||||
|
# simdjson
|
||||||
|
message(STATUS "Downloading simdjson...")
|
||||||
|
FetchContent_Declare(
|
||||||
|
simdjson
|
||||||
|
URL "https://github.com/simdjson/simdjson/archive/refs/tags/v3.9.5.tar.gz"
|
||||||
|
FIND_PACKAGE_ARGS NAMES simdjson
|
||||||
|
)
|
||||||
|
FetchContent_MakeAvailable(simdjson)
|
||||||
|
message(STATUS "simdjson downloaded.")
|
||||||
|
|
||||||
|
# libtorch
|
||||||
|
include("${CMAKE_CURRENT_LIST_DIR}/fetch_torch.cmake")
|
8
include/chectus_net/fen2vec.h
Normal file
8
include/chectus_net/fen2vec.h
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string_view>
|
||||||
|
|
||||||
|
namespace chectus {
|
||||||
|
|
||||||
|
|
||||||
|
}
|
@ -1,34 +0,0 @@
|
|||||||
DROP TABLE IF EXISTS Train;
|
|
||||||
DROP TABLE IF EXISTS Test;
|
|
||||||
|
|
||||||
CREATE TABLE train (
|
|
||||||
id INTEGER PRIMARY KEY ASC AUTOINCREMENT
|
|
||||||
UNIQUE
|
|
||||||
NOT NULL,
|
|
||||||
fen TEXT UNIQUE
|
|
||||||
NOT NULL,
|
|
||||||
cpe INTEGER NOT NULL,
|
|
||||||
dep INTEGER NOT NULL,
|
|
||||||
nxt TEXT,
|
|
||||||
mate INTEGER,
|
|
||||||
flag INTEGER NOT NULL
|
|
||||||
DEFAULT (0),
|
|
||||||
ver INTEGER NOT NULL
|
|
||||||
DEFAULT (1)
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE test (
|
|
||||||
id INTEGER PRIMARY KEY ASC AUTOINCREMENT
|
|
||||||
UNIQUE
|
|
||||||
NOT NULL,
|
|
||||||
fen TEXT UNIQUE
|
|
||||||
NOT NULL,
|
|
||||||
cpe INTEGER NOT NULL,
|
|
||||||
dep INTEGER NOT NULL,
|
|
||||||
nxt TEXT,
|
|
||||||
mate INTEGER,
|
|
||||||
flag INTEGER NOT NULL
|
|
||||||
DEFAULT (0),
|
|
||||||
ver INTEGER NOT NULL
|
|
||||||
DEFAULT (1)
|
|
||||||
);
|
|
@ -5,16 +5,3 @@ if [ ! -f lichess_db_eval.jsonl ]; then
|
|||||||
wget https://database.lichess.org/lichess_db_eval.jsonl.zst
|
wget https://database.lichess.org/lichess_db_eval.jsonl.zst
|
||||||
zstd -d lichess_db_eval.jsonl.zst
|
zstd -d lichess_db_eval.jsonl.zst
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Split
|
|
||||||
head -n 2000000 lichess_db_eval.jsonl > train2M.jsonl
|
|
||||||
tail -n 200000 lichess_db_eval.jsonl > test200K.jsonl
|
|
||||||
|
|
||||||
# Create database
|
|
||||||
if [ ! -f evals.db ]; then
|
|
||||||
sqlite3 evals.db < init.sql
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Processz
|
|
||||||
python3 process_data.py evals.db train2M.jsonl Train
|
|
||||||
python3 process_data.py evals.db test200K.jsonl Test
|
|
||||||
|
@ -1,47 +0,0 @@
|
|||||||
import sqlite3
|
|
||||||
import json
|
|
||||||
from tqdm import tqdm
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Argument parsing
|
|
||||||
parser = argparse.ArgumentParser(description='Process and insert JSON data into SQLite database.')
|
|
||||||
parser.add_argument('db_name', type=str, help='Name of the SQLite database file.')
|
|
||||||
parser.add_argument('input_file', type=str, help='Name of the input JSON file.')
|
|
||||||
parser.add_argument('table_name', type=str, help='Name of the table in the database.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Connect to the SQLite database
|
|
||||||
conn = sqlite3.connect(args.db_name)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
# Process and insert JSON data
|
|
||||||
with open(args.input_file, 'r', encoding='utf-8') as f:
|
|
||||||
for line in tqdm(f, desc='Processing data', unit=' lines'):
|
|
||||||
data = json.loads(line)
|
|
||||||
fen = data['fen']
|
|
||||||
best_eval = max(data['evals'], key=lambda x: x['depth'])
|
|
||||||
depth = best_eval['depth']
|
|
||||||
|
|
||||||
cur_player = fen.split(' ')[1]
|
|
||||||
|
|
||||||
# Safely get evaluation details
|
|
||||||
pvs = best_eval.get('pvs', [{}])[0]
|
|
||||||
cpe = pvs.get('cp', 20000 if cur_player == 'w' else -20000)
|
|
||||||
mate = pvs.get('mate')
|
|
||||||
nxt = pvs.get('line', '').split(' ')[0]
|
|
||||||
|
|
||||||
# Insert data into the table
|
|
||||||
cursor.execute(f'''
|
|
||||||
INSERT OR IGNORE INTO {args.table_name} (fen, cpe, dep, nxt, mate)
|
|
||||||
VALUES (?, ?, ?, ?, ?)
|
|
||||||
''', (fen, cpe, depth, nxt, mate))
|
|
||||||
|
|
||||||
# Commit the transaction and close the connection
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
print("Data processing and insertion complete.")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
8
src/main.cpp
Normal file
8
src/main.cpp
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <simdjson.h>
|
||||||
|
#include <torch/torch.h>
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
torch::Tensor tensor = torch::rand({2, 3});
|
||||||
|
std::cout << tensor << std::endl;
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user