diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index ed6427312418..62af9503fe82 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -168,10 +168,6 @@ def retrieve_model( # If model file does not exist, download elif allow_download: - # Make sure valid model filename before attempting download - - if "url" not in config: - raise ValueError(f"Model filename not in model list: {model_filename}") url = config.pop("url", None) config["path"] = GPT4All.download_model( diff --git a/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py b/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py index fa798c0c3b5e..89e81086dbc8 100644 --- a/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py @@ -1,5 +1,6 @@ import sys from io import StringIO +from pathlib import Path from gpt4all import GPT4All, Embed4All import time @@ -114,3 +115,15 @@ def test_empty_embedding(): embedder = Embed4All() with pytest.raises(ValueError): output = embedder.embed(text) + +def test_download_model(tmp_path: Path): + import gpt4all.gpt4all + old_default_dir = gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY + gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = tmp_path # temporary pytest directory to ensure a download happens + try: + model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin') + model_path = tmp_path / model.config['filename'] + assert model_path.absolute() == Path(model.config['path']).absolute() + assert model_path.stat().st_size == int(model.config['filesize']) + finally: + gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = old_default_dir