import torch
best_val_loss = float('inf')
for epoch in range(1, num_epochs + 1):
model.train()
train_loss = 0.0
for batch in train_loader:
inputs, targets = [item.to(device) for item in batch]
optimizer.zero_grad(set_to_none=True)
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item() * inputs.size(0)
model.eval()
val_loss = 0.0
with torch.no_grad():
for batch in valid_loader:
inputs, targets = [item.to(device) for item in batch]
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item() * inputs.size(0)
val_loss /= len(valid_loader.dataset)
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save({'model_state': model.state_dict()}, 'best_model.pt')
scheduler.step(val_loss)
print(f'epoch={epoch} val_loss={val_loss:.4f}')
The training loop is where research code either becomes maintainable or turns into a mess. I keep it explicit: train phase, validation phase, scheduler step, metric tracking, and checkpoint saving. That structure pays off immediately when experiments fail halfway through or need to be resumed on another machine.