Skip to content

Commit

Permalink
More Fixed Autonaming, In-place ReLu, Ease of use
Browse files Browse the repository at this point in the history
  • Loading branch information
hctomkins committed Aug 8, 2015
1 parent 4e14d88 commit 3b15d96
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
2 changes: 1 addition & 1 deletion CaffeGenerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def findsocket(socketname,node): #Given a node, find the position of a ce
def autotop(node,socket,orderpass=0): #Assigns an arbitrary top name to a node
print('autotop')
if isinplace(node) and not orderpass:
top = nodebefore(node).name + str(socket)
top = autobottom(node,0,orderpass=0)
else:
top = node.name + str(socket)
return top
Expand Down
33 changes: 24 additions & 9 deletions CaffeNodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def calcsize(self, context,axis='x'):
fcsizes.extend([0])
passes.extend([0])
reversals.extend([0])
poolsizes.extend([node.kernel])
poolsizes.extend([node.kernel_size])
poolstrides.extend([node.stride])
offsets.extend([1])
node = node.inputs[0].links[0].from_node
Expand Down Expand Up @@ -341,13 +341,24 @@ def draw_color(self, context, node):
return (0.0, 0.8, 0.8, 0.5)


class params_p_g(bpy.types.PropertyGroup):
class params_p_gw(bpy.types.PropertyGroup):
name = bpy.props.StringProperty(name='Shared name')
lr_mult = bpy.props.FloatProperty(default=1.0)
decay_mult = bpy.props.FloatProperty(default=1.0)

def draw(self, context, layout):
layout.prop(self, "name")
#layout.prop(self, "name")
layout.prop(self, "lr_mult")
layout.prop(self, "decay_mult")


class params_p_gb(bpy.types.PropertyGroup):
name = bpy.props.StringProperty(name='Shared name')
lr_mult = bpy.props.FloatProperty(default=2.0)
decay_mult = bpy.props.FloatProperty(default=0.0)

def draw(self, context, layout):
#layout.prop(self, "name")
layout.prop(self, "lr_mult")
layout.prop(self, "decay_mult")

Expand All @@ -357,8 +368,8 @@ def poll(cls, ntree):
return ntree.bl_idname == 'CaffeNodeTree'

extra_params = bpy.props.BoolProperty(name='Extra Parameters', default=False)
weight_params = bpy.props.PointerProperty(type=params_p_g)
bias_params = bpy.props.PointerProperty(type=params_p_g)
weight_params = bpy.props.PointerProperty(type=params_p_gw)
bias_params = bpy.props.PointerProperty(type=params_p_gb)

phases = [("TRAIN", "TRAIN", "Train only"),
("TEST", "TEST", "Test only"),
Expand Down Expand Up @@ -494,6 +505,8 @@ def draw_buttons(self, context, layout):
layout.prop(self, "rand_skip")
elif self.db_type == 'HDF5Data':
layout.prop(self, "shuffle")
layout.prop(self, "height")
layout.prop(self, "width")
else:
layout.prop(self, "rand_skip")
layout.prop(self, "height")
Expand Down Expand Up @@ -1115,7 +1128,7 @@ class ReLuNode(Node, CaffeTreeNode):
# === Optional Functions ===
def init(self, context):
self.inputs.new('ImageSocketType', "Input image")
self.outputs.new('OutputSocketType', "Rectified output")
self.outputs.new('InPlaceOutputSocketType', "Rectified output")


# Copy function to initialize a copied node from an existing one.
Expand Down Expand Up @@ -1294,7 +1307,7 @@ class DropoutNode(Node, CaffeTreeNode):
# === Optional Functions ===
def init(self, context):
self.inputs.new('NAFlatSocketType', "Input image")
self.outputs.new('OutputSocketType', "Output image")
self.outputs.new('InPlaceOutputSocketType', "Output image")


# Copy function to initialize a copied node from an existing one.
Expand Down Expand Up @@ -1821,7 +1834,8 @@ def poll(cls, context):

def register():
bpy.utils.register_class(filler_p_g)
bpy.utils.register_class(params_p_g)
bpy.utils.register_class(params_p_gw)
bpy.utils.register_class(params_p_gb)
bpy.utils.register_class(slice_point_p_g)
bpy.utils.register_class(OutputSocket)
bpy.utils.register_class(CaffeTree)
Expand Down Expand Up @@ -1865,7 +1879,8 @@ def unregister():
nodeitems_utils.unregister_node_categories("CUSTOM_NODES")

bpy.utils.unregister_class(filler_p_g)
bpy.utils.unregister_class(params_p_g)
bpy.utils.unregister_class(params_p_gw)
bpy.utils.unregister_class(params_p_gb)
bpy.utils.unregister_class(slice_point_p_g)
bpy.utils.unregister_class(OutputSocket)
bpy.utils.unregister_class(InPlaceOutputSocket)
Expand Down

0 comments on commit 3b15d96

Please sign in to comment.