PyTorch Wrapper is a library that provides a systematic and extensible way to build, train, evaluate, and tune deep learning models using PyTorch. It also provides several ready to use modules and functions for fast model development.
Branch | Build | Coverage |
---|---|---|
Master | ||
Develop |
pip install pytorch-wrapper
git clone https://github.com/jkoutsikakis/pytorch-wrapper.git
cd pytorch-wrapper
pip install .
import torch
import pytorch_wrapper as pw
train_dataloader = ...
val_dataloader = ...
dev_dataloader = ...
evaluators = { 'acc': pw.evaluators.AccuracyEvaluator(), ... }
loss_wrapper = pw.loss_wrappers.GenericPointWiseLossWrapper(torch.nn.BCEWithLogitsLoss())
model = ...
system = pw.System(model=model, device=torch.device('cuda'))
optimizer = torch.optim.Adam(system.model.parameters())
system.train(
loss_wrapper,
optimizer,
train_data_loader=train_dataloader,
evaluators=evaluators,
evaluation_data_loaders={'val': val_dataloader},
callbacks=[
pw.training_callbacks.EarlyStoppingCriterionCallback(
patience=3,
evaluation_data_loader_key='val',
evaluator_key='acc',
tmp_best_state_filepath='current_best.weights'
)
]
)
results = system.evaluate(dev_dataloader, evaluators)
predictions = system.predict(dev_dataloader)
system.save_model_state('model.weights')
system.load_model_state('model.weights')
The docs can be found here.
There are also the following examples in notebook format: