Skip to content

Commit

Permalink
Added section on exporting weights from PyTorch using safetensors
Browse files Browse the repository at this point in the history
  • Loading branch information
bokenator authored Jun 3, 2023
1 parent 084d378 commit 015fb11
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,70 @@ This should print the top 5 imagenet categories for the image. The code for this
println!("{:50} {:5.2}%", class, 100.0 * probability)
}
```
### Importing Pre-Trained Weights from PyTorch Using SafeTensors

`safetensors` is a new simple format by HuggingFace for storing tensors. It does not rely on Python's `pickle` module, and therefore the tensors are not bound to the specific classes and the exact directory structure used when the model is saved. It is also zero-copy, which means that reading the file will require no more memory than the original file.

For more information on `safetensors`, please check out https://github.com/huggingface/safetensors

#### Installing `safetensors`

You can install `safetensors` via the pip manager:

```
pip install safetensors
```

#### Exporting weights in PyTorch

```python
import torchvision
from safetensors import torch as stt

model = torchvision.models.resnet18(pretrained=True)
stt.save_file(model.state_dict(), 'resnet18.safetensors')
```

*Note: the filename of the export must be named with a `.safetensors` suffix for it to be properly decoded by `tch`.*

#### Importing weights in `tch`

```rust
use anyhow::Result;
use tch::{
Device,
Kind,
nn::VarStore,
vision::{
imagenet,
resnet::resnet18,
}
};

fn main() -> Result<()> {
// Create the model and load the pre-trained weights
let mut vs = VarStore::new(Device::cuda_if_available());
let model = resnet18(&vs.root(), 1000);
vs.load("resnet18.safetensors")?;

// Load the image file and resize it to the usual imagenet dimension of 224x224.
let image = imagenet::load_image_and_resize224("dog.jpg")?
.to_device(vs.device());

// Apply the forward pass of the model to get the logits
let output = image
.unsqueeze(0)
.apply_t(&model, false)
.softmax(-1, Kind::Float);

// Print the top 5 categories for this image.
for (probability, class) in imagenet::top(&output, 5).iter() {
println!("{:50} {:5.2}%", class, 100.0 * probability)
}

Ok(())
}
```

Further examples include:
* A simplified version of
Expand Down

0 comments on commit 015fb11

Please sign in to comment.