Skip to content

Commit

Permalink
use images instead of image to match model.forward kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Oct 28, 2023
1 parent e61aa3f commit b8fb1e9
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions llava/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
cur_len = cur_len if 'image' in sample else -cur_len
cur_len = cur_len if 'images' in sample else -cur_len
length_list.append(cur_len)
return length_list

Expand Down Expand Up @@ -700,11 +700,11 @@ def expand2square(pil_img, background_color):

# image exist in the data
if 'image' in self.list_data_dict[i]:
data_dict['image'] = image
data_dict['images'] = image
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
data_dict['images'] = torch.zeros(3, crop_size['height'], crop_size['width'])
return data_dict


Expand Down Expand Up @@ -732,8 +732,8 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)

if 'image' in instances[0]:
images = [instance['image'] for instance in instances]
if 'images' in instances[0]:
images = [instance['images'] for instance in instances]
if all(x is not None and x.shape == images[0].shape for x in images):
batch['images'] = torch.stack(images)
else:
Expand Down

0 comments on commit b8fb1e9

Please sign in to comment.