Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Eng] Fix issues with generation on MAC with mps enabled #325

Merged
merged 1 commit into from
May 24, 2023

Conversation

tongbaojia
Copy link
Contributor

Currently, if export SUNO_ENABLE_MPS=True, the audio generation doesn't work.

To test this:
export SUNO_ENABLE_MPS=True
then do:
python -m bark --text "tony is in town"

This is eventually traced down to generate_fine. And the cause is:
codebook_preds is type int64 from torch.multinomial.
in_buffer is type int32.

On CPU machines, the column slicing works.
On MPS, for some unknown reason, the column slicing doesn't and it just keeps on overwriting the first column!
Converting the type fixes this issue.

+                    codebook_preds = torch.multinomial(probs[rel_start_fill_idx:1024], num_samples=1).reshape(-1)
+                print("check", rel_start_fill_idx, nn)
+                print(in_buffer, in_buffer.shape)
+                codebook_preds = codebook_preds #.to(torch.int32)
                 in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds
+                print("after", in_buffer)


----------


check 0 2
tensor([[[ 409,  993, 1024,  ..., 1024, 1024, 1024],
         [ 764,  956, 1024,  ..., 1024, 1024, 1024],
         [ 499,  942, 1024,  ..., 1024, 1024, 1024],
         ...,
         [1024, 1024, 1024,  ..., 1024, 1024, 1024],
         [1024, 1024, 1024,  ..., 1024, 1024, 1024],
         [1024, 1024, 1024,  ..., 1024, 1024, 1024]]], device='mps:0',
       dtype=torch.int32) torch.Size([1, 1024, 8])
after tensor([[[ 819,  993, 1024,  ..., 1024, 1024, 1024],
         [1000,  956, 1024,  ..., 1024, 1024, 1024],
         [ 345,  942, 1024,  ..., 1024, 1024, 1024],
         ...,
         [ 786, 1024, 1024,  ..., 1024, 1024, 1024],
         [ 786, 1024, 1024,  ..., 1024, 1024, 1024],
         [ 786, 1024, 1024,  ..., 1024, 1024, 1024]]], device='mps:0',
       dtype=torch.int32)
check 0 3
tensor([[[ 819,  993, 1024,  ..., 1024, 1024, 1024],
         [1000,  956, 1024,  ..., 1024, 1024, 1024],
         [ 345,  942, 1024,  ..., 1024, 1024, 1024],
         ...,
         [ 786, 1024, 1024,  ..., 1024, 1024, 1024],
         [ 786, 1024, 1024,  ..., 1024, 1024, 1024],
         [ 786, 1024, 1024,  ..., 1024, 1024, 1024]]], device='mps:0',
       dtype=torch.int32) torch.Size([1, 1024, 8])
after tensor([[[ 961,  993, 1024,  ..., 1024, 1024, 1024],
         [ 945,  956, 1024,  ..., 1024, 1024, 1024],
         [ 399,  942, 1024,  ..., 1024, 1024, 1024],
         ...,
         [1005, 1024, 1024,  ..., 1024, 1024, 1024],
         [ 506, 1024, 1024,  ..., 1024, 1024, 1024],
         [ 302, 1024, 1024,  ..., 1024, 1024, 1024]]], device='mps:0',
       dtype=torch.int32)
check 0 4
tensor([[[ 961,  993, 1024,  ..., 1024, 1024, 1024],
         [ 945,  956, 1024,  ..., 1024, 1024, 1024],
         [ 399,  942, 1024,  ..., 1024, 1024, 1024],
         ...,
         [1005, 1024, 1024,  ..., 1024, 1024, 1024],
         [ 506, 1024, 1024,  ..., 1024, 1024, 1024],
         [ 302, 1024, 1024,  ..., 1024, 1024, 1024]]], device='mps:0',
       dtype=torch.int32) torch.Size([1, 1024, 8])
after tensor([[[ 742,  993, 1024,  ..., 1024, 1024, 1024],
         [ 720,  956, 1024,  ..., 1024, 1024, 1024],
         [1013,  942, 1024,  ..., 1024, 1024, 1024],
         ...,
         [ 435, 1024, 1024,  ..., 1024, 1024, 1024],
         [ 408, 1024, 1024,  ..., 1024, 1024, 1024],
         [ 816, 1024, 1024,  ..., 1024, 1024, 1024]]], device='mps:0',
       dtype=torch.int32)

@tongbaojia
Copy link
Contributor Author

tongbaojia commented May 24, 2023

I get this from current main FYI
On Python 3.10.10, torch==2.0.1, torchaudio==2.0.2

export SUNO_ENABLE_MPS=True
python -m bark --text "tony is in town [laughs] lol"
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:42<00:00,  2.33it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 37/37 [01:19<00:00,  2.15s/it]
Oops, an error occurred: index out of range in self

@gkucsko
Copy link
Contributor

gkucsko commented May 24, 2023

amazing tyty!

@gkucsko gkucsko merged commit bfb7ebf into suno-ai:main May 24, 2023
@tongbaojia tongbaojia deleted the TT_refactors branch May 25, 2023 02:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants