Skip to content

v0.6.0

Compare
Choose a tag to compare
@analokmaus analokmaus released this 31 Oct 03:57
· 2 commits to master since this release

What's New

  • TorchTrainer can now accept models in which loss functions are included (e.g., ArcFace, LLMs). You must pass criterion = None to the TorchTrainer.train() and calculate the loss inside your training hook.

Sample code

def forward_train(self, trainer, inputs):
    embeddings, loss = trainer.model.*inputs)
    return loss, embeddings.detach()