Skip to content

Commit

Permalink
Merge pull request microsoft#62 from microsoft/eval
Browse files Browse the repository at this point in the history
Fix layer.eval() bug
  • Loading branch information
edwardjhu authored Apr 25, 2023
2 parents e4a415e + dc5d174 commit 0771c1e
Showing 1 changed file with 66 additions and 74 deletions.
140 changes: 66 additions & 74 deletions loralib/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,29 +61,28 @@ def reset_parameters(self):

def train(self, mode: bool = True):
nn.Embedding.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= (self.lora_B @ self.lora_A).T * self.scaling
self.merged = False

def eval(self):
nn.Embedding.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
self.merged = True

if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
self.merged = False
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
self.merged = True

def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged:
result = nn.Embedding.forward(self, x)
if self.r > 0:
after_A = F.embedding(
x, self.lora_A.T, self.padding_idx, self.max_norm,
x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse
)
result += (after_A @ self.lora_B.T) * self.scaling
result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
return result
else:
return nn.Embedding.forward(self, x)
Expand Down Expand Up @@ -116,7 +115,7 @@ def __init__(
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
self.weight.data = self.weight.data.transpose(0, 1)

def reset_parameters(self):
nn.Linear.reset_parameters(self)
Expand All @@ -127,31 +126,28 @@ def reset_parameters(self):

def train(self, mode: bool = True):
def T(w):
return w.T if self.fan_in_fan_out else w
return w.transpose(0, 1) if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False

def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Linear.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
self.merged = True
if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
self.merged = True

def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
return w.transpose(0, 1) if self.fan_in_fan_out else w
if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
return result
else:
return F.linear(x, T(self.weight), bias=self.bias)
Expand Down Expand Up @@ -196,7 +192,7 @@ def __init__(
self.lora_ind = self.lora_ind.view(-1)
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
self.weight.data = self.weight.data.transpose(0, 1)

def reset_parameters(self):
nn.Linear.reset_parameters(self)
Expand All @@ -215,37 +211,34 @@ def zero_pad(self, x):

def train(self, mode: bool = True):
def T(w):
return w.T if self.fan_in_fan_out else w
return w.transpose(0, 1) if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0 and any(self.enable_lora):
delta_w = F.conv1d(
self.lora_A.data.unsqueeze(0),
self.lora_B.data.unsqueeze(-1),
groups=sum(self.enable_lora)
).squeeze(0)
self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
self.merged = False

def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Linear.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0 and any(self.enable_lora):
delta_w = F.conv1d(
self.lora_A.data.unsqueeze(0),
self.lora_B.data.unsqueeze(-1),
groups=sum(self.enable_lora)
).squeeze(0)
self.weight.data += self.zero_pad(T(delta_w * self.scaling))
self.merged = True
if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0 and any(self.enable_lora):
delta_w = F.conv1d(
self.lora_A.data.unsqueeze(0),
self.lora_B.data.unsqueeze(-1),
groups=sum(self.enable_lora)
).squeeze(0)
self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
self.merged = False
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0 and any(self.enable_lora):
delta_w = F.conv1d(
self.lora_A.data.unsqueeze(0),
self.lora_B.data.unsqueeze(-1),
groups=sum(self.enable_lora)
).squeeze(0)
self.weight.data += self.zero_pad(T(delta_w * self.scaling))
self.merged = True

def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
return w.transpose(0, 1) if self.fan_in_fan_out else w
if self.merged:
return F.linear(x, T(self.weight), bias=self.bias)
else:
Expand Down Expand Up @@ -300,17 +293,16 @@ def reset_parameters(self):

def train(self, mode: bool = True):
nn.Conv2d.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
self.merged = False

def eval(self):
nn.Conv2d.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
self.merged = True
if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
self.merged = False
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
self.merged = True

def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged:
Expand Down

0 comments on commit 0771c1e

Please sign in to comment.