diff --git a/palme/model.py b/palme/model.py index c27db33..cc613a8 100644 --- a/palme/model.py +++ b/palme/model.py @@ -204,14 +204,13 @@ def forward(self, text_tokens, images): print("Shape after concatenation:", concatenated_input.shape) dense_layer = nn.Linear(concatenated_input.size(-1), concatenated_input.size(-1)) - input = dense_layer(concatenated_input) - + processed_input = dense_layer(concatenated_input) # Proceed with the forward propagation - # model_input = self.decoder(concatenated_input) - # print("After passing concatenated input through decoder:", model_input.shape) + model_input = self.decoder(processed_input) + print("After passing concatenated input through decoder:", model_input.shape) - output = self.decoder(input, passed_x=input)[0] + output = self.decoder(model_input) return output except Exception as error: