Skip to content

Commit

Permalink
change
Browse files Browse the repository at this point in the history
  • Loading branch information
Yazhou-Z committed Oct 24, 2021
1 parent e19eb7a commit 3267013
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions CNN_multichannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ def forward(self, x):
return logit

batch_size = 2
learning_rate = 1e-6
learning_rate = 8e-5
num_epoches = 40000
valid_loss_min = 2

model = Cnn1d(7)

Expand All @@ -168,9 +169,8 @@ def forward(self, x):

loss_curve = []
tr_acc = []
valid_loss_min = 2
# train
epoch = 1
epoch = 0
while epoch < num_epoches:
train_acc = 0
for i in range(len(X_tensor)):
Expand Down Expand Up @@ -207,6 +207,8 @@ def forward(self, x):

# print('epoch: {}, loss: {:.4}'.format(epoch, print_loss), 'step: ', i + 1)

epoch += 1

# calculate accuracy
acc = train_acc / len(X_tensor)
tr_acc.append(acc)
Expand All @@ -215,7 +217,7 @@ def forward(self, x):

if epoch % 10 == 0:
print('epoch: {}, loss: {:.4}, acc: {:.4}'.format(epoch, print_loss, acc))
print(out, '->', pred, ':', label - 1, loss)
# print(out, '->', pred, ':', label - 1, loss)

# create checkpoint variable and add important data
checkpoint = {
Expand All @@ -231,7 +233,6 @@ def forward(self, x):
torch.save(model.state_dict(), 'model_best')
valid_loss_min = print_loss

epoch += 1

plt.plot(loss_curve)
fig2 = plt.gcf()
Expand Down

0 comments on commit 3267013

Please sign in to comment.