Skip to content

Commit

Permalink
Fix multiple GPU inference.
Browse files Browse the repository at this point in the history
  • Loading branch information
haotian-liu committed Nov 8, 2023
1 parent a7f6cce commit 4a77fb4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions llava/model/llava_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def prepare_inputs_labels_for_multimodal(
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
image_features = [x.flatten(0, 1).to(concat_images.device) for x in image_features]
image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
else:
image_features = self.encode_images(images)
image_features = self.encode_images(images).to(self.device)

# TODO: image start / end is not implemented here to support pretraining.
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
Expand Down

0 comments on commit 4a77fb4

Please sign in to comment.