Skip to content

Commit

Permalink
extra insurance in case eos id is not there
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 15, 2022
1 parent 067ac32 commit 683dd98
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 2 additions & 0 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def embed_text(self, text):
is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
text_mask = text_mask & (text != 0)
assert not self.cleared

text_embed = self.clip.encode_text(text)
Expand Down Expand Up @@ -434,6 +435,7 @@ def embed_text(self, text):
is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
text_mask = text_mask & (text != 0)
assert not self.cleared

text_embed = self.clip.encode_text(text)
Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.11.2'
__version__ = '1.11.4'

0 comments on commit 683dd98

Please sign in to comment.