Skip to content

Commit

Permalink
Update TF NN
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed May 18, 2017
1 parent 78c9055 commit 22339cc
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions NN/TF/Networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def get_rs(self, x, predict=True, pipe=False):

@NNTiming.timeit(level=4, prefix="[API] ")
def add(self, layer, *args, **kwargs):

# Init kwargs
kwargs["apply_bias"] = kwargs.get("apply_bias", True)
kwargs["position"] = kwargs.get("position", len(self._layers) + 1)
Expand Down Expand Up @@ -571,8 +570,9 @@ def _draw_detailed_network(self, radius=6, width=1200, height=800, padding=0.2,
input_x, input_y = np.meshgrid(xf, yf)
input_xs = np.c_[input_x.ravel().astype(np.float32), input_y.ravel().astype(np.float32)]

_activations = [activation.eval(feed_dict={self._tfx: input_xs}).T.reshape(units[i + 1], plot_num, plot_num)
for i, activation in enumerate(self._activations)]
_activations = self._sess.run(self._activations, {self._tfx: input_xs})
_activations = [activation.T.reshape(units[i + 1], plot_num, plot_num)
for i, activation in enumerate(_activations)]
_graphs = []
for j, activation in enumerate(_activations):
_graph_group = []
Expand Down Expand Up @@ -726,10 +726,10 @@ def build(self, units="load"):
raise BuildLayerError("At least 2 layers are needed")
_input_shape = (units[0], units[1])
self.initialize()
self.add(ReLU(_input_shape))
self.add("ReLU", _input_shape)
for unit_num in units[2:]:
self.add(ReLU((unit_num,)))
self.add(CrossEntropy((units[-1],)))
self.add("ReLU", (unit_num,))
self.add("CrossEntropy", (units[-1],))

@NNTiming.timeit(level=4, prefix="[API] ")
def split_data(self, x, y, x_test, y_test,
Expand Down

0 comments on commit 22339cc

Please sign in to comment.