Skip to content

Commit

Permalink
Update quant_block.py
Browse files Browse the repository at this point in the history
Signed-off-by: Haoxuan Wang <[email protected]>
  • Loading branch information
hatchetProject authored Feb 13, 2024
1 parent dd198fd commit 51de3e5
Showing 1 changed file with 0 additions and 26 deletions.
26 changes: 0 additions & 26 deletions qdiff/quant_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,6 @@ def cross_attn_forward(self, x, context=None, mask=None, timestep=None):
v = self.to_v(context)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
# Changed here 1
# old_q = q.clone()
# old_k = k.clone()
# old_v = v.clone()
# sim_old = einsum('b i d, b j d -> b i j', old_q, old_k) * self.scale
# attn_old = sim_old.softmax(dim=-1)
# out_old = einsum('b i j, b j d -> b i d', attn_old, old_v)
# out_old = rearrange(out_old, '(b h) n d -> b n (h d)', h=h)

if self.use_act_quant:
assert timestep is not None
Expand Down Expand Up @@ -276,11 +268,6 @@ def cross_attn_forward(self, x, context=None, mask=None, timestep=None):
# out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)

# Changed here 2
# mean_out = torch.mean(torch.mean(out_old.detach(), dim=0), dim=0)
# _, max_idx = torch.topk(torch.abs(mean_out), 3)
# out[:, :, max_idx] = out_old[:, :, max_idx].cuda()

return self.to_out(out)


Expand Down Expand Up @@ -353,20 +340,7 @@ def _forward(self, x, context=None, timestep=None):

x = self.attn1(self.norm1(x), timestep=timestep) + x # Original 1
x = self.attn2(self.norm2(x), context=context, timestep=timestep) + x # Original 2

# old_x = x.clone()
# mean_x = torch.mean(torch.mean(old_x.detach(), dim=0), dim=0)
# max_idx = torch.argsort(torch.abs(mean_x))[-2:]
# x[:, :, max_idx] = torch.zeros(x[:, :, max_idx].shape).cuda()
# x[:, :, max_idx] = (x[:, :, max_idx] / 2).cuda()

x = self.ff(self.norm3(x)) + x # Original 3

# for n, m in self.ff.named_modules():
# if isinstance(m, QuantModule):
# m.set_quant_state(True, False)
# old_x = self.ff(self.norm3(old_x)) + old_x
# x[:, :, max_idx] = (old_x[:, :, max_idx]).cuda()

return x

Expand Down

0 comments on commit 51de3e5

Please sign in to comment.