Skip to content

Commit

Permalink
Fix review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
aaroey committed Jun 25, 2018
1 parent f4a68ba commit d1a6a52
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,26 @@
to_string = lambda s: s.decode("utf-8")


def GetSingleEngineGraphDef():
# TODO(aaroey): test graph with different dtypes.
def GetSingleEngineGraphDef(dtype=dtypes.float32):
"""Create a graph containing single segment."""
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
dtype=dtypes.float32, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME)
dtype=dtype, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME)
with g.device("/GPU:0"):
conv_filter = constant_op.constant(
[[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
name="weights",
dtype=dtypes.float32)
dtype=dtype)
conv = nn.conv2d(
input=inp,
filter=conv_filter,
strides=[1, 2, 2, 1],
padding="SAME",
name="conv")
bias = constant_op.constant(
[4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtypes.float32)
[4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtype)
added = nn.bias_add(conv, bias, name="bias_add")
relu = nn.relu(added, "relu")
identity = array_ops.identity(relu, "identity")
Expand All @@ -83,6 +84,7 @@ def GetSingleEngineGraphDef():
return g.as_graph_def()


# TODO(aaroey): test graph with different dtypes.
def GetMultiEngineGraphDef(dtype=dtypes.float32):
"""Create a graph containing multiple segment."""
g = ops.Graph()
Expand Down

0 comments on commit d1a6a52

Please sign in to comment.