83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import os
|
|
from model import resnet18, ChessPredictModelBaby, ChessPredictModelS
|
|
from dataloader import create_dataloader
|
|
from tqdm import tqdm
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
def train(model, criterion, optimizer, train_loader, val_loader, num_epochs=10, patience=3, model_path='best_model.pth'):
|
|
best_loss = float('inf')
|
|
epochs_no_improve = 0
|
|
|
|
if os.path.exists(model_path):
|
|
model.load_state_dict(torch.load(model_path, weights_only=True))
|
|
print(f"Loaded saved model from {model_path}")
|
|
|
|
print('Started training')
|
|
for epoch in range(num_epochs):
|
|
model.train()
|
|
train_loss = 0.0
|
|
|
|
train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]", unit="batch")
|
|
for inputs, labels in train_loader_tqdm:
|
|
inputs = inputs.to(device)
|
|
labels = labels.to(device)
|
|
optimizer.zero_grad()
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs.squeeze(1), labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
train_loss += loss.item() * inputs.size(0)
|
|
|
|
train_loss /= len(train_loader.dataset)
|
|
|
|
model.eval()
|
|
val_loss = 0.0
|
|
|
|
with torch.no_grad():
|
|
val_loader_tqdm = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]", unit="batch")
|
|
for inputs, labels in val_loader_tqdm:
|
|
inputs = inputs.to(device)
|
|
labels = labels.to(device)
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs.squeeze(1), labels)
|
|
val_loss += loss.item() * inputs.size(0)
|
|
|
|
val_loss /= len(val_loader.dataset)
|
|
|
|
print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
|
|
|
|
# Check for overfitting
|
|
if val_loss < best_loss:
|
|
best_loss = val_loss
|
|
epochs_no_improve = 0
|
|
torch.save(model.state_dict(), model_path)
|
|
print('Model saved!')
|
|
else:
|
|
epochs_no_improve += 1
|
|
if epochs_no_improve == patience:
|
|
print('Early stopping!')
|
|
break
|
|
|
|
if __name__ == "__main__":
|
|
batch_size = 32
|
|
num_epochs = 50
|
|
learning_rate = 0.001
|
|
patience = 2
|
|
model_path = 'best_model_baby.pth'
|
|
weight_decay = 0
|
|
|
|
print('Loading Data')
|
|
train_loader = create_dataloader(table_name='train', batch_size=batch_size, shuffle=True, num_workers=3)
|
|
val_loader = create_dataloader(table_name='test', batch_size=batch_size, shuffle=False, num_workers=3)
|
|
print('Loaded Data')
|
|
|
|
model = ChessPredictModelBaby().half().to(device)
|
|
criterion = nn.SmoothL1Loss()
|
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
|
|
|
train(model, criterion, optimizer, train_loader, val_loader, num_epochs=num_epochs, patience=patience, model_path=model_path)
|