Skip to content

Commit

Permalink
Add logging
Browse files Browse the repository at this point in the history
  • Loading branch information
kazemnejad committed Jun 14, 2022
1 parent e677c60 commit c65330e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
23 changes: 23 additions & 0 deletions lm_eval/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@


class CustomOPTLearnedPositionalEmbedding(OPTLearnedPositionalEmbedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = 1):
super(OPTLearnedPositionalEmbedding, self).__init__(
num_embeddings, embedding_dim, padding_idx
)
self.onnx_trace = False
if self.padding_idx is not None:
self.max_positions = self.num_embeddings - self.padding_idx - 1
else:
self.max_positions = self.num_embeddings
self.has_logged = False

def forward(self, attention_mask: Tensor, positions: Optional[Tensor] = None):
if positions is None:
attention_mask = attention_mask.long()
Expand All @@ -34,6 +45,11 @@ def forward(self, attention_mask: Tensor, positions: Optional[Tensor] = None):
attention_mask.bool(), positions, self.padding_idx
).long()

if not self.has_logged:
print("-------> positions:")
print(positions)
self.has_logged = True

return F.embedding(
positions,
self.weight,
Expand Down Expand Up @@ -405,9 +421,11 @@ def __init__(
# self.gpt2 = nn.DataParallel(self.gpt2)

self.pretrained = pretrained
self.has_logged = False

self._load_opt_model()


def _load_opt_model(self):
weights_path = huggingface_hub.snapshot_download(self.pretrained)
files = os.listdir(weights_path)
Expand Down Expand Up @@ -544,4 +562,9 @@ def _model_call(self, inps: torch.Tensor):
else:
position_ids = None

if not self.has_logged:
print("-------> inputs:")
print(self.tokenizer.batch_decode(inps))
self.has_logged = True

return self.gpt2(input_ids=inps, position_ids=position_ids)[0][:, :, :50257]
1 change: 1 addition & 0 deletions run_all_phase_shift_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
print(python_run_script)
try:
os.system(python_run_script)
print("\n"*5)
time.sleep(5)
except Exception as exp:
print(exp)
Expand Down

0 comments on commit c65330e

Please sign in to comment.