Skip to content

Commit

Permalink
Fix build breakage due to soft torch import (ray-project#7790)
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl authored Mar 29, 2020
1 parent e4bd5db commit d6255c3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ matrix:
- if [ $RAY_CI_LINUX_WHEELS_AFFECTED != "1" ]; then exit; fi

# Explicitly sleep 60 seconds for logs to go through
- ./ci/travis/test-wheels.sh || cat /tmp/ray/session_latest/logs/* || sleep 60 || false
- ./ci/travis/test-wheels.sh || cat /tmp/ray/session_latest/logs/* || (sleep 60 && false)
cache: false

# Build MacOS wheels.
Expand All @@ -168,7 +168,7 @@ matrix:
- if [ $RAY_CI_MACOS_WHEELS_AFFECTED != "1" ]; then exit; fi

# Explicitly sleep 60 seconds for logs to go through
- ./ci/travis/test-wheels.sh || cat /tmp/ray/session_latest/logs/* || sleep 60 || false
- ./ci/travis/test-wheels.sh || cat /tmp/ray/session_latest/logs/* || (sleep 60 && false)

# RLlib: Learning tests (from rllib/tuned_examples/regression_tests/*.yaml).
- os: linux
Expand Down
16 changes: 15 additions & 1 deletion rllib/utils/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@ def try_import_tfp(error=False):
return None


# Fake module for torch.nn.
class NNStub:
pass


# Fake class for torch.nn.Module to allow it to be inherited from.
class ModuleStub:
def __init__(self, *a, **kw):
raise ImportError("Could not import `torch`.")


def try_import_torch(error=False):
"""
Args:
Expand All @@ -118,7 +129,10 @@ def try_import_torch(error=False):
except ImportError as e:
if error:
raise e
return None, None

nn = NNStub()
nn.Module = ModuleStub
return None, nn


def get_variable(value, framework="tf", tf_name="unnamed-variable"):
Expand Down

0 comments on commit d6255c3

Please sign in to comment.