Fix
This commit is contained in:
parent
2179ff225f
commit
059261edfd
@ -2,8 +2,10 @@ import torch
|
|||||||
from model import ChessPredictModelS
|
from model import ChessPredictModelS
|
||||||
from fen2vec import parse_fen
|
from fen2vec import parse_fen
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
def run(model_path='best_model.pth'):
|
def run(model_path='best_model.pth'):
|
||||||
model = ChessPredictModelS()
|
model = ChessPredictModelS().half().to(device)
|
||||||
model.load_state_dict(torch.load(model_path, weights_only=True))
|
model.load_state_dict(torch.load(model_path, weights_only=True))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -13,12 +15,11 @@ def run(model_path='best_model.pth'):
|
|||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
input_tensor = parse_fen(fen_str)
|
input_tensor, _wdl = parse_fen(fen_str, 0)
|
||||||
input_tensor = input_tensor.unsqueeze(0) # add batch dim
|
input_tensor = input_tensor.unsqueeze(0).to(device) # add batch dim
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model(input_tensor)
|
output = model(input_tensor)
|
||||||
transformed_output = torch.atanh(output) * 10
|
|
||||||
print(f"$= {transformed_output[0, 0]} {output[0, 0]}")
|
print(f"$= {transformed_output[0, 0]} {output[0, 0]}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error: {e}")
|
print(f"Error: {e}")
|
||||||
|
@ -63,11 +63,11 @@ def train(model, criterion, optimizer, train_loader, val_loader, num_epochs=10,
|
|||||||
break
|
break
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
batch_size = 32
|
batch_size = 256
|
||||||
num_epochs = 50
|
num_epochs = 50
|
||||||
learning_rate = 0.001
|
learning_rate = 0.001
|
||||||
patience = 2
|
patience = 2
|
||||||
model_path = 'best_model_baby.pth'
|
model_path = 'best_model.pth'
|
||||||
weight_decay = 0
|
weight_decay = 0
|
||||||
|
|
||||||
print('Loading Data')
|
print('Loading Data')
|
||||||
@ -75,7 +75,7 @@ if __name__ == "__main__":
|
|||||||
val_loader = create_dataloader(table_name='test', batch_size=batch_size, shuffle=False, num_workers=3)
|
val_loader = create_dataloader(table_name='test', batch_size=batch_size, shuffle=False, num_workers=3)
|
||||||
print('Loaded Data')
|
print('Loaded Data')
|
||||||
|
|
||||||
model = ChessPredictModelBaby().half().to(device)
|
model = ChessPredictModelS().half().to(device)
|
||||||
criterion = nn.SmoothL1Loss()
|
criterion = nn.SmoothL1Loss()
|
||||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
|
||||||
|
0
lichessdb/prepare.sh
Normal file → Executable file
0
lichessdb/prepare.sh
Normal file → Executable file
Loading…
x
Reference in New Issue
Block a user