Skip to content

Commit

Permalink
Modify gated unit tests to fix Fairseq OSS (facebookresearch#2059)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#2059

test_ensemble_sequence_generator and test_export_ensemble_model are green on fbcode master but Pytorch 1.5 release cut happened before the TorchScript fix, so updating the gate to 1.6
Remove quantization test from fairseq as FBGEMMS is binded at OSS side. Will add the test back in fbtranslate but land this first to fix OSS side failures.

Reviewed By: myleott

Differential Revision: D21231873

fbshipit-source-id: 8a2ad7dbed118ca8e3f4c351c399a82fd9740445
  • Loading branch information
cndn authored and facebook-github-bot committed Apr 24, 2020
1 parent 38f35cc commit b1af3e3
Showing 1 changed file with 2 additions and 13 deletions.
15 changes: 2 additions & 13 deletions tests/test_sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_export_transformer(self):
torch.jit.script(model)

@unittest.skipIf(
torch.__version__ < "1.5.0", "Targeting OSS scriptability for the 1.5 release"
torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release"
)
def test_ensemble_sequence_generator(self):
model = self.transformer_model
Expand All @@ -128,22 +128,11 @@ def test_ensemble_sequence_generator(self):
scripted_model = torch.jit.script(generator)
self._test_save_and_load(scripted_model)

@unittest.skipIf(
torch.__version__ < "1.5.0", "Targeting OSS scriptability for the 1.5 release"
)
def test_quantized_ensemble_sequence_generator(self):
model = torch.quantization.quantize_dynamic(
self.transformer_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True
)
generator = SequenceGenerator([model], self.task.tgt_dict, beam_size=2)
scripted_model = torch.jit.script(generator)
self._test_save_and_load(scripted_model)


class TestJitEnsemble(TestJitSequenceGeneratorBase):

@unittest.skipIf(
torch.__version__ < "1.5.0", "Targeting OSS scriptability for the 1.5 release"
torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release"
)
def test_export_ensemble_model(self):
model = self.transformer_model
Expand Down

0 comments on commit b1af3e3

Please sign in to comment.