forked from PaddlePaddle/docs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
【映射文档】add default_collate (PaddlePaddle#5996)
* model_convert add torch.is_nonzero .etc * model_convert add is_nonzero .etc * model_convert add is_nonzero .etc * model_convert add is_nonzero .etc * add xlogy.etc * add randn_like .etc * comment * add vdot etc. * add aminmax etc. * fix parameter * add bucketizr etc. * add bucketizr etc. * add bucketizr etc. * add sinc etc. * add sinc etc. * add sinc etc. * add cov etc. * add cov etc. * add cov etc. * add cov etc. * add unique etc. * add unique etc. * add backward etc. * add backward etc. * modify get_cuda_rng_state to get_rng_state * add scaler_tensor etc. * add scaler_tensor etc. * add scalar_tensor etc. * add default_collate * add default_collate
- Loading branch information
Showing
3 changed files
with
44 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
41 changes: 41 additions & 0 deletions
41
...t/convert_from_pytorch/api_difference/utils/torch.utils.data.default_collate.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
## [ 参数不一致 ]torch.utils.data.default_collate | ||
### [torch.utils.data.default_collate](https://pytorch.org/docs/stable/data.html?highlight=default_collate#torch.utils.data.default_collate) | ||
|
||
```python | ||
torch.utils.data.default_collate(batch) | ||
``` | ||
|
||
### [paddle.io.dataloader.collate.default_collate_fn] | ||
|
||
```python | ||
paddle.io.dataloader.collate.default_collate_fn(batch) | ||
``` | ||
|
||
返回参数类型不一致,需要转写。具体如下: | ||
### 参数映射 | ||
| PyTorch | PaddlePaddle | 备注 | | ||
| ------------- | ------------ | ------------------------------------------------------ | | ||
| batch | batch | 输入的用于组 batch 的数据。 | | ||
| 返回值 | 返回值 | 返回参数类型不一致,当 batch 的元素为 numpy.ndarray 或 number 时, Pytorch 默认返回 torch.tensor, Paddle 默认返回 numpy.ndarray。 | | ||
|
||
|
||
### 转写示例 | ||
#### 当 batch 的元素为 numpy.ndarray 或 number 时 | ||
```python | ||
# Pytorch 写法 | ||
y = torch.utils.data.default_collate(batch) | ||
|
||
# Paddle 写法 | ||
y = paddle.to_tensor(paddle.io.dataloader.collate.default_collate_fn(batch)) | ||
``` | ||
|
||
#### 当 batch 的元素为字典且字典的 value 为 numpy.ndarray 或 number 时 | ||
```python | ||
# Pytorch 写法 | ||
y = torch.utils.data.default_collate(batch) | ||
|
||
# Paddle 写法 | ||
y = paddle.io.dataloader.collate.default_collate_fn(batch) | ||
for k, v in y.items(): | ||
y[k] = paddle.to_tensor(v) | ||
``` |