Skip to content

Commit 3e1613d

Browse files
chatbot_tutorial.py: Solve the optimizer cuda call problem
If you don't configure this string of code, you will get an error when you iterate over the update from 4000_checkpoint.tar: ``` encoder_optimizer.step() ``` Error message: ``` exp_avg.mul_(beta1).add_(1 - beta1, grad) RuntimeError: expected backend CPU and dtype Float but got backend CUDA and dtype Float ``` Fix it: pytorch/pytorch#2830 ``` with torch.no_grad(): correct = 0 total = 0 for images, labels in test_loader: images = images.to(device) # missing line from original code labels = labels.to(device) # missing line from original code images = images.reshape(-1, 28 * 28) out = model(images) _, predicted = torch.max(out.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() ```
1 parent a7c3a8b commit 3e1613d

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

beginner_source/chatbot_tutorial.py

+11
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,17 @@ def evaluateInput(encoder, decoder, searcher, voc):
13261326
print('Building optimizers ...')
13271327
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
13281328
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
1329+
# If you have cuda, configure cuda to call
1330+
for state in encoder_optimizer.state.values():
1331+
for k, v in state.items():
1332+
if isinstance(v, torch.Tensor):
1333+
state[k] = v.cuda()
1334+
1335+
for state in decoder_optimizer.state.values():
1336+
for k, v in state.items():
1337+
if isinstance(v, torch.Tensor):
1338+
state[k] = v.cuda()
1339+
13291340
if loadFilename:
13301341
encoder_optimizer.load_state_dict(encoder_optimizer_sd)
13311342
decoder_optimizer.load_state_dict(decoder_optimizer_sd)

0 commit comments

Comments
 (0)