Skip to content

Commit

Permalink
Update layers.py (remove redundant condition) (microsoft#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
Infrared1029 authored Jul 10, 2023
1 parent 488c4c6 commit 2e08c96
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions loralib/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,12 @@ def train(self, mode: bool = True):
def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged:
result = nn.Embedding.forward(self, x)
if self.r > 0:
after_A = F.embedding(
x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse
)
result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
return result
after_A = F.embedding(
x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse
)
result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
return result
else:
return nn.Embedding.forward(self, x)

Expand Down Expand Up @@ -145,9 +144,8 @@ def forward(self, x: torch.Tensor):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w
if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
result = F.linear(x, T(self.weight), bias=self.bias)
result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
return result
else:
return F.linear(x, T(self.weight), bias=self.bias)
Expand Down

0 comments on commit 2e08c96

Please sign in to comment.