Skip to content

Provides a systematic and extensible way to build, train, evaluate, and tune deep learning models using PyTorch.

License

Notifications You must be signed in to change notification settings

jkoutsikakis/pytorch-wrapper

Repository files navigation

PyTorch Wrapper

Logo

PyTorch Wrapper is a library that provides a systematic and extensible way to build, train, evaluate, and tune deep learning models using PyTorch. It also provides several ready to use modules and functions for fast model development.

Branch Build Coverage
Master Build Status Coverage Status
Develop Build Status Coverage Status

Installation

From PyPI

pip install pytorch-wrapper

From Source

git clone https://github.com/jkoutsikakis/pytorch-wrapper.git
cd pytorch-wrapper
pip install .

Basic abstract usage pattern

import torch
import pytorch_wrapper as pw

train_dataloader = ...
val_dataloader = ...
dev_dataloader = ...

evaluators = { 'acc': pw.evaluators.AccuracyEvaluator(), ... }
loss_wrapper = pw.loss_wrappers.GenericPointWiseLossWrapper(torch.nn.BCEWithLogitsLoss())

model = ...

system = pw.System(model=model, device=torch.device('cuda'))

optimizer = torch.optim.Adam(system.model.parameters())

system.train(
    loss_wrapper,
    optimizer,
    train_data_loader=train_dataloader,
    evaluators=evaluators,
    evaluation_data_loaders={'val': val_dataloader},
    callbacks=[
        pw.training_callbacks.EarlyStoppingCriterionCallback(
            patience=3,
            evaluation_data_loader_key='val',
            evaluator_key='acc',
            tmp_best_state_filepath='current_best.weights'
        )
    ]
)

results = system.evaluate(dev_dataloader, evaluators)

predictions = system.predict(dev_dataloader)

system.save_model_state('model.weights')
system.load_model_state('model.weights')

Docs & Examples

The docs can be found here.

There are also the following examples in notebook format:

  1. Two Spiral Task
  2. Image Classification Task
  3. Tuning Image Classifier
  4. Text Classification Task
  5. Token Classification Task
  6. Text Classification Task using BERT
  7. Custom Callback
  8. Custom Loss Wrapper
  9. Custom Evaluator