Skip to content

Commit

Permalink
added Input layer type
Browse files Browse the repository at this point in the history
  • Loading branch information
MrNothing committed Oct 16, 2019
1 parent 5cdd6ec commit 444071d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
10 changes: 6 additions & 4 deletions Sources/scripts/keras_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@ def Run(self):

X_batch = np.asarray(batch[0])
Y_batch = np.asarray(batch[1])

X_batch = np.reshape(X_batch, [self._input.batch_size]+self.model.input_shape)

loss = self.model.instance.train_on_batch(X_batch, Y_batch)
infos = self.model.instance.train_on_batch(X_batch, Y_batch)
SetState(self.id, it/self.epochs)

#every N steps, send the state to the scene
if it % self.display_step == 0:
SetState(self.id, it/self.training_iterations)
SendChartData(self.id, "Loss", loss, "#ff0000")

SendChartData(self.id, "Loss", infos[0], "#ff0000")
Log("Loss: "+str(infos[0]))
if self._type=="image":
test_X = [X_batch[0]]
test_Y = [Y_batch[0]]
Expand Down
5 changes: 5 additions & 0 deletions Sources/scripts/keras_sequencial.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def ParseInput(self, val, has_input=False):
return Activation(args[1])
elif typ=="Dropout":
return Dropout(float(args[1]))
elif typ=="Input":
sh = []
for i in range(len(args)-1):
sh += int(args[i+1])
return Input(shape=sh)
elif typ=="Flatten":
return Flatten()
elif typ=="MaxPooling2D":
Expand Down

0 comments on commit 444071d

Please sign in to comment.