Skip to content

Rust bindings for the C++ api of PyTorch.

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT
Notifications You must be signed in to change notification settings

NOBLES5E/tch-rs

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

torch-rust

Some very experimental rust bindings for PyTorch. The code generation part for the C api on top of libtorch comes from ocaml-torch.

Instructions

LD_LIBRARY_PATH=/.../libtorch/lib LIBTORCH=/.../libtorch cargo run --example basics

Examples

The following code trains a linear classifier on MNIST as a proof of concept.

    let m = vision::Mnist::load_dir(std::path::Path::new("data")).unwrap();
    let mut ws = Tensor::zeros(&[IMAGE_DIM, LABELS], Kind::Float).set_requires_grad(true);
    let mut bs = Tensor::zeros(&[LABELS], Kind::Float).set_requires_grad(true);
    for epoch in 1..200 {
        let logits = m.train_images.mm(&ws) + &bs;
        let loss = logits.log_softmax(-1).nll_loss(&m.train_labels);
        ws.zero_grad();
        bs.zero_grad();
        loss.backward();
        no_grad(|| {
            ws += ws.grad() * (-1);
            bs += bs.grad() * (-1);
        });
        let test_logits = m.test_images.mm(&ws) + &bs;
        let test_accuracy = test_logits
            .argmax(-1)
            .eq(&m.test_labels)
            .to_kind(Kind::Float)
            .mean()
            .double_value(&[]);
        println!(
            "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
            epoch,
            loss.double_value(&[]),
            100. * test_accuracy
        );
    }

This can be run with this command.

LD_LIBRARY_PATH=/.../libtorch/lib LIBTORCH=/.../libtorch cargo run --example mnist

About

Rust bindings for the C++ api of PyTorch.

Resources

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Rust 59.4%
  • C 22.9%
  • C++ 16.6%
  • OCaml 1.0%
  • Python 0.1%
  • CMake 0.0%