Skip to content

Commit

Permalink
A temporary fix for multiple gpu inference of MPT model in haotian-li…
Browse files Browse the repository at this point in the history
  • Loading branch information
haotian-liu committed Aug 2, 2023
1 parent ce6c419 commit fc10062
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion llava/model/language_model/llava_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tu

input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
# FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338
logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
if self.logit_scale is not None:
if self.logit_scale == 0:
warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
Expand Down

0 comments on commit fc10062

Please sign in to comment.