diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py index 839a680e84045e..63457a42adae69 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py @@ -55,17 +55,18 @@ to_string = lambda s: s.decode("utf-8") -def GetSingleEngineGraphDef(): +# TODO(aaroey): test graph with different dtypes. +def GetSingleEngineGraphDef(dtype=dtypes.float32): """Create a graph containing single segment.""" g = ops.Graph() with g.as_default(): inp = array_ops.placeholder( - dtype=dtypes.float32, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME) + dtype=dtype, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME) with g.device("/GPU:0"): conv_filter = constant_op.constant( [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], name="weights", - dtype=dtypes.float32) + dtype=dtype) conv = nn.conv2d( input=inp, filter=conv_filter, @@ -73,7 +74,7 @@ def GetSingleEngineGraphDef(): padding="SAME", name="conv") bias = constant_op.constant( - [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtypes.float32) + [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtype) added = nn.bias_add(conv, bias, name="bias_add") relu = nn.relu(added, "relu") identity = array_ops.identity(relu, "identity") @@ -83,6 +84,7 @@ def GetSingleEngineGraphDef(): return g.as_graph_def() +# TODO(aaroey): test graph with different dtypes. def GetMultiEngineGraphDef(dtype=dtypes.float32): """Create a graph containing multiple segment.""" g = ops.Graph()