Skip to content

Commit

Permalink
Default dtype made floatx() instead of np.float32 (keras-team#12332)
Browse files Browse the repository at this point in the history
* Replace obsolete link in comment

* Default to floatx() instead of np.float32

* Change in random_binomial
  • Loading branch information
abhaikollara authored and gabrieldemarmiesse committed Feb 22, 2019
1 parent 979e880 commit 0264236
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def variable(value, dtype=None, name=None, constraint=None):
value = value.astype(dtype)

# TODO: remove the conversion when cntk supports int32, int64
# https://docs.microsoft.com/en-us/python/api/cntk.variables.parameter
# https://www.cntk.ai/pythondocs/cntk.variables.html#cntk.variables.Parameter
dtype = 'float32' if 'int' in str(dtype) else dtype

v = C.parameter(shape=shape,
Expand Down Expand Up @@ -386,7 +386,7 @@ def random_binomial(shape, p=0.0, dtype=None, seed=None):
# ensure that randomness is conditioned by the Numpy RNG
seed = np.random.randint(10e7)
if dtype is None:
dtype = np.float32
dtype = floatx()
else:
dtype = _convert_string_dtype(dtype)

Expand Down Expand Up @@ -420,14 +420,12 @@ def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):

def random_uniform_variable(shape, low, high,
dtype=None, name=None, seed=None):
if dtype is None:
dtype = floatx()
if seed is None:
# ensure that randomness is conditioned by the Numpy RNG
seed = np.random.randint(10e3)

if dtype is None:
dtype = np.float32
dtype = floatx()
else:
dtype = _convert_string_dtype(dtype)

Expand All @@ -452,13 +450,11 @@ def random_normal_variable(
dtype=None,
name=None,
seed=None):
if dtype is None:
dtype = floatx()
if seed is None:
# ensure that randomness is conditioned by the Numpy RNG
seed = np.random.randint(10e7)
if dtype is None:
dtype = np.float32
dtype = floatx()
else:
dtype = _convert_string_dtype(dtype)

Expand Down Expand Up @@ -497,7 +493,7 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
if seed is None:
seed = np.random.randint(1, 10e6)
if dtype is None:
dtype = np.float32
dtype = floatx()
else:
dtype = _convert_string_dtype(dtype)

Expand Down

0 comments on commit 0264236

Please sign in to comment.