Skip to content
/ apex Public
forked from NVIDIA/apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

License

Notifications You must be signed in to change notification settings

ryanleary/apex

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

33 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Introduction

This repo is designed to hold PyTorch modules and utilities that are under active development and experimental. This repo is not designed as a long term solution or a production solution. Things placed in here are intended to be eventually moved to upstream PyTorch.

Requirements

Python 3 PyTorch 0.3 or newer CUDA 9

Quick Start

To build the extension run the following command in the root directory of this project

python setup.py install

To use the extension simply run

import apex

and optionally (if required for your use)

import apex._C as apex_backend

What's included

Current version of apex contains:

  1. Mixed precision utilities can be found here examples of using mixed precision utilities can be found for the PyTorch imagenet example and the PyTorch word language model example.
  2. Parallel utilities can be found here and an example/walkthrough can be found here
  • apex/parallel/distributed.py contains a simplified implementation of PyTorch's DistributedDataParallel that's optimized for use with NCCL in single gpu / process mode
  • apex/parallel/multiproc.py is a simple multi-process launcher that can be used on a single node/computer with multiple GPU's
  1. Reparameterization function that allows you to recursively apply reparameterization to an entire module (including children modules).
  2. An experimental and in development flexible RNN API.

About

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 79.0%
  • Cuda 17.2%
  • C++ 3.7%
  • Shell 0.1%