v0.6.0
What's New
- TorchTrainer can now accept models in which loss functions are included (e.g., ArcFace, LLMs). You must pass
criterion = None
to theTorchTrainer.train()
and calculate the loss inside your training hook.
Sample code
def forward_train(self, trainer, inputs):
embeddings, loss = trainer.model.*inputs)
return loss, embeddings.detach()