Skip to content

Commit

Permalink
Enable Travis on Trax.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 308385447
  • Loading branch information
afrozenator authored and copybara-github committed Apr 25, 2020
1 parent 6922375 commit a960735
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 12 deletions.
23 changes: 23 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
sudo: required
language: python
cache: pip
git:
depth: 3
quiet: true
services:
- docker
python:
- "3.6"
env:
global:
- TF_LATEST="2.1.*"
matrix:
- TF_VERSION="2.1.*"
install:
- ./oss_scripts/oss_pip_install.sh
script:
- ./oss_scripts/oss_tests.sh

# - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
# pylint -j 2 trax;
# fi
40 changes: 40 additions & 0 deletions oss_scripts/oss_pip_install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2020 The Trax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#!/bin/bash

set -v # print commands as they're executed
set -e # fail and exit on any command erroring

: "${TF_VERSION:?}"

# Make sure we have the latest pip and setuptools installed.
pip install -q -U pip
pip install -q -U setuptools

# Make sure we have the latest version of numpy - avoid problems we were
# seeing with Python 3
pip install -q -U numpy
pip install -q "tensorflow==$TF_VERSION"

# Just print the version again to make sure.
python -c 'import tensorflow as tf; print(tf.__version__)'

# First ensure that the base dependencies are sufficient for a full import
pip install -q -e .

# Then install the test dependencies
pip install -q -e .[tests]
# Make sure to install the atari extras for gym
pip install "gym[atari]"
80 changes: 80 additions & 0 deletions oss_scripts/oss_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2020 The Trax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#!/bin/bash

set -v # print commands as they're executed

# Instead of exiting on any failure with "set -e", we'll call set_status after
# each command and exit $STATUS at the end.
STATUS=0
function set_status() {
local last_status=$?
if [[ $last_status -ne 0 ]]
then
echo "<<<<<<FAILED>>>>>> Exit code: $last_status"
fi
STATUS=$(($last_status || $STATUS))
}

# Check env vars set
echo "${TF_VERSION:?}" && \
echo "${TF_LATEST:?}" && \
echo "${TRAVIS_PYTHON_VERSION:?}"
set_status
if [[ $STATUS -ne 0 ]]
then
exit $STATUS
fi

# Check import.
python -c "import trax"
set_status

# Check notebooks.
# TODO(afrozm): Add more.
jupyter nbconvert --ExecutePreprocessor.kernel_name=python3 \
--ExecutePreprocessor.timeout=600 --to notebook --execute \
trax/intro.ipynb;
set_status

# Check tests, separate out directories for easy triage.
pytest --disable-warnings trax/layers
set_status

pytest --disable-warnings trax/math
set_status

pytest --disable-warnings trax/models
set_status

pytest --disable-warnings trax/optimizers
set_status

pytest --disable-warnings trax/rl
set_status

pytest --disable-warnings trax/supervised
set_status

pytest --disable-warnings \
--ignore=trax/layers \
--ignore=trax/math \
--ignore=trax/models \
--ignore=trax/optimizers \
--ignore=trax/rl \
--ignore=trax/supervised
set_status

exit $STATUS
22 changes: 12 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,30 @@
license='Apache 2.0',
packages=find_packages(),
install_requires=[
'absl-py',
'funcsigs',
'gin-config',
'gym',
'gym==0.14.0',
'jax',
'jaxlib',
'numpy',
'scipy',
'six',
'jax',
'jaxlib',
'tensor2tensor',
'tensorflow-datasets',
'absl-py',
'funcsigs'
],
extras_require={
'tensorflow': ['tensorflow>=1.14.0'],
'tensorflow_gpu': ['tensorflow-gpu>=1.14.0'],
'tensorflow': ['tensorflow>=1.15.0'],
'tensorflow_gpu': ['tensorflow-gpu>=1.15.0'],
'tests': [
'attrs',
'pytest',
'mock',
'pylint',
'jupyter',
'matplotlib',
'mock',
'parameterized',
'pylint',
'pytest',
'wrapt==1.11.*',
],
},
classifiers=[
Expand Down
4 changes: 2 additions & 2 deletions trax/models/research/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""BERT."""

import jax
from tensorflow.train import load_checkpoint
import tensorflow as tf

from trax import layers as tl
from trax.math import numpy as np
Expand Down Expand Up @@ -144,7 +144,7 @@ def new_weights_and_state(self, input_signature):
return weights, state

print('Loading pre-trained weights from', self.init_checkpoint)
ckpt = load_checkpoint(self.init_checkpoint)
ckpt = tf.train.load_checkpoint(self.init_checkpoint)

def reshape_qkv(name):
x = ckpt.get_tensor(name)
Expand Down

0 comments on commit a960735

Please sign in to comment.