Skip to content

Commit

Permalink
Added new search API with auto iterator.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alberto Paro committed Jun 8, 2011
1 parent 109f846 commit 068162c
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 76 deletions.
224 changes: 206 additions & 18 deletions pyes/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@ def _discovery(self):
self._init_connection()
return self.servers

def _send_request(self, method, path, body=None, params={}):
def _send_request(self, method, path, body=None, params=None, headers=None):
if params is None:
params = {}
if headers is None:
headers = {}
# prepare the request
if not path.startswith("/"):
path = "/" + path
Expand All @@ -200,7 +204,8 @@ def _send_request(self, method, path, body=None, params={}):
body = json.dumps(body, cls=self.encoder)
else:
body = ""
request = RestRequest(method=Method._NAMES_TO_VALUES[method.upper()], uri=path, parameters=params, headers={}, body=body)
request = RestRequest(method=Method._NAMES_TO_VALUES[method.upper()],
uri=path, parameters=params, headers=headers, body=body)
if self.dump_curl is not None:
self._dump_curl_request(request)

Expand Down Expand Up @@ -330,16 +335,19 @@ def get_indices(self, include_aliases=False):
this is an alias for.
"""
state = self.cluster_state()
status = self.status()
result = {}
indices = status['indices']
for index in sorted(indices.keys()):
info = indices[index]
indices_status = status['indices']
indices_metadata = state['metadata']['indices']
for index in sorted(indices_status.keys()):
info = indices_status[index]
metadata = indices_metadata[index]
num_docs = info['docs']['num_docs']
result[index] = dict(num_docs=num_docs)
if not include_aliases:
continue
for alias in info['aliases']:
for alias in metadata.get('aliases', []):
try:
alias_obj = result[alias]
except KeyError:
Expand Down Expand Up @@ -729,11 +737,10 @@ def force_bulk(self):
"""
Force executing of all bulk data
"""
if self.bulk_items == 0:
return
self._send_request("POST", "/_bulk", self.bulk_data.getvalue())
self.bulk_data = StringIO()
self.bulk_items = 0
if self.bulk_items!=0:
self._send_request("POST", "/_bulk", self.bulk_data.getvalue())
self.bulk_data = StringIO()
self.bulk_items = 0

def put_file(self, filename, index, doc_type, id=None):
"""
Expand Down Expand Up @@ -815,7 +822,7 @@ def get(self, index, doc_type, id, fields=None, routing=None, **get_params):
get_params["routing"] = routing
return self._send_request('GET', path, params=get_params)

def search(self, query, indexes=None, doc_types=None, **query_params):
def search_raw(self, query, indexes=None, doc_types=None, **query_params):
"""Execute a search against one or more indices to get the search hits.
`query` must be a Search object, a Query object, or a custom
Expand All @@ -840,6 +847,28 @@ def search(self, query, indexes=None, doc_types=None, **query_params):

return self._query_call("_search", body, indexes, doc_types, **query_params)

def search(self, query, indexes=None, doc_types=None, **query_params):
"""Execute a search against one or more indices to get the resultset.
`query` must be a Search object, a Query object, or a custom
dictionary of search parameters using the query DSL to be passed
directly.
"""
indexes = self._validate_indexes(indexes)
if doc_types is None:
doc_types = []
elif isinstance(doc_types, basestring):
doc_types = [doc_types]
if hasattr(query, 'search'):
query = query.search()

if hasattr(query, 'to_search_json') or isinstance(query, dict):
pass
else:
raise pyes.exceptions.InvalidQuery("search() must be supplied with a Search or Query object, or a dict")
return ResultSet(connection=self, query = query, indexes=indexes, doc_types=doc_types, query_params=query_params)

def scan(self, query, indexes=None, doc_types=None, scroll_timeout="10m", **query_params):
"""Return a generator which will scan against one or more indices and iterate over the search hits. (currently support only by ES Master)
Expand All @@ -857,6 +886,12 @@ def scan(self, query, indexes=None, doc_types=None, scroll_timeout="10m", **quer
break
yield results

def search_scroll(self, scroll_id, scroll_timeout="10m"):
"""
Executes a scrolling given an scroll_id
"""
return self._send_request('GET', "_search/scroll", scroll_id, {"scroll":scroll_timeout})

def reindex(self, query, indexes=None, doc_types=None, **query_params):
"""
Execute a search query against one or more indices and and reindex the hits.
Expand Down Expand Up @@ -955,17 +990,15 @@ def create_percolator(self, index, name, query, **kwargs):
"""
path = self._make_path(['_percolator', index, name])
body = None

if hasattr(query, 'serialize'):
query = {'query': query.serialize()}

if isinstance(query, dict):
# A direct set of search parameters.
query.update(kwargs)
body = json.dumps(query, cls=self.encoder)
else:
if not isinstance(query, dict):
raise pyes.exceptions.InvalidQuery("create_percolator() must be supplied with a Query object or dict")
# A direct set of search parameters.
query.update(kwargs)
body = json.dumps(query, cls=self.encoder)

return self._send_request('PUT', path, body=body)

Expand Down Expand Up @@ -1015,3 +1048,158 @@ def encode_json(data):
""" Encode some json to dict"""
return json.dumps(data, cls=ESJsonEncoder)

class ResultSet(object):
def __init__(self, connection, query, indexes=None, doc_types=None, query_params=None,
auto_fix_keys=False, auto_clean_highlight=False):
"""
results: an es query results dict
fix_keys: remove the "_" from every key, useful for django views
clean_highlight: removed empty highlight
"""
self.connection = connection
self.indexes = indexes
self.doc_types = doc_types
self.query_params= query_params or {}
self.scroller_parameters = {}
self.scroller_id = None
self._results = None
self._total = None
self.valid = False
self.facets = {}
self.auto_fix_keys= auto_fix_keys
self.auto_clean_highlight = auto_clean_highlight
self.query = query
self.iterpos = 0 #keep track of iterator position
self.start = 0
self.chuck_size = 0
self._store_query_data()

def _store_query_data(self):
"""
Try to detect the start position
"""
if isinstance(self.query, dict):
self.start = self.query.get("start", 0)
self.chuck_size = self.query.get("chuck_size", 10)
else:
self.start = self.query.start or 0
self.chuck_size = self.query.size or 10

def _do_search(self, auto_increment=False):
self.iterpos = 0
process_post_query = True #used to skip results in first scan
if self.scroller_id is None:
if auto_increment:
self.start += self.chuck_size

if hasattr(self.query, 'to_search_json'):
self.query.start = self.start
self.query.size = self.chuck_size
# Common case - a Search or Query object.
body = self.query.to_search_json()
elif isinstance(self.query, dict):
# A direct set of search parameters.
self.query['start'] = self.start
self.query['size'] = self.chuck_size
body = json.dumps(self.query, cls=self.connection.encoder)
else:
raise pyes.exceptions.InvalidQuery("search() must be supplied with a Search or Query object, or a dict")

self._results = self.connection.search_raw(self.query, indexes=self.indexes,
doc_types=self.doc_types, **self.query_params)
if 'search_type' in self.query_params and self.query_params['search_type']=="scan":
self.scroller_parameters['search_type'] = self.query_params['search_type']
del self.query_params['search_type']
if 'scroll' in self.query_params:
self.scroller_parameters['scroll'] = self.query_params['scroll']
del self.query_params['scroll']
if 'size' in self.query_params:
self.scroller_parameters['size'] = self.query_params['size']
del self.query_params['size']
self.chuck_size=self.scroller_parameters['size']
if '_scroll_id' in self._results:
#scan query, let's load the first bulk of data
self.scroller_id=self._results['_scroll_id']
self._do_search()
process_post_query = False
else:
self._results = self.connection.search_scroll(self.scroller_id, self.scroller_parameters.get("scroll", "10m"))

if process_post_query:
self.facets = self._results.get('facets', {})
if 'hits' in self._results:
self.valid = True
self.hits = self._results['hits']['hits']
else:
self.hits=[]
if self.auto_fix_keys:
self._fix_keys()
if self.auto_clean_highlight:
self.clean_highlight()

@property
def total(self):
if self._results is None:
self._do_search()
if self._total is None:
self._total = 0
if self.valid:
self._total = self._results.get("hits", {}).get('total', 0)
return self._total

def __len__(self):
return self.total

def fix_keys(self):
"""
Remove the _ from the keys of the results
"""
if not self.valid:
return

for hit in self._results['hits']['hits']:
for key, item in hit.items():
if key.startswith("_"):
hit[key[1:]] = item
del hit[key]

def clean_highlight(self):
"""
Remove the empty highlight
"""
if not self.valid:
return

for hit in self._results['hits']['hits']:
if 'highlight' in hit:
hl = hit['highlight']
for key, item in hl.items():
if not item:
del hl[key]

def __getattr__(self, name):
if self._results is None:
self._do_search()
return self._results['hits'][name]

def next(self):
if self._results is None:
self._do_search()
if len(self.hits)==0:
raise StopIteration
if self.iterpos<len(self.hits):
res = self.hits[self.iterpos]
self.iterpos +=1
return res
if len(self.hits)<self.chuck_size:
raise StopIteration
self._do_search(auto_increment=True)
self.iterpos = 0
if len(self.hits)==0:
raise StopIteration
res = self.hits[self.iterpos]
self.iterpos +=1
return res

def __iter__(self):
return self
59 changes: 1 addition & 58 deletions pyes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-

__author__ = 'Alberto Paro'
__all__ = ['clean_string', 'ResultSet', "ESRange", "ESRangeOp", "string_b64encode", "string_b64decode"]
__all__ = ['clean_string', "ESRange", "ESRangeOp", "string_b64encode", "string_b64decode"]
import base64

def string_b64encode(s):
Expand Down Expand Up @@ -78,63 +78,6 @@ def clean_string(text):
return text.translate(UNI_SPECIAL_CHARS).strip()
return text.translate(None, STR_SPECIAL_CHARS).strip()

class ResultSet(object):
def __init__(self, results, fix_keys=True, clean_highlight=True):
"""
results: an es query results dict
fix_keys: remove the "_" from every key, useful for django views
clean_highlight: removed empty highlight
"""
self._results = results
self._total = None
self.valid = False
self.facets = results.get('facets', {})
if 'hits' in results:
self.valid = True
self.results = results['hits']['hits']
if fix_keys:
self.fix_keys()
if clean_highlight:
self.clean_highlight()

@property
def total(self):
if self._total is None:
self._total = 0
if self.valid:
self._total = self._results.get("hits", {}).get('total', 0)
return self._total

def fix_keys(self):
"""
Remove the _ from the keys of the results
"""
if not self.valid:
return

for hit in self._results['hits']['hits']:
for key, item in hit.items():
if key.startswith("_"):
hit[key[1:]] = item
del hit[key]

def clean_highlight(self):
"""
Remove the empty highlight
"""
if not self.valid:
return

for hit in self._results['hits']['hits']:
if 'highlight' in hit:
hl = hit['highlight']
for key, item in hl.items():
if not item:
del hl[key]

def __getattr__(self, name):
return self._results['hits'][name]

def keys_to_string(data):
"""
Function to convert all the unicode keys in string keys
Expand Down

0 comments on commit 068162c

Please sign in to comment.