2024-07-28 15:25:53 +08:00

30 lines
769 B
Python

import torch
from model import ChessPredictModelS
from fen2vec import parse_fen
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def run(model_path='best_model.pth'):
model = ChessPredictModelS().half().to(device)
model.load_state_dict(torch.load(model_path, weights_only=True))
model.eval()
while True:
fen_str = input("FEN(q) >")
if fen_str.lower() == 'q':
break
try:
input_tensor, _wdl = parse_fen(fen_str, 0)
input_tensor = input_tensor.unsqueeze(0).to(device) # add batch dim
with torch.no_grad():
output = model(input_tensor)
print(f"$= {transformed_output[0, 0]} {output[0, 0]}")
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
model_path = 'best_model.pth'
run(model_path)