Skip to content

Commit

Permalink
try enumerated shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
philloooo committed Dec 12, 2023
1 parent 475041a commit 88979fd
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions scripts/convert2coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,21 @@ def convert_encoder_to_tvm(model):
def convert_decoder_to_tvm(model):
model.eval()

tokens_shape = (1, 10)
tokens_shape = []
max_token = 448
# max_token = 448, max number of EnumeratedShapes supported by coreml = 128
segment = max_token//128 + 1
i = segment
while(i<max_token+segment):
tokens_shape.append([1, i])
i += segment

audio_shape = (1, 1500, 384)
token_data = (1000*torch.rand(tokens_shape)).long()
token_data = (1000*torch.rand(tokens_shape[0])).long()
audio_data = torch.rand(audio_shape)
traced_model = torch.jit.trace(model, (token_data, audio_data))

token_flexible_shape = ct.Shape(shape=(1,
ct.RangeDim(lower_bound=1, upper_bound=100000, default=1)))
token_flexible_shape = ct.EnumeratedShapes(shapes=tokens_shape, default=tokens_shape[0])


model = ct.convert(
Expand Down

0 comments on commit 88979fd

Please sign in to comment.