Skip to content

Commit

Permalink
Improved inverting for attn masks (EricLBuehler#811)
Browse files Browse the repository at this point in the history
EricLBuehler authored Oct 1, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent ce02618 commit 2ec3bcb
Showing 2 changed files with 3 additions and 8 deletions.
4 changes: 1 addition & 3 deletions mistralrs-core/src/vision_models/mllama/mod.rs
Original file line number Diff line number Diff line change
@@ -52,9 +52,7 @@ fn prepare_cross_attention_mask(
cross_attn_mask = cross_attn_mask.unsqueeze(1)?;

// Invert the mask
let inverted_cross_attn_mask = (1. - cross_attn_mask)?
.to_dtype(DType::F32)?
.to_dtype(dtype)?;
let inverted_cross_attn_mask = (1. - cross_attn_mask.to_dtype(DType::F32)?.to_dtype(dtype)?)?;
const NEG_INF_VALUE: f64 = -1e15;
cross_attn_mask = masked_fill(
&inverted_cross_attn_mask,
7 changes: 2 additions & 5 deletions mistralrs-core/src/vision_models/mllama/vision.rs
Original file line number Diff line number Diff line change
@@ -353,14 +353,11 @@ fn _prepare_aspect_ratio_attention_mask(
)?;

// Invert the mask
attention_mask = (1. - attention_mask)?;
attention_mask = (1. - attention_mask.to_dtype(DType::F32)?.to_dtype(dtype)?)?;

// Reshape to 2d and create 4d attn mask
// (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
attention_mask = attention_mask
.reshape((bs, max_num_tiles * target_length, 1))?
.to_dtype(DType::F32)?
.to_dtype(dtype)?;
attention_mask = attention_mask.reshape((bs, max_num_tiles * target_length, 1))?;
attention_mask =
attention_mask.matmul(&attention_mask.transpose(D::Minus1, D::Minus2)?.mul(-1e15)?)?;
attention_mask

0 comments on commit 2ec3bcb

Please sign in to comment.