PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. You can try it right now, for free, on a single Cloud TPU with Google Colab, and use it in production and on Cloud TPU Pods with Google Cloud.
Take a look at one of our Colab notebooks to quickly try different PyTorch networks running on Cloud TPUs and learn how to use Cloud TPUs as PyTorch devices:
- Getting Started with PyTorch on Cloud TPUs
- Training AlexNet on Fashion MNIST with a single Cloud TPU Core
- Training AlexNet on Fashion MNIST with multiple Cloud TPU Cores
- Fast Neural Style Transfer (NeurIPS 2019 Demo)
- Training A Simple Convolutional Network on MNIST
- Training a ResNet18 Network on CIFAR10
- ImageNet Inference with ResNet50
- Training DC-GAN using Colab Cloud TPU
The rest of this README covers:
- User Guide & Best Practices
- Running PyTorch on Cloud TPUs and GPU Google Cloud also runs networks faster than Google Colab.
- Available docker images and wheels
- Performance Profiling and Auto-Metrics Analysis
- Troubleshooting
- Providing Feedback
- Building and Contributing to PyTorch/XLA
- Additional Reads
Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at PyTorch.org.
Our comprehensive user guides are available at:
Documentation for the latest release
Documentation for master branch
See the API Guide for best practices when writing networks that run on XLA devices(TPU, GPU, CPU and...)
Google Cloud offers TPU VMs for more transparent and easier access to the TPU hardware. This is our recommended way of running PyTorch/XLA on Cloud TPU. Please check out our Cloud TPU VM User Guide. To learn more about the Cloud TPU System Architecture, please check out this doc.
If a single TPU VM does not suit your requirment, you can consider using TPU Pod. TPU Pod is a collection of TPU devices connected by dedicated high-speed network interfaces. Please checkout our Cloud TPU VM Pod User Guide.
The following pre-built docker images are available. For running dockers, check this doc for TPUVM and this doc for GPU.
Version | Cloud TPU VMs Docker |
---|---|
2.0 | gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm |
1.13 | gcr.io/tpu-pytorch/xla:r1.13_3.8_tpuvm |
nightly | gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm |
nightly at date | gcr.io/tpu-pytorch/xla:nightly_3.8_YYYYMMDD |
Version | GPU CUDA 11.8 + Python 3.8 Docker |
---|---|
2.0 | gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.8 |
nightly | gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.8 |
nightly at date(>=20230210) | gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.8_YYYYMMDD |
Version | GPU CUDA 11.7 + Python 3.8 Docker |
---|---|
2.0 | gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.7 |
nightly | gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.7 |
nightly at date(>=20230210) | gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.7_YYYYMMDD |
Version | GPU CUDA 11.2 + Python 3.8 Docker |
---|---|
1.13 | gcr.io/tpu-pytorch/xla:r1.13_3.8_cuda_11.2 |
Version | GPU CUDA 11.2 + Python 3.7 Docker |
---|---|
1.13 | gcr.io/tpu-pytorch/xla:r1.13_3.7_cuda_11.2 |
1.12 | gcr.io/tpu-pytorch/xla:r1.12_3.7_cuda_11.2 |
To run on compute instances with GPUs.
Version | Cloud TPU VMs Wheel |
---|---|
2.0 | https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl |
1.13 | https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.13-cp38-cp38-linux_x86_64.whl |
1.12 | https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.12-cp38-cp38-linux_x86_64.whl |
1.11 | https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.11-cp38-cp38-linux_x86_64.whl |
1.10 | https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.10-cp38-cp38-linux_x86_64.whl |
nightly | https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
Note: For TPU Pod customers using XRT (our legacy runtime), we have custom wheels for torch
, torchvision
, and torch_xla
at https://storage.googleapis.com/tpu-pytorch/wheels/xrt
.
Package | Cloud TPU VMs Wheel (XRT on Pod, Legacy Only) |
---|---|
torch_xla | https://storage.googleapis.com/tpu-pytorch/wheels/xrt/torch_xla-2.0-cp38-cp38-linux_x86_64.whl |
torch | https://storage.googleapis.com/tpu-pytorch/wheels/xrt/torch-2.0-cp38-cp38-linux_x86_64.whl |
torchvision | https://storage.googleapis.com/tpu-pytorch/wheels/xrt/torchvision-2.0-cp38-cp38-linux_x86_64.whl |
Version | GPU Wheel + Python 3.8 |
---|---|
2.0 + CUDA 11.8 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/118/torch_xla-2.0-cp38-cp38-linux_x86_64.whl |
2.0 + CUDA 11.7 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/117/torch_xla-2.0-cp38-cp38-linux_x86_64.whl |
1.13 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.13-cp38-cp38-linux_x86_64.whl |
nightly + CUDA 11.7 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/117/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
nightly + CUDA 11.8 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/118/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
Version | GPU Wheel + Python 3.7 |
---|---|
1.13 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.13-cp37-cp37m-linux_x86_64.whl |
1.12 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl |
1.11 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.11-cp37-cp37m-linux_x86_64.whl |
nightly | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-nightly-cp37-cp37-linux_x86_64.whl |
Version | Colab TPU Wheel |
---|---|
2.0 | https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp39-cp39-linux_x86_64.whl |
1.13 | https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.13-cp37-cp37m-linux_x86_64.whl |
1.12 | https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl |
1.11 | https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.11-cp37-cp37m-linux_x86_64.whl |
You can also add +yyyymmdd
after torch_xla-nightly
to get the nightly wheel of a specified date. To get the companion pytorch and torchvision nightly wheel, replace the torch_xla
with torch
or torchvision
on above wheel links.
For PyTorch/XLA release r2.0 and older and when developing PyTorch/XLA, install the libtpu
pip package with the following command:
pip3 install torch_xla[tpuvm]
This is only required on Cloud TPU VMs.
With PyTorch/XLA we provide a set of performance profiling tooling and auto-metrics analysis which you can check the following resources:
- Official tutorial
- Colab notebook
- Sample MNIST training script with profiling
- Utility script for capturing performance profiles
If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging and optimizing your network(s).
The PyTorch/XLA team is always happy to hear from users and OSS contributors! The best way to reach out is by filing an issue on this Github. Questions, bug reports, feature requests, build issues, etc. are all welcome!
See the contribution guide.
This repository is jointly operated and maintained by Google, Facebook and a number of individual contributors listed in the CONTRIBUTORS file. For questions directed at Facebook, please send an email to [email protected]. For questions directed at Google, please send an email to [email protected]. For all other questions, please open up an issue in this repository here.
You can find additional useful reading materials in