Skip to content

Commit

Permalink
split the input list of conv_operator into two inputs: image and filt…
Browse files Browse the repository at this point in the history
  • Loading branch information
Haonan authored and emailweixu committed Sep 21, 2016
1 parent b130ba7 commit 98bc889
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions python/paddle/trainer_config_helpers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2667,7 +2667,7 @@ def __add_evaluator__(e):

return LayerOutput(name, LayerType.COST, parents=[input, label])

def conv_operator(input, filter_size, num_filters,
def conv_operator(img, filter, filter_size, num_filters,
num_channel=None, stride=1, padding=0, groups=1,
filter_size_y=None, stride_y=None, padding_y=None):
"""
Expand All @@ -2680,13 +2680,16 @@ def conv_operator(input, filter_size, num_filters,
.. code-block:: python
op = conv_operator(input=[layer1, layer2],
op = conv_operator(img=input1,
filter=input2,
filter_size=3.0,
num_filters=64,
num_channels=64)
:param input: Input layer.
:type input: LayerOutput|list|tuple
:param img: input image
:type img: LayerOutput
:param filter: input filter
:type filter: LayerOutput
:param filter_size: The x dimension of a filter kernel.
:type filter_size: int
:param filter_size_y: The y dimension of a filter kernel. Since
Expand All @@ -2708,14 +2711,13 @@ def conv_operator(input, filter_size, num_filters,
:return: A ConvOperator Object.
:rtype: ConvOperator
"""
assert isinstance(input, list) or isinstance(input, tuple)
if filter_size_y is None:
filter_size_y = filter_size
if stride_y is None:
stride_y = stride
if padding_y is None:
padding_y = padding
op = ConvOperator(input_layer_names=[x.name for x in input],
op = ConvOperator(input_layer_names=[img.name, filter.name],
num_filters = num_filter,
conv_conf=Conv(filter_size=filter_size,
padding=padding,
Expand All @@ -2725,7 +2727,7 @@ def conv_operator(input, filter_size, num_filters,
padding_y=padding_y,
stride_y=stride_y,
groups=groups))
op.origin = input
op.origin = [img, filter]
op.origin.operator = "conv_op"
return op

Expand Down

0 comments on commit 98bc889

Please sign in to comment.