Train¶
Orchestrates training over epochs.
Classification (default): dataset yields (x, y) pairs, loss is
cross-entropy, validation metric is accuracy.
Custom loss / metrics: models can override the following hooks to supply their own behaviour (e.g. detection models yielding multi-target batches):
model.compute_loss(outputs, *targets)— loss functionmodel.build_validation_metric(device)— initial validation statemodel.validation_update(state, outputs, *targets)— per-batch updatemodel.validation_compute(state)— finalise metrics dict
When hooks are defined, Trainer unpacks each batch as
inputs, *targets = batch and passes *targets through. Classification
datasets yielding (x, y) still work unchanged (targets = (y,)).
Schedules: pass lr_schedule="cosine" for warmup+cosine, or any
torch.optim.lr_scheduler class. warmup_epochs adds a linear warmup
in front of a class-based schedule or modifies the cosine setup.