Skip to content

Commit

Permalink
saika model input
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Aug 9, 2023
1 parent ff34785 commit 80561b9
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions palme/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,13 @@ def forward(self, text_tokens, images):
print("Shape after concatenation:", concatenated_input.shape)

dense_layer = nn.Linear(concatenated_input.size(-1), concatenated_input.size(-1))
input = dense_layer(concatenated_input)

processed_input = dense_layer(concatenated_input)

# Proceed with the forward propagation
# model_input = self.decoder(concatenated_input)
# print("After passing concatenated input through decoder:", model_input.shape)
model_input = self.decoder(processed_input)
print("After passing concatenated input through decoder:", model_input.shape)

output = self.decoder(input, passed_x=input)[0]
output = self.decoder(model_input)
return output

except Exception as error:
Expand Down

0 comments on commit 80561b9

Please sign in to comment.