Skip to content
This repository has been archived by the owner on Jan 8, 2025. It is now read-only.

Commit

Permalink
Fix tests error
Browse files Browse the repository at this point in the history
  • Loading branch information
Cheng Guo committed Aug 3, 2020
1 parent 4838f29 commit 3cdea3d
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 7 deletions.
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
url="https://github.com/guocheng2018/transformer-encoder",
packages=setuptools.find_packages(),
python_requires=">=3.5",
install_requires=["pytorch>=1.0.0"],
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def test_encoder():
enc = TransformerEncoder(n_layers, d_model, d_ff, n_heads, dropout)
enc = TransformerEncoder(d_model, d_ff, n_heads=n_heads, n_layers=n_layers, dropout=dropout)
x = torch.randn(batch_size, max_len, d_model)
mask = torch.randn(batch_size, max_len).ge(0)
out = enc(x, mask)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
dropout = 0.1
n_layers = 6

factor = 1
warmup = 20
scale_factor = 1
warmup_steps = 20


def test_optim():
enc = TransformerEncoder(n_layers, d_model, d_ff, n_heads, dropout)
opt = WarmupOptimizer(d_model, factor, warmup, optim.Adam(enc.parameters()))
enc = TransformerEncoder(d_model, d_ff, n_heads=n_heads, n_layers=n_layers, dropout=dropout)
opt = WarmupOptimizer(optim.Adam(enc.parameters()), d_model, scale_factor, warmup_steps)
assert type(opt.rate(step=1)) is float # step starts from 1
opt.step()
2 changes: 1 addition & 1 deletion tests/test_pe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def test_pe():
PE = PositionalEncoding(d_model, dropout, max_len)
PE = PositionalEncoding(d_model, dropout=dropout, max_len=max_len)
embeds = torch.randn(batch_size, max_len, d_model) # (batch_size, max_len, d_model)
out = PE(embeds)
assert embeds.size() == out.size()

0 comments on commit 3cdea3d

Please sign in to comment.