import torch import torch.nn as nn import torch.optim as optim import os from model import resnet18, ChessPredictModelBaby, ChessPredictModelS from dataloader import create_dataloader from tqdm import tqdm device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def train(model, criterion, optimizer, train_loader, val_loader, num_epochs=10, patience=3, model_path='best_model.pth'): best_loss = float('inf') epochs_no_improve = 0 if os.path.exists(model_path): model.load_state_dict(torch.load(model_path, weights_only=True)) print(f"Loaded saved model from {model_path}") print('Started training') for epoch in range(num_epochs): model.train() train_loss = 0.0 train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]", unit="batch") for inputs, labels in train_loader_tqdm: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs.squeeze(1), labels) loss.backward() optimizer.step() train_loss += loss.item() * inputs.size(0) train_loss /= len(train_loader.dataset) model.eval() val_loss = 0.0 with torch.no_grad(): val_loader_tqdm = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]", unit="batch") for inputs, labels in val_loader_tqdm: inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) loss = criterion(outputs.squeeze(1), labels) val_loss += loss.item() * inputs.size(0) val_loss /= len(val_loader.dataset) print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}') # Check for overfitting if val_loss < best_loss: best_loss = val_loss epochs_no_improve = 0 torch.save(model.state_dict(), model_path) print('Model saved!') else: epochs_no_improve += 1 if epochs_no_improve == patience: print('Early stopping!') break if __name__ == "__main__": batch_size = 256 num_epochs = 50 learning_rate = 0.001 patience = 2 model_path = 'best_model.pth' weight_decay = 0 print('Loading Data') train_loader = create_dataloader(table_name='train', batch_size=batch_size, shuffle=True, num_workers=3) val_loader = create_dataloader(table_name='test', batch_size=batch_size, shuffle=False, num_workers=3) print('Loaded Data') model = ChessPredictModelS().half().to(device) criterion = nn.SmoothL1Loss() optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) train(model, criterion, optimizer, train_loader, val_loader, num_epochs=num_epochs, patience=patience, model_path=model_path)