29 lines
714 B
Python
29 lines
714 B
Python
import torch
|
|
from model import ChessPredictModelS
|
|
from fen2vec import parse_fen
|
|
|
|
def run(model_path='best_model.pth'):
|
|
model = ChessPredictModelS()
|
|
model.load_state_dict(torch.load(model_path, weights_only=True))
|
|
model.eval()
|
|
|
|
while True:
|
|
fen_str = input("FEN(q) >")
|
|
if fen_str.lower() == 'q':
|
|
break
|
|
|
|
try:
|
|
input_tensor = parse_fen(fen_str)
|
|
input_tensor = input_tensor.unsqueeze(0) # add batch dim
|
|
|
|
with torch.no_grad():
|
|
output = model(input_tensor)
|
|
transformed_output = torch.atanh(output) * 10
|
|
print(f"$= {transformed_output[0, 0]} {output[0, 0]}")
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
|
|
if __name__ == "__main__":
|
|
model_path = 'best_model.pth'
|
|
run(model_path)
|