Skip to content

Commit

Permalink
address first round of cr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanfad committed Jun 2, 2016
1 parent f0944da commit 08c71d2
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 25 deletions.
42 changes: 20 additions & 22 deletions contrib/Python/cntk/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def times(left, right, output_rank=1, name=None):
# CNTK uses column vectors and column major representation, thus we reverse
# params
op = Times(right, left, outputRank=output_rank, name=name)
#wrap_numpy_arrays(op)
wrap_numpy_arrays(op)
op.rank = op.x.rank + op.y.rank - 2
return op

Expand Down Expand Up @@ -586,7 +586,7 @@ def relu(x, name=None):
from cntk.ops.cntk2 import Relu
op = Relu(x, name=name)
wrap_numpy_arrays(op)
op.rank = 0
op.rank = op._.rank
return op

def sigmoid(x, name=None):
Expand All @@ -609,7 +609,7 @@ def sigmoid(x, name=None):
from cntk.ops.cntk2 import Sigmoid
op = Sigmoid(x, name=name)
wrap_numpy_arrays(op)
op.rank = 0
op.rank = op._.rank
return op

def tanh(x, name=None):
Expand All @@ -631,7 +631,7 @@ def tanh(x, name=None):
from cntk.ops.cntk2 import Tanh
op = Tanh(x, name=name)
wrap_numpy_arrays(op)
op.rank = 0
op.rank = op._.rank
return op

def softmax(x, name=None):
Expand All @@ -658,7 +658,7 @@ def softmax(x, name=None):
from cntk.ops.cntk2 import Softmax
op = Softmax(x)
wrap_numpy_arrays(op)
op.rank = 0
op.rank = op._.rank
return op

def exp(x, name=None):
Expand Down Expand Up @@ -699,7 +699,7 @@ def log(x, name=None):
CNTK returns -85.1 for log(x) if `x` is negative or zero. The reason is that
it uses 1e-37 (whose natural logarithm is -85.1) as the smallest float
number for `log`, because this is the only guaranteed precision across
platforms. This will be changed to op = `NaN` and `-inf`.
platforms. This will be changed to return `NaN` and `-inf`.
"""
from cntk.ops.cntk2 import Log
op = Log(x, name=name)
Expand All @@ -724,7 +724,7 @@ def sqrt(x, name=None):
Note:
CNTK returns zero for sqrt of negative nubmers, this will be changed to
op = NaN
retrun NaN
"""
from cntk.ops.cntk2 import Sqrt
op = Sqrt(x, name=name)
Expand Down Expand Up @@ -774,8 +774,8 @@ def abs(x, name=None):

def cond(flag, value_if_true, value_if_false, name=None):
"""
op = either value_if_true or value_if_false based on the value of flag.
If flag != 0 value_if_true is op =ed, otherwise value_if_false.
return either value_if_true or value_if_false based on the value of flag.
If flag != 0 value_if_true is returned, otherwise value_if_false.
Behaves analogously to numpy.where(...).
Example:
Expand All @@ -794,7 +794,7 @@ def cond(flag, value_if_true, value_if_false, name=None):
from cntk.ops.cntk1 import If
op = If(flag, value_if_true, value_if_false, name = name)
wrap_numpy_arrays(op)
op.rank = max(op.cond.rank,max(op.thenVal.rank,op.elseVal.rank))
op.rank = max(op.cond.rank,op.thenVal.rank,op.elseVal.rank)
return op

################################################################################
Expand All @@ -808,7 +808,7 @@ def future_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None
the next logical sample. The `time_step` parameter is the number of steps
to look into the future and is 1 by default. If there is no future value (i.e.
the current sample is the last one in the tensor) then the `default_hidden_activation`
value is op =ed which is 0.1 by default.
value is returned which is 0.1 by default.
Example:
>>> data = np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
Expand All @@ -833,7 +833,7 @@ def future_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None
from cntk.ops.cntk1 import FutureValue
op = FutureValue(shape, x, time_step, default_hidden_activation, name = name)
wrap_numpy_arrays(op)
op.rank = 0 if np.isscalar(shape) else len(shape)
op.rank = np.ndim(shape)
return op

def past_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None):
Expand All @@ -843,7 +843,7 @@ def past_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None):
the previous logical sample. The `time_step` parameter is the number of steps
to look into the past and is 1 by default. If there is no past value (i.e.
the current sample is the first one in the tensor) then the `default_hidden_activation`
value is op =ed which is 0.1 by default.
value is returned which is 0.1 by default.
Example:
>>> data = np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
Expand All @@ -868,7 +868,7 @@ def past_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None):
from cntk.ops.cntk1 import PastValue
op = PastValue(shape, x, time_step, default_hidden_activation, name = name)
wrap_numpy_arrays(op)
op.rank = 0 if np.isscalar(shape) else len(shape)
op.rank = np.ndim(shape)
return op

################################################################################
Expand Down Expand Up @@ -903,15 +903,13 @@ def reshape(x, shape, name=None):
shape = tuple(reversed(shape))
op = NewReshape(x, shape, 0, 0, name = name)
wrap_numpy_arrays(op)
op.rank = 0 if np.isscalar(shape) else len(shape)
op.rank = np.ndim(shape)
return op

def transpose_dimensions(x, axis1, axis2, name=None):
"""
Reverses two axes of the tensor. The output tensor has the same data but with
axis1 and axis2 swaped.
The backward pass propagates the received gradient for the output-shape to the input shape.
axis1 and axis2 swapped.
Note:
axes are zero-based as in Numpy, in contrast to CNTK, where 1 is the first axis.
Expand Down Expand Up @@ -1080,7 +1078,7 @@ def input(shape, dynamic_axis='', name=None):
shape = tuple(reversed(shape))
op = Input(shape, dynamicAxis=dynamic_axis, name=name)

op.rank = 0 if np.isscalar(shape) else len(shape)
op.rank = np.ndim(shape)
return op

def sparse_input_numpy(indices, values, shape, alias=None, dynamic_axis='', name=None):
Expand Down Expand Up @@ -1154,7 +1152,7 @@ def sparse_input(shape, dynamic_axis='', name=None):
# cntk uses column major, thus we reverse the shape
shape = tuple(reversed(shape))
op = SparseInput(shape, dynamicAxis=dynamic_axis, name=name)
op.rank = 0 if np.isscalar(shape) else len(shape)
op.rank = np.ndim(shape)
return op

def parameter(shape=None, value=None, learning_rate_multiplier=1.0,
Expand Down Expand Up @@ -1191,7 +1189,7 @@ def parameter(shape=None, value=None, learning_rate_multiplier=1.0,
learningRateMultiplier=learning_rate_multiplier,
name=name)

op.rank = 0 if np.isscalar(shape) else len(shape)
op.rank = np.ndim(shape)
return op

"""
Expand Down Expand Up @@ -1233,7 +1231,7 @@ def parameter(shape=None, value=None, learning_rate_multiplier=1.0,
init='fromLiteral',
initFromLiteral=s.getvalue().decode())

op.rank = 0 if np.isscalar(param_shape) else len(param_shape)
op.rank = np.ndim(param_shape)
return op

def constant(value, name=None):
Expand Down
6 changes: 3 additions & 3 deletions contrib/Python/cntk/ops/tests/linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ def numpy_op(x):


TIMES_PAIRS = [
#([[30.]], [[10.]]),
#([[1.5, 2.1]], [[10.], [20.]]),
([[30.]], [[10.]]),
([[1.5, 2.1]], [[10.], [20.]]),
([[100., 200.]], [[10.], [20.]]),
]

Expand Down Expand Up @@ -266,7 +266,7 @@ def test_op_times(left_operand, right_operand, device_id, precision,
right_as_input = times(constant(left_operand), b)

unittest_helper(left_as_input, None, expected, device_id=device_id,
precision=precision, clean_up=False, backward_pass=False)
precision=precision, clean_up=True, backward_pass=False)

unittest_helper(right_as_input, None, expected, device_id=device_id,
precision=precision, clean_up=True, backward_pass=False)
Expand Down
1 change: 1 addition & 0 deletions contrib/Python/cntk/utils/_fetch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def __init__(self, cond, thenVal, elseVal, op_name='BS.Boolean.If', name=None):
self.thenVal = thenVal
self.elseVal = elseVal
self.params_with_defaults = []
self.inputs = ['cond', 'thenVal', 'elseVal']
"""

Expand Down

0 comments on commit 08c71d2

Please sign in to comment.