Skip to content

Commit 8dc2236

Browse files
author
吴高升
authored
[PaddlePaddle] set get_dataloader_workers to 4 (d2l-ai#1203)
* [PaddlePaddle] set get_dataloader_workers to 4 * Rerun lr scheduler.md * Try to resolve Termination signal * Remove extra comments
1 parent 0b8045d commit 8dc2236

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

chapter_linear-networks/image-classification-dataset.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -256,11 +256,12 @@ batch_size = 256
256256
257257
def get_dataloader_workers(): #@save
258258
"""使用4个进程来读取数据"""
259-
return 4 if ('cpu' in paddle.device.get_device()) else 0
259+
return 4
260260
261261
train_iter = paddle.io.DataLoader(dataset=mnist_train,
262262
batch_size=batch_size,
263263
shuffle=True,
264+
return_list=True,
264265
num_workers=get_dataloader_workers())
265266
```
266267

@@ -348,9 +349,11 @@ def load_data_fashion_mnist(batch_size, resize=None):
348349
return (paddle.io.DataLoader(dataset=mnist_train,
349350
batch_size=batch_size,
350351
shuffle=True,
352+
return_list=True,
351353
num_workers=get_dataloader_workers()),
352354
paddle.io.DataLoader(dataset=mnist_test,
353355
batch_size=batch_size,
356+
return_list=True,
354357
shuffle=True,
355358
num_workers=get_dataloader_workers()))
356359
```

d2l/paddle.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def get_dataloader_workers():
190190
"""使用4个进程来读取数据
191191
192192
Defined in :numref:`sec_fashion_mnist`"""
193-
return 4 if ('cpu' in paddle.device.get_device()) else 0
193+
return 4
194194

195195
def load_data_fashion_mnist(batch_size, resize=None):
196196
"""下载Fashion-MNIST数据集,然后将其加载到内存中

0 commit comments

Comments
 (0)