Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support pytorch embedding? #1

Open
kangkang59812 opened this issue May 20, 2019 · 6 comments
Open

support pytorch embedding? #1

kangkang59812 opened this issue May 20, 2019 · 6 comments

Comments

@kangkang59812
Copy link

some error in torchsummary.

@wassname
Copy link
Contributor

wassname commented Jun 6, 2019

The RNN example includes an embedding, so this package works for normal embeddings.

If you found an error perhaps you can expand on what it was and how to replicate it?

@kangkang59812
Copy link
Author

@wassname

from torch import nn
from torchsummary import summary
embedding = nn.Embedding(10, 3) 
summary(embedding.cuda(),(2,4))

RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)

embedding in pytorch need Long type. But in summary, it gives float tensor.

@nmhkahn
Copy link
Owner

nmhkahn commented Jun 6, 2019

Hi. I'm sorry, I had been very busy recently.
I'll look how to handle it today or tomorrow and let you know after finding the way. Thanks!

@wassname
Copy link
Contributor

wassname commented Jun 6, 2019

@ nmhkahn no worries, I thought I'd help you handle it

@kangkang59812 I think it's because you need to provide a tensor not a shape (unlike torchsummary). So try:

from torch import nn
from torchsummaryX import summary
embedding = nn.Embedding(10, 3) 
summary(embedding.cuda(), torch.zeros((2,4)).cuda())

And it should work.

@sagjounkani
Copy link

Got this error when I ran the above piece of code:

TypeError: rand(): argument 'size' must be tuple of ints, but found element of type Tensor at pos 2

@wassname
Copy link
Contributor

wassname commented Nov 17, 2019

Try this code on the latest version please:

import torch
from torch import nn
from torchsummaryX import summary
embedding = nn.Sequential(nn.Embedding(10, 3) )
summary(embedding, torch.zeros((2,4)).long())

If that doesn't work, please post the full error stack and what versions of torch and torchsummaryX you are using, so we can replicate it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants