Skip to content

Commit

Permalink
Pass LoRA rank to LoRALinearLayer (huggingface#2191)
Browse files Browse the repository at this point in the history
  • Loading branch information
asadm authored Feb 1, 2023
1 parent f73d0b6 commit dd3cae3
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/diffusers/models/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,10 @@ class LoRACrossAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
super().__init__()

self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)

def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
Expand Down Expand Up @@ -408,10 +408,10 @@ class LoRAXFormersCrossAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim, rank=4):
super().__init__()

self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)

def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
Expand Down

0 comments on commit dd3cae3

Please sign in to comment.