forked from facebookresearch/DrQA
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request facebookresearch#191 from philipmorrisintl/master
Elasticsearch integration
- Loading branch information
Showing
4 changed files
with
134 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
"""Rank documents with an ElasticSearch index""" | ||
|
||
import logging | ||
import scipy.sparse as sp | ||
|
||
from multiprocessing.pool import ThreadPool | ||
from functools import partial | ||
from elasticsearch import Elasticsearch | ||
|
||
from . import utils | ||
from . import DEFAULTS | ||
from .. import tokenizers | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ElasticDocRanker(object): | ||
""" Connect to an ElasticSearch index. | ||
Score pairs based on Elasticsearch | ||
""" | ||
|
||
def __init__(self, elastic_url=None, elastic_index=None, elastic_fields=None, elastic_field_doc_name=None, strict=True, elastic_field_content=None): | ||
""" | ||
Args: | ||
elastic_url: URL of the ElasticSearch server containing port | ||
elastic_index: Index name of ElasticSearch | ||
elastic_fields: Fields of the Elasticsearch index to search in | ||
elastic_field_doc_name: Field containing the name of the document (index) | ||
strict: fail on empty queries or continue (and return empty result) | ||
elastic_field_content: Field containing the content of document in plaint text | ||
""" | ||
# Load from disk | ||
elastic_url = elastic_url or DEFAULTS['elastic_url'] | ||
logger.info('Connecting to %s' % elastic_url) | ||
self.es = Elasticsearch(hosts=elastic_url) | ||
self.elastic_index = elastic_index | ||
self.elastic_fields = elastic_fields | ||
self.elastic_field_doc_name = elastic_field_doc_name | ||
self.elastic_field_content = elastic_field_content | ||
self.strict = strict | ||
|
||
# Elastic Ranker | ||
|
||
def get_doc_index(self, doc_id): | ||
"""Convert doc_id --> doc_index""" | ||
field_index = self.elastic_field_doc_name | ||
if isinstance(field_index, list): | ||
field_index = '.'.join(field_index) | ||
result = self.es.search(index=self.elastic_index, body={'query':{'match': | ||
{field_index: doc_id}}}) | ||
return result['hits']['hits'][0]['_id'] | ||
|
||
|
||
def get_doc_id(self, doc_index): | ||
"""Convert doc_index --> doc_id""" | ||
result = self.es.search(index=self.elastic_index, body={'query': { 'match': {"_id": doc_index}}}) | ||
source = result['hits']['hits'][0]['_source'] | ||
return utils.get_field(source, self.elastic_field_doc_name) | ||
|
||
def closest_docs(self, query, k=1): | ||
"""Closest docs by using ElasticSearch | ||
""" | ||
results = self.es.search(index=self.elastic_index, body={'size':k ,'query': | ||
{'multi_match': { | ||
'query': query, | ||
'type': 'most_fields', | ||
'fields': self.elastic_fields}}}) | ||
hits = results['hits']['hits'] | ||
doc_ids = [utils.get_field(row['_source'], self.elastic_field_doc_name) for row in hits] | ||
doc_scores = [row['_score'] for row in hits] | ||
return doc_ids, doc_scores | ||
|
||
def batch_closest_docs(self, queries, k=1, num_workers=None): | ||
"""Process a batch of closest_docs requests multithreaded. | ||
Note: we can use plain threads here as scipy is outside of the GIL. | ||
""" | ||
with ThreadPool(num_workers) as threads: | ||
closest_docs = partial(self.closest_docs, k=k) | ||
results = threads.map(closest_docs, queries) | ||
return results | ||
|
||
# Elastic DB | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def close(self): | ||
"""Close the connection to the database.""" | ||
self.es = None | ||
|
||
def get_doc_ids(self): | ||
"""Fetch all ids of docs stored in the db.""" | ||
results = self.es.search(index= self.elastic_index, body={ | ||
"query": {"match_all": {}}}) | ||
doc_ids = [utils.get_field(result['_source'], self.elastic_field_doc_name) for result in results['hits']['hits']] | ||
return doc_ids | ||
|
||
def get_doc_text(self, doc_id): | ||
"""Fetch the raw text of the doc for 'doc_id'.""" | ||
idx = self.get_doc_index(doc_id) | ||
result = self.es.get(index=self.elastic_index, doc_type='_doc', id=idx) | ||
return result if result is None else result['_source'][self.elastic_field_content] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters