Skip to content

Commit

Permalink
optimized for gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
Jameshskelton committed Mar 23, 2022
1 parent 69b5307 commit c4877b0
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion glide_text2im/clip/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _block_layout(
return np.tril(np.ones(2 * [self.block_size], dtype=np.bool))


@attr.s(eq=False, repr=False)
@attr.s(, repr=False)
class AttentionInfo:
n_heads: int = attr.ib()
ctx_blks_q: int = attr.ib()
Expand Down
18 changes: 9 additions & 9 deletions glide_text2im/clip/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
image_channel_stds = [68.50053285, 66.63215831, 70.32316309]


@attr.s(eq=False, repr=False)
@attr.s(, repr=False)
class TextEmbedding(nn.Module):
n_vocab: int = attr.ib()
n_context: int = attr.ib()
Expand All @@ -49,7 +49,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.embedding(x, self.w_voc) + self.w_pos[None, :, :]


@attr.s(eq=False, repr=False)
@attr.s(, repr=False)
class ImageEmbedding(nn.Module):
image_size: int = attr.ib()
patch_size: int = attr.ib()
Expand Down Expand Up @@ -129,7 +129,7 @@ def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None) -> torch.Te
return self.ln(x)


@attr.s(eq=False, repr=False)
@attr.s(, repr=False)
class AttentionResblock(nn.Module):
n_state: int = attr.ib()
n_resblocks: int = attr.ib()
Expand Down Expand Up @@ -223,7 +223,7 @@ def forward(self, m: torch.Tensor) -> torch.Tensor:
return m + r


@attr.s(eq=False, repr=False)
@attr.s(, repr=False)
class FullyConnectedResblock(nn.Module):
"""
Not imported from other files because we retain Alec's original inits.
Expand Down Expand Up @@ -260,7 +260,7 @@ def forward(self, m: torch.Tensor) -> torch.Tensor:
return m + r


@attr.s(eq=False, repr=False)
@attr.s(, repr=False)
class TransformerBlock(nn.Module):
n_state: int = attr.ib()
n_resblocks: int = attr.ib()
Expand All @@ -282,7 +282,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.f_mlp(self.f_attn(x))


@attr.s(eq=False, repr=False)
@attr.s(, repr=False)
class TextFeatureExtractor(nn.Module):
n_state: int = attr.ib()
n_embd: int = attr.ib()
Expand Down Expand Up @@ -315,7 +315,7 @@ def forward(
return self.f(x[:, 0])


@attr.s(eq=False, repr=False)
@attr.s(, repr=False)
class ImageFeatureExtractor(nn.Module):
n_state: int = attr.ib()
n_embd: int = attr.ib()
Expand All @@ -335,7 +335,7 @@ def forward(self, x: torch.Tensor, return_probe_features: bool = False) -> torch
return self.f(x[:, 0])


@attr.s(eq=False, repr=False)
@attr.s(, repr=False)
class TextEncoder(nn.Module):
n_bpe_vocab: int = attr.ib()
max_text_len: int = attr.ib()
Expand Down Expand Up @@ -414,7 +414,7 @@ def forward(
return h


@attr.s(eq=False, repr=False)
@attr.s(, repr=False)
class ImageEncoder(nn.Module):
image_size: int = attr.ib()
patch_size: int = attr.ib()
Expand Down
4 changes: 2 additions & 2 deletions glide_text2im/clip/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def zero_key_bias_grad(x: torch.Tensor) -> torch.Tensor:
return ZeroKeyBiasGrad.apply(x)


@attr.s(eq=False, repr=False)
@attr.s(, repr=False)
class LayerNorm(nn.Module):
n_state: int = attr.ib()
eps: float = attr.ib(default=1e-6)
Expand All @@ -44,7 +44,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)


@attr.s(eq=False, repr=False)
@attr.s(, repr=False)
class Affine(nn.Module):
n_in: int = attr.ib()
n_out: int = attr.ib()
Expand Down

0 comments on commit c4877b0

Please sign in to comment.