Skip to content

Commit

Permalink
Updated Range Annotation Iteration (asyml#849)
Browse files Browse the repository at this point in the history
* Transform get method with bisect

* Serialization bug fix

* Deletion Order Change

* modified bisect range logic

* removing unused import

* incorrect type for range annotation

* removing extra entry

* unused argument

* Changed co_iterator_annotation_like interface

* improving code quality

* doc fix

* range annotation name change

* covering iteration edge cases

* docs fix

* Consistent range argument

* mypy fixes

* Updated Docstring

* type fixes

* Black changes
  • Loading branch information
Pushkar-Bhuse authored Jul 18, 2022
1 parent 984e79c commit 64eb43e
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 88 deletions.
4 changes: 2 additions & 2 deletions forte/data/base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,15 +291,15 @@ def get(
self,
type_name: str,
include_sub_type: bool,
range_annotation: Optional[Tuple[int]] = None,
range_span: Optional[Tuple[int, int]] = None,
) -> Iterator[List]:
r"""This function fetches entries from the data store of
type ``type_name``.
Args:
type_name: The index of the list in ``self.__elements``.
include_sub_type: A boolean to indicate whether get its subclass.
range_annotation: A tuple that contains the begin and end indices
range_span: A tuple that contains the begin and end indices
of the searching range of annotation-like entries.
Returns:
Expand Down
2 changes: 1 addition & 1 deletion forte/data/data_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,7 +1513,7 @@ def require_annotations(entry_class=Annotation) -> bool:
for entry_data in self._data_store.get(
type_name=get_full_module_name(entry_type_),
include_sub_type=include_sub_type,
range_annotation=range_annotation # type: ignore
range_span=range_annotation # type: ignore
and (range_annotation.begin, range_annotation.end),
):
entry: Entry = self.get_entry(tid=entry_data[TID_INDEX])
Expand Down
209 changes: 160 additions & 49 deletions forte/data/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
Currently, DataStore supports storing data structures with linear span
(e.g. Annotation), and relational data structures (e.g Link and Group).
Future extension of the class may support data structures with 2-d range
(e.g. bounding boxes).
(e.g. bounding boxes).
Internally, we store every entry in a variable ``__elements``, which is
a nested list: a list of ``entry lists``.
Expand Down Expand Up @@ -258,12 +258,11 @@ def __getstate__(self):
to save space.
"""
state = super().__getstate__()
for k in state["_DataStore__elements"]:
state["_DataStore__elements"] = {}
for k in self.__elements:
# build the full `_type_attributes`
self._get_type_info(k)
state["_DataStore__elements"][k] = list(
state["_DataStore__elements"][k]
)
state["_DataStore__elements"][k] = list(self.__elements[k])
state.pop("_DataStore__tid_ref_dict")
state.pop("_DataStore__tid_idx_dict")
state.pop("_DataStore__deletion_count")
Expand Down Expand Up @@ -742,8 +741,8 @@ def _add_entry_raw(
):
"""
This function add raw entry in DataStore object
based on corresponding type name
and sort them based on entry type.
based on corresponding type name and sort them
based on entry type.
Args:
entry_type: entry's type which decides the sorting of entry.
Expand Down Expand Up @@ -1183,8 +1182,54 @@ def get_length(self, type_name: str) -> int:
delete_count = self.__deletion_count.get(type_name, 0)
return len(self.__elements[type_name]) - delete_count

def _get_bisect_range(
self, search_list: SortedList, range_span: Tuple[int, int]
) -> Optional[List]:
"""
Perform binary search on the specified list for target entry class.
Entry class can be a subtype of
:class:`~forte.data.ontology.top.Annotation`
or :class:`~forte.data.ontology.top.AudioAnnotation`. This function
finds the the elements in the `Annotation` or `AudioAnnotation`
sorted list whose begin and end index falls within `range_span`.
Args:
search_list: A `SortedList` object on which the binary search
will be carried out.
range_span: a tuple that indicates the start and end index
of the range in which we want to get required entries
Returns:
List of entries to fetch
"""

# Check if there are any entries within the given range
if (
search_list[0][constants.BEGIN_INDEX] > range_span[1]
or search_list[-1][constants.END_INDEX] < range_span[0]
):
return None

result_list = []

begin_index = search_list.bisect_left([range_span[0], range_span[0]])

for idx in range(begin_index, len(search_list)):
if search_list[idx][constants.BEGIN_INDEX] > range_span[1]:
break

if search_list[idx][constants.END_INDEX] <= range_span[1]:
result_list.append(search_list[idx])

if len(result_list) == 0:
return None

return result_list

def co_iterator_annotation_like(
self, type_names: List[str]
self,
type_names: List[str],
range_span: Optional[Tuple[int, int]] = None,
) -> Iterator[List]:
r"""
Given two or more type names, iterate their entry lists from beginning
Expand All @@ -1203,21 +1248,62 @@ def co_iterator_annotation_like(
The precedence of those values indicates their priority in the min heap
ordering.
For example, if two entries have both the same begin and end field,
then their order is
Lastly, the `range_span` argument determines the start
and end position of the span range within which entries of specified by
`type_name` need to be fetched. For example, if two entries have both
the same begin and end field, then their order is
decided by the order of user input type_name (the type that first
appears in the target type list will return first).
For entries that have the exact same `begin`, `end` and `type_name`,
the order will be determined arbitrarily.
For example, let's say we have two entry types,
:class:`~ft.onto.base_ontology.Sentence` and
:class:`~ft.onto.base_ontology.EntityMention`.
Each type has two entries. The two entries of type `Sentence` ranges from span
`(0,5)` and `(6,10)`. Similarly, the two entries of type `EntityMention` has span
`(0,3)` and `(15,20)`.
.. code-block:: python
# function signature
entries = list(
co_iterator_annotation_like(
type_names = [
"ft.onto.base_ontology.Sentence",
"ft.onto.base_ontology.EntityMention"
],
range_span = (0,12)
)
)
# Fetching result
result = [
all_anno.append([type(anno).__name__, anno.begin, anno.end])
for all_anno in entries
]
# return
result = [
['Sentence', 0, 5],
['EntityMention', 0, 5],
['Sentence', 6, 10]
]
From this we can see how `range_span` affects which
entries will be fetched and also how the function chooses the order
in which entries are fetched.
Args:
type_names: a list of string type names
range_span: a tuple that indicates the start and end index
of the range in which we want to get required entries
Returns:
An iterator of entry elements.
"""

n = len(type_names)
# suppose the length of type_names is N and the length of entry list of
# one type is M
# then the time complexity of using min-heap to iterate
Expand All @@ -1227,17 +1313,44 @@ def co_iterator_annotation_like(
# it avoids empty entry lists or non-existent entry list
first_entries = []

for tn in type_names:
# For every entry type, store the entries that fall within the required
# range.When range_end and range_begin are None, we fetch all entries of
# each type (mentioned in type_names). But when range_end and range_end
# is specified, we find the list of entries that fall within the range
# and only iterate through them
all_entries_range = {}

# This list stores the types of entries that have atleast one entry to
# fetch. The order of the types in this list is the same as the order
# followed by them in type_names.
valid_type_names = []

if range_span is not None:
for tn in type_names:
possible_entries = self._get_bisect_range(
self.__elements[tn], range_span
)
if possible_entries is not None:
all_entries_range[tn] = possible_entries
valid_type_names.append(tn)

else:
try:
first_entries.append(self.__elements[tn][0])
except KeyError as e: # self.__elements[tn] will be caught here.
for tn in type_names:
all_entries_range[tn] = self.__elements[tn]
valid_type_names = type_names
except KeyError as e: # all_entries_range[tn] will be caught here.
raise ValueError(
f"Input argument `type_names` to the function contains"
f" a type name [{tn}], which is not recognized."
f" Please input available ones in this DataStore"
f" object: {list(self.__elements.keys())}"
) from e
except IndexError as e: # self.__elements[tn][0] will be caught here.

for tn in valid_type_names:
try:
first_entries.append(all_entries_range[tn][0])
except IndexError as e: # all_entries_range[tn][0] will be caught here.
raise ValueError(
f"Entry list of type name, {tn} which is"
" one list item of input argument `type_names`,"
Expand All @@ -1247,25 +1360,25 @@ def co_iterator_annotation_like(
) from e

# record the current entry index for elements
# pointers[i] is the index of entry at (i)th sorted entry lists
pointers = [0] * n
# pointers[tn] is the index of entry of type tn
pointers = {key: 0 for key in all_entries_range}

# compare tuple (begin, end, order of type name in input argument
# type_names)
# we initialize a MinHeap with the first entry of all sorted entry lists
# in self.__elements
# in all_entries_range
# the metric of comparing entry order is represented by the tuple
# (begin index of entry, end index of entry,
# the index of the entry type name in input parameter ``type_names``)
h: List[Tuple[Tuple[int, int, int], str]] = []
for p_idx in range(n):
for p_idx, entry in enumerate(first_entries):
entry_tuple = (
(
first_entries[p_idx][constants.BEGIN_INDEX],
first_entries[p_idx][constants.END_INDEX],
entry[constants.BEGIN_INDEX],
entry[constants.END_INDEX],
p_idx,
),
first_entries[p_idx][constants.ENTRY_TYPE_INDEX],
entry[constants.ENTRY_TYPE_INDEX],
)
heappush(
h,
Expand All @@ -1282,23 +1395,24 @@ def co_iterator_annotation_like(
# `the current entry` means the entry that
# popped entry_tuple represents.
# `the current entry list` means the entry
# list (values of self.__elements) where `the current entry`
# list (values of all_entries_range) where `the current entry`
# locates at.

# retrieve the popped entry tuple (minimum item in the heap)
# and get the p_idx (the index of the current entry list in self.__elements)
# and get the p_idx (the index of the current entry
# list in all_entries_range)
entry_tuple = heappop(h)
(_, _, p_idx), type_name = entry_tuple
# get the index of current entry
# and locate the entry represented by the tuple for yielding
pointer = pointers[p_idx]
entry = self.__elements[type_name][pointer]
pointer = pointers[type_name]
entry = all_entries_range[type_name][pointer]
# check whether there is next entry in the current entry list
# if there is, then we push the new entry's tuple into the heap
if pointer + 1 < len(self.__elements[type_name]):
pointers[p_idx] += 1
new_pointer = pointers[p_idx]
new_entry = self.__elements[type_name][new_pointer]
if pointer + 1 < len(all_entries_range[type_name]):
pointers[type_name] += 1
new_pointer = pointers[type_name]
new_entry = all_entries_range[type_name][new_pointer]
new_entry_tuple = (
(
new_entry[constants.BEGIN_INDEX],
Expand All @@ -1317,7 +1431,7 @@ def get(
self,
type_name: str,
include_sub_type: bool = True,
range_annotation: Optional[Tuple[int]] = None,
range_span: Optional[Tuple[int, int]] = None,
) -> Iterator[List]:
r"""This function fetches entries from the data store of
type ``type_name``. If `include_sub_type` is set to True and
Expand All @@ -1328,27 +1442,23 @@ def get(
Args:
type_name: The fully qualified name of the entry.
include_sub_type: A boolean to indicate whether get its subclass.
range_annotation: A tuple that contains the begin and end indices
range_span: A tuple that contains the begin and end indices
of the searching range of entries.
Returns:
An iterator of the entries matching the provided arguments.
"""

def within_range(
entry: List[Any], range_annotation: Tuple[int]
) -> bool:
def within_range(entry: List[Any], range_span: Tuple[int, int]) -> bool:
"""
A helper function for deciding whether an annotation entry is
inside the `range_annotation`.
inside the `range_span`.
"""
if not self._is_annotation(entry[constants.ENTRY_TYPE_INDEX]):
return False
return (
entry[constants.BEGIN_INDEX]
>= range_annotation[constants.BEGIN_INDEX]
and entry[constants.END_INDEX]
<= range_annotation[constants.END_INDEX]
entry[constants.BEGIN_INDEX] >= range_span[0]
and entry[constants.END_INDEX] <= range_span[1]
)

entry_class = get_class(type_name)
Expand All @@ -1362,15 +1472,16 @@ def within_range(
all_types = list(all_types)
all_types.sort()
if self._is_annotation(type_name):
if range_annotation is None:
if range_span is None:
yield from self.co_iterator_annotation_like(all_types)
else:
for entry in self.co_iterator_annotation_like(all_types):
if within_range(entry, range_annotation):
yield entry
for entry in self.co_iterator_annotation_like(
all_types, range_span=range_span
):
yield entry
elif issubclass(entry_class, Link):
for type in all_types:
if range_annotation is None:
if range_span is None:
yield from self.iter(type)
else:
for entry in self.__elements[type]:
Expand All @@ -1388,12 +1499,12 @@ def within_range(
entry[constants.CHILD_TID_INDEX]
]
if within_range(
parent, range_annotation
) and within_range(child, range_annotation):
parent, range_span
) and within_range(child, range_span):
yield entry
elif issubclass(entry_class, Group):
for type in all_types:
if range_annotation is None:
if range_span is None:
yield from self.iter(type)
else:
for entry in self.__elements[type]:
Expand All @@ -1403,7 +1514,7 @@ def within_range(
within = True
for m in members:
e = self.__tid_ref_dict[m]
if not within_range(e, range_annotation):
if not within_range(e, range_span):
within = False
break
if within:
Expand Down
Loading

0 comments on commit 64eb43e

Please sign in to comment.