Skip to content

Commit

Permalink
new forwaard
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Aug 9, 2023
1 parent b69abd4 commit 53c7834
Showing 1 changed file with 65 additions and 22 deletions.
87 changes: 65 additions & 22 deletions palme/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,38 +129,81 @@ def __init__(self,
except Exception as e:
print(f"Error initlizing palme components: {e}")

# def forward(self, text_tokens, images):
# try:
# # if text_tokens.dtype != torch.long:
# # text_tokens = text_tokens.long()
# print(text_tokens.shape)
# images = self.vit_model(pixel_values=images)["last_hidden_state"]
# # print(images.shape)
# images = self.perceive(images).squeeze(1)
# # print(images.shape)
# images = self.image_proj(images)
# # print(images.shape)

# # images = images.unsqueeze(2) # Adjust to [1, 64, 1, 50304]
# # images = F.interpolate(images, size=(114, 50304)) # Reshape to [1, 114, 1, 50304]
# # images = images.squeeze(2) # Return to [1, 114, 50304]

# # print(images.shape)

# model_input = self.decoder(text_tokens)
# print(model_input.shape)

# # model_input = torch.cat([model_input[:, 0:2], images, model_input[:, 2:]], dim=1)
# # model_input = torch.cat([model_input[:, 0:2], images, model_input[:, 2:]], dim=1)
# model_input = torch.cat([model_input[:, 0:2], images, model_input[:, 2:-2]], dim=1)
# # print(model_input.shape)

# model_input = self.decoder(model_input)
# # print(model_input.shape)

# output = self.decoder(model_input, passed_x=model_input)[0]
# return output

# except Exception as e:
# print(f"Error duing forward pass: {e}")
# return None
def forward(self, text_tokens, images):
try:
# if text_tokens.dtype != torch.long:
# text_tokens = text_tokens.long()
print(text_tokens.shape)
# Print the initial shape of text tokens for clarity
print("Initial text tokens shape:", text_tokens.shape)

# Process images with the VIT model
images = self.vit_model(pixel_values=images)["last_hidden_state"]
# print(images.shape)
print("Images after VIT model:", images.shape)

# Reshape images with perceive and project
images = self.perceive(images).squeeze(1)
# print(images.shape)
print("Images after PerceiverResampler:", images.shape)

images = self.image_proj(images)
# print(images.shape)

# images = images.unsqueeze(2) # Adjust to [1, 64, 1, 50304]
# images = F.interpolate(images, size=(114, 50304)) # Reshape to [1, 114, 1, 50304]
# images = images.squeeze(2) # Return to [1, 114, 50304]

# print(images.shape)
print("Images after image_proj:", images.shape)

# Process the text tokens
model_input = self.decoder(text_tokens)
print(model_input.shape)
print("Text tokens after decoding:", model_input.shape)

# As per our understanding, text_tokens might be [1, 114+2, X]
# We need to drop last 2 from the second dimension to make it [1, 114, X]
# We also want images to be of shape [1, 114, Y]
# The final concatenated tensor will be [1, 114, X+Y]

# Before concatenating, check if the reshaping has made the first two dimensions equal
if model_input.shape[:2] != images.shape[:2]:
raise ValueError("Mismatched dimensions between images and text tokens")

# model_input = torch.cat([model_input[:, 0:2], images, model_input[:, 2:]], dim=1)
# model_input = torch.cat([model_input[:, 0:2], images, model_input[:, 2:]], dim=1)
model_input = torch.cat([model_input[:, 0:2], images, model_input[:, 2:-2]], dim=1)
# print(model_input.shape)
# Concatenate the tensors along the last dimension
concatenated_input = torch.cat([model_input, images], dim=-1)
print("Shape after concatenation:", concatenated_input.shape)

model_input = self.decoder(model_input)
# print(model_input.shape)
# Proceed with the forward propagation
model_input = self.decoder(concatenated_input)
print("After passing concatenated input through decoder:", model_input.shape)

output = self.decoder(model_input, passed_x=model_input)[0]
return output

except Exception as e:
print(f"Error duing forward pass: {e}")
return None
print(f"Error during forward pass: {e}")
return None

0 comments on commit 53c7834

Please sign in to comment.