This commit is contained in:
方而静 2024-07-28 15:25:53 +08:00
parent 2179ff225f
commit 059261edfd
3 changed files with 9 additions and 8 deletions

View File

@ -2,8 +2,10 @@ import torch
from model import ChessPredictModelS from model import ChessPredictModelS
from fen2vec import parse_fen from fen2vec import parse_fen
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def run(model_path='best_model.pth'): def run(model_path='best_model.pth'):
model = ChessPredictModelS() model = ChessPredictModelS().half().to(device)
model.load_state_dict(torch.load(model_path, weights_only=True)) model.load_state_dict(torch.load(model_path, weights_only=True))
model.eval() model.eval()
@ -13,12 +15,11 @@ def run(model_path='best_model.pth'):
break break
try: try:
input_tensor = parse_fen(fen_str) input_tensor, _wdl = parse_fen(fen_str, 0)
input_tensor = input_tensor.unsqueeze(0) # add batch dim input_tensor = input_tensor.unsqueeze(0).to(device) # add batch dim
with torch.no_grad(): with torch.no_grad():
output = model(input_tensor) output = model(input_tensor)
transformed_output = torch.atanh(output) * 10
print(f"$= {transformed_output[0, 0]} {output[0, 0]}") print(f"$= {transformed_output[0, 0]} {output[0, 0]}")
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")

View File

@ -63,11 +63,11 @@ def train(model, criterion, optimizer, train_loader, val_loader, num_epochs=10,
break break
if __name__ == "__main__": if __name__ == "__main__":
batch_size = 32 batch_size = 256
num_epochs = 50 num_epochs = 50
learning_rate = 0.001 learning_rate = 0.001
patience = 2 patience = 2
model_path = 'best_model_baby.pth' model_path = 'best_model.pth'
weight_decay = 0 weight_decay = 0
print('Loading Data') print('Loading Data')
@ -75,7 +75,7 @@ if __name__ == "__main__":
val_loader = create_dataloader(table_name='test', batch_size=batch_size, shuffle=False, num_workers=3) val_loader = create_dataloader(table_name='test', batch_size=batch_size, shuffle=False, num_workers=3)
print('Loaded Data') print('Loaded Data')
model = ChessPredictModelBaby().half().to(device) model = ChessPredictModelS().half().to(device)
criterion = nn.SmoothL1Loss() criterion = nn.SmoothL1Loss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

0
lichessdb/prepare.sh Normal file → Executable file
View File