Skip to content

Commit

Permalink
change display_attention parameters' names
Browse files Browse the repository at this point in the history
  • Loading branch information
egliette committed Jul 13, 2023
1 parent 3ab3598 commit 8efd7a3
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ checkpoint:

# training hyperparameters
batch_size: 32
total_epoch: 5
total_epoch: 10

clip: 1.0

Expand Down
2 changes: 1 addition & 1 deletion inference_streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def add_bg_from_url():
unsafe_allow_html=True
)


def main(config_fpath="config.yml"):
config = data_utils.get_config(config_fpath)
for key, value in config.items():
Expand Down Expand Up @@ -112,6 +111,7 @@ def main(config_fpath="config.yml"):

st.header("Attention Matrix")
src_tokens = [token.lower() for token in src_tok.tokenize(input_text)]
src_tokens = [src_tok.vocab.bos_token] + src_tokens + [src_tok.vocab.eos_token]
fig = model_utils.display_attention(src_tokens, pred_tokens[1:],
attention, n_heads=1,
n_rows=1, n_cols=1, fig_size=(5, 5))
Expand Down
7 changes: 6 additions & 1 deletion preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,12 @@ def main(config_fpath="config.yml"):

print("Load tokenizers...")
src_tok = EnTokenizer()
tgt_tok = ViTokenizer()

# Vietnamese Multi-word Tokenizer
# tgt_tok = ViTokenizer()

# Vietnamese Word Tokenizer
tgt_tok = EnTokenizer()

print("Load DataLoaders")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down
9 changes: 4 additions & 5 deletions utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def translate_sentence(sent, src_tok, tgt_tok, model, device, max_len=256):

return pred_tokens, attention

def display_attention(sentence, translation, attention, n_heads = 8,
n_rows = 4, n_cols = 2, fig_size=(15,25)):
def display_attention(src_tokens, pred_tokens, attention, n_heads = 8,
n_rows=4, n_cols=2, fig_size=(15,25)):

assert n_rows * n_cols == n_heads

Expand All @@ -82,9 +82,8 @@ def display_attention(sentence, translation, attention, n_heads = 8,
cax = ax.matshow(_attention, cmap='bone')

ax.tick_params(labelsize=12)
ax.set_xticklabels(['']+['<sos>']+[t.lower() for t in sentence]+['<eos>'],
rotation=45)
ax.set_yticklabels(['']+translation)
ax.set_xticklabels([''] + src_tokens)
ax.set_yticklabels([''] + pred_tokens)

ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
Expand Down

0 comments on commit 8efd7a3

Please sign in to comment.