diff --git a/docs/index.rst b/docs/index.rst index 46c49307c..044eacb2a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,6 +12,10 @@ API文档 :titlesonly: :glob: + introduction + quickstart + performance + api/ltp* diff --git a/ltp/ltp.py b/ltp/ltp.py index f42f0b99c..e66afe20f 100644 --- a/ltp/ltp.py +++ b/ltp/ltp.py @@ -58,17 +58,14 @@ def __init__(self, path: str = 'small', batch_size: int = 10, device=None, vocab self.device = torch.device('cuda') else: self.device = torch.device('cpu') - if os.path.exists(path): - ckpt = torch.load(os.path.join(path, "ltp.model"), map_location=self.device) - self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True) - elif path in model_map: + if path in model_map or os.path.exists(path): cache_dir = kwargs.pop("cache_dir", LTP_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) resolved_archive_path = cached_path( - model_map[path], + model_map.get(path, path), cache_dir=cache_dir, force_download=force_download, proxies=proxies, diff --git a/requirements.txt b/requirements.txt index e77d0680b..d72f232e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,8 +32,8 @@ pretty_errors pytest #doc -sphinx -sphinx_rtd_theme +sphinx~=3.1.1 +sphinx_rtd_theme~=0.4.0 # Server tornado~=6.0.4