Skip to content

Commit

Permalink
Update a CNTK version on Travis (keras-team#10419)
Browse files Browse the repository at this point in the history
  • Loading branch information
taehoonlee authored and fchollet committed Jun 13, 2018
1 parent 4f90f95 commit a68c516
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ install:

# install cntk
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
pip install https://cntk.ai/PythonWheel/CPU-Only/cntk-2.3.1-cp27-cp27mu-linux_x86_64.whl;
pip install https://cntk.ai/PythonWheel/CPU-Only/cntk-2.5.1-cp27-cp27mu-linux_x86_64.whl;
elif [[ "$TRAVIS_PYTHON_VERSION" == "3.6" ]]; then
pip install https://cntk.ai/PythonWheel/CPU-Only/cntk-2.3.1-cp36-cp36m-linux_x86_64.whl;
pip install https://cntk.ai/PythonWheel/CPU-Only/cntk-2.5.1-cp36-cp36m-linux_x86_64.whl;
fi

# install pydot for visualization tests
Expand Down
9 changes: 6 additions & 3 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,12 @@ def in_test_phase(x, alt, training=None):


def _convert_string_dtype(dtype):
# cntk only support float32 and float64
if dtype == 'float32':
return np.float32
elif dtype == 'float64':
return np.float64
elif dtype == 'float16':
return np.float16
else:
# cntk only running with float,
# try to cast to float to run the model
Expand All @@ -125,10 +126,12 @@ def _convert_dtype_string(dtype):
return 'float32'
elif dtype == np.float64:
return 'float64'
elif dtype == np.float16:
return 'float16'
else:
raise ValueError('CNTK Backend: Unsupported dtype: %s. '
'CNTK only supports float32 and '
'float64.' % dtype)
'CNTK only supports float32, float64, and '
'float16.' % dtype)


def variable(value, dtype=None, name=None, constraint=None):
Expand Down
6 changes: 1 addition & 5 deletions tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1767,11 +1767,7 @@ def test_set_floatx(self):
def test_dtype(self):
assert K.dtype(K.variable(1, dtype='float64')) == 'float64'
assert K.dtype(K.variable(1, dtype='float32')) == 'float32'
if K.backend() == 'cntk':
with pytest.raises(ValueError):
K.variable(1, dtype='float16')
else:
assert K.dtype(K.variable(1, dtype='float16')) == 'float16'
assert K.dtype(K.variable(1, dtype='float16')) == 'float16'

def test_variable_support_bool_dtype(self):
# Github issue: 7819
Expand Down

0 comments on commit a68c516

Please sign in to comment.