Skip to content

Commit

Permalink
preparing for distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
bmccann committed Aug 10, 2017
1 parent 580d448 commit e3ad9fe
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 82 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.pyc
5 changes: 3 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-la

ENV PATH /opt/conda/bin:$PATH
RUN conda install -c soumith pytorch=0.1.12 cuda80
RUN pip install -r https://raw.githubusercontent.com/pytorch/text/master/requirements.txt
RUN pip install git+https://github.com/pytorch/text.git

# Default to utf-8 encodings in python
# Can verify in container with:
Expand All @@ -28,3 +26,6 @@ RUN locale-gen en_US.UTF-8
ENV LANG en_US.UTF-8
ENV LANGUAGE en_US:en
ENV LC_ALL en_US.UTF-8

ADD ./ /cove/
RUN cd cove && pip install -r requirements.txt && python setup.py develop
27 changes: 16 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,26 @@ which takes in sequences of vectors pretrained with GloVe and outputs CoVe.

## Running with Docker

We have included a Dockerfile that covers all dependencies.
We typically use this code on a machine with a GPU,
so we use `nvidia-docker`.
Install [Docker](https://www.docker.com/get-docker).
Install [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) if you would like to use with with a GPU.

Once you have installed [Docker](https://www.docker.com/get-docker),
pull the docker image with `docker pull bmccann/cove`.
Then you can use `nvidia-docker run -it -v /path/to/cove/:/cove cove` to start a docker container using that image.
Once the container is running,
you can use `nvidia-docker ps` to find the `container_name` and
`nvidia-docker exec -it container_name bash -c "cd cove && python example.py"` to run example.py.
```bash
docker pull bmccann/cove # pull the docker image
docker run -it cove # start a docker container
python /cove/test/example.py
```

## Running without Docker

You will need to install PyTorch and then run `pip install -r requirements.txt`
Run the example with `python example.py`.
Install [PyTorch](http://pytorch.org/).

```bash
git clone https://github.com/salesforce/cove.git # use ssh: [email protected]:salesforce/cove.git
cd cove
pip install -r requirements.txt
python setup.py develop
python test/example.py
```


## References
Expand Down
4 changes: 4 additions & 0 deletions cove/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .encoder import *


_all__ = ['MTLSTM']
57 changes: 57 additions & 0 deletions cove/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os

import torch
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from torch.nn.utils.rnn import pack_padded_sequence as pack
import torch.utils.model_zoo as model_zoo


model_urls = {
'wmt-lstm' : 'https://s3.amazonaws.com/research.metamind.io/cove/wmtlstm-b142a7f2.pth'
}

model_cache = os.path.join(os.path.dirname(os.path.realpath(__file__)), '.torch')


class MTLSTM(nn.Module):

def __init__(self, n_vocab=None, vectors=None, residual_embeddings=False):
"""Initialize an MTLSTM.
Arguments:
n_vocab (bool): If not None, initialize MTLSTM with an embedding matrix with n_vocab vectors
vectors (Float Tensor): If not None, initialize embedding matrix with specified vectors
residual_embedding (bool): If True, concatenate the input embeddings with MTLSTM outputs during forward
"""
super().__init__()
self.embed = False
if n_vocab is not None:
self.embed = True
self.vectors = nn.Embedding(n_vocab, 300)
if vectors is not None:
self.vectors.weight.data = vectors
self.rnn = nn.LSTM(300, 300, num_layers=2, bidirectional=True)
self.rnn.load_state_dict(model_zoo.load_url(model_urls['wmt-lstm'], model_dir=model_cache))
self.residual_embeddings = residual_embeddings

def forward(self, inputs, lengths, hidden=None):
"""A pretrained MT-LSTM (McCann et. al. 2017).
This LSTM was trained with 300d 840B GloVe on the WMT 2017 machine translation dataset.
Arguments:
inputs (Tensor): If MTLSTM handles embedding, a Long Tensor of size (batch_size, timesteps).
Otherwise, a Float Tensor of size (batch_size, timesteps, features).
lengths (Long Tensor): (batch_size, lengths) lenghts of each sequence for handling padding
hidden (Float Tensor): initial hidden state of the LSTM
"""
if self.embed:
inputs = self.vectors(inputs.t()).t()
lens, indices = torch.sort(lengths, 0, True)
outputs, hidden_t = self.rnn(pack(inputs[indices], lens.tolist(), batch_first=True), hidden)
outputs = unpack(outputs, batch_first=True)[0]
_, _indices = torch.sort(indices, 0)
outputs = outputs[_indices]
if self.residual_embeddings:
outputs = torch.cat([inputs, outputs], 2)
return outputs
31 changes: 0 additions & 31 deletions example.py

This file was deleted.

37 changes: 0 additions & 37 deletions mtlstm.py

This file was deleted.

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
-r https://raw.githubusercontent.com/pytorch/text/master/requirements.txt
-e git://github.com/pytorch/text.git#egg=torchtext
git+https://github.com/pytorch/text.git
23 changes: 23 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/usr/bin/env python
from setuptools import setup, find_packages
from codecs import open
from os import path


with open(path.join(path.abspath(path.dirname(__file__)), 'README.md'), encoding='utf-8') as f:
long_description = f.read()

setup_info = dict(
name='cove',
version='1.0.0',
author='Bryan McCann',
author_email='[email protected]',
url='https://github.com/salesforce/cove',
description='Context Vectors for Deep Learning and NLP',
long_description=long_description,
license='BSD 3-Clause',
keywords='cove, context vectors, deep learning, natural language processing',
packages=find_packages()
)

setup(**setup_info)

0 comments on commit e3ad9fe

Please sign in to comment.