Skip to content

Train

Orchestrates training over epochs.

Classification (default): dataset yields (x, y) pairs, loss is cross-entropy, validation metric is accuracy.

Custom loss / metrics: models can override the following hooks to supply their own behaviour (e.g. detection models yielding multi-target batches):

  • model.compute_loss(outputs, *targets) — loss function
  • model.build_validation_metric(device) — initial validation state
  • model.validation_update(state, outputs, *targets) — per-batch update
  • model.validation_compute(state) — finalise metrics dict

When hooks are defined, Trainer unpacks each batch as inputs, *targets = batch and passes *targets through. Classification datasets yielding (x, y) still work unchanged (targets = (y,)).

Schedules: pass lr_schedule="cosine" for warmup+cosine, or any torch.optim.lr_scheduler class. warmup_epochs adds a linear warmup in front of a class-based schedule or modifies the cosine setup.