Skip to content

Commit

Permalink
Change Grid from ontology to a data structure in core.py (asyml#876)
Browse files Browse the repository at this point in the history
* rewrite grids -> grid

* remove grid from adding entry

* rewrite the docstring

* import grid

* import grid

* pylint

* pylint

* pylint

* remove the requirement to initialize bounding box

* fix pylint and docstring

* DataPack.grids -> DataPack.grid

* fix docstring

* remove docstring

* simplify test

* rename: grids -> grid

* correct typing for grid

* correct the docstring

* get_grid_cell -> _get_image_within_grid_cell and it doesn't create new array to improve efficiency

* improve docstring

* remove unused DataPack.grid
  • Loading branch information
hepengfe authored Jul 25, 2022
1 parent 12152dc commit 8a61272
Show file tree
Hide file tree
Showing 7 changed files with 333 additions and 243 deletions.
4 changes: 1 addition & 3 deletions forte/data/data_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
Generics,
AudioAnnotation,
ImageAnnotation,
Grids,
Payload,
)

Expand Down Expand Up @@ -171,7 +170,6 @@ def __init__(self, pack_name: Optional[str] = None):
self._data_store: DataStore = DataStore()
self._entry_converter: EntryConverter = EntryConverter()
self.image_annotations: List[ImageAnnotation] = []
self.grids: List[Grids] = []

self.text_payloads: List[Payload] = []
self.audio_payloads: List[Payload] = []
Expand Down Expand Up @@ -242,7 +240,7 @@ def text(self) -> str:
return ""

@property
def audio(self) -> Optional[np.ndarray]:
def audio(self):
r"""Return the audio of the data pack"""
return self.get_payload_data_at(Modality.Audio, 0)

Expand Down
2 changes: 0 additions & 2 deletions forte/data/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from forte.data.ontology.top import (
Annotation,
AudioAnnotation,
Grids,
Group,
ImageAnnotation,
Link,
Expand Down Expand Up @@ -774,7 +773,6 @@ def _add_entry_raw(
Group,
Generics,
ImageAnnotation,
Grids,
Payload,
MultiPackLink,
MultiPackGroup,
Expand Down
10 changes: 0 additions & 10 deletions forte/data/entry_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
Generics,
AudioAnnotation,
ImageAnnotation,
Grids,
MultiPackGeneric,
MultiPackGroup,
MultiPackLink,
Expand Down Expand Up @@ -124,15 +123,6 @@ def save_entry_object(
tid=entry.tid,
allow_duplicate=allow_duplicate,
)
elif data_store_ref._is_subclass(entry.entry_type(), Grids):
# Will be deprecated in future
data_store_ref.add_entry_raw(
type_name=entry.entry_type(),
attribute_data=[entry.image_payload_idx, None],
base_class=Grids,
tid=entry.tid,
allow_duplicate=allow_duplicate,
)
elif data_store_ref._is_subclass(entry.entry_type(), MultiPackLink):
data_store_ref.add_entry_raw(
type_name=entry.entry_type(),
Expand Down
229 changes: 228 additions & 1 deletion forte/data/ontology/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import (
Iterable,
Optional,
Tuple,
Type,
Hashable,
TypeVar,
Expand All @@ -33,7 +34,7 @@
overload,
List,
)

import math
import numpy as np

from forte.data.container import ContainerType
Expand All @@ -49,6 +50,7 @@
"FList",
"FNdArray",
"MultiEntry",
"Grid",
]

default_entry_fields = [
Expand Down Expand Up @@ -635,5 +637,230 @@ def index_key(self) -> int:
return self.tid


class Grid:
"""
Regular grid with a grid configuration dependent on the image size.
It is a data structure used to retrieve grid-related objects such as grid
cells from the image. Grid itself doesn't store image data but only data
related to grid configurations such as grid shape and image size.
Based on the image size and the grid shape,
we compute the height and the width of grid cells.
For example, if the image size (image_height,image_width) is (640, 480)
and the grid shape (height, width) is (2, 3)
the size of grid cells (self.c_h, self.c_w) will be (320, 160).
However, when the image size is not divisible by the grid shape, we round
up the resulting divided size(floating number) to an integer.
In this way, as each grid cell possibly takes one more pixel,
we make the last grid cell per column and row
size(height and width) to be the remainder of the image size divided by the
grid cell size which is smaller than other grid cell.
For example, if the image
size is (128, 128) and the grid shape is (13, 13), the first 12 grid cells
per column and row will have a size of (10, 10) since 128/13=9.85, so we
round up to 10. The last grid cell per column and row will have a size of
(8, 8) since 128%10=8.
Args:
height: the number of grid cell per column.
width: the number of grid cell per row.
image_height: the number of pixels per column in the image.
image_width: the number of pixels per row in the image.
"""

def __init__(
self,
height: int,
width: int,
image_height: int,
image_width: int,
):
if image_height <= 0 or image_width <= 0:
raise ValueError(
"both image height and width must be positive"
f"but the image shape is {(image_height, image_width)}"
"please input a valid image shape"
)
if height <= 0 or width <= 0:
raise ValueError(
f"height({height}) and "
f"width({width}) both must be larger than 0"
)
if height >= image_height or width >= image_width:
raise ValueError(
"Grid height and width must be smaller than image height and width"
)

self._height = height
self._width = width

# We require each grid to be bounded/intialized with one image size since
# the number of different image shapes are limited per computer vision task.
# For example, we can only have one image size (640, 480) from a CV dataset,
# and we could augment the dataset with few other image sizes
# (320, 240), (480, 640). Then there are only three image sizes.
# Therefore, it won't be troublesome to
# have a grid for each image size, and we can check the image size during the
# initialization of the grid.

# By contrast, if we don't initialize it with any
# image size and pass the image size directly into the method/operation on
# the fly, the API would be more complex and image size check would be
# repeated everytime the method is called.
self._image_height = image_height
self._image_width = image_width

# if the resulting size of grid is not an integer, we round it up.
# The last grid cell per row and column might be out of the image size
# since we constrain the maximum pixel locations by the image size
self.c_h, self.c_w = (
math.ceil(image_height / self._height),
math.ceil(image_width / self._width),
)

def _get_image_within_grid_cell(
self,
img_arr: np.ndarray,
h_idx: int,
w_idx: int,
) -> np.ndarray:
"""
Get the array data within a grid cell from the image data.
The array is a masked version of the original image, and it has
the same size as the original image. The array entries that are not
within the grid cell will masked as zeros. The image array entries that
are within the grid cell will kept.
Note: all indices are zero-based and counted from top left corner of
the image.
Args:
img_arr: image data represented as a numpy array.
h_idx: the zero-based height(row) index of the grid cell in the
grid, the unit is one grid cell.
w_idx: the zero-based width(column) index of the grid cell in the
grid, the unit is one grid cell.
Raises:
ValueError: ``h_idx`` is out of the range specified by ``height``.
ValueError: ``w_idx`` is out of the range specified by ``width``.
Returns:
numpy array that represents the grid cell.
"""
if not 0 <= h_idx < self._height:
raise ValueError(
f"input parameter h_idx ({h_idx}) is"
"out of scope of h_idx range"
f" {(0, self._height)}"
)
if not 0 <= w_idx < self._width:
raise ValueError(
f"input parameter w_idx ({w_idx}) is"
"out of scope of w_idx range"
f" {(0, self._width)}"
)

return img_arr[
h_idx * self.c_h : min((h_idx + 1) * self.c_h, self._image_height),
w_idx * self.c_w : min((w_idx + 1) * self.c_w, self._image_width),
]

def get_overlapped_grid_cell_indices(
self, image_arr: np.ndarray
) -> List[Tuple[int, int]]:
"""
Get the grid cell indices in the form of (height index, width index)
that image array overlaps with.
Args:
image_arr: image data represented as a numpy array.
Returns:
a list of tuples that represents the grid cell indices that image array overlaps with.
"""
grid_cell_indices = []
for h_idx in range(self._height):
for w_idx in range(self._width):
if (
np.sum(
self._get_image_within_grid_cell(
image_arr, h_idx, w_idx
)
)
> 0
):
grid_cell_indices.append((h_idx, w_idx))
return grid_cell_indices

def get_grid_cell_center(self, h_idx: int, w_idx: int) -> Tuple[int, int]:
"""
Get the center pixel position of the grid cell at the specific height
index and width index in the ``Grid``.
The computation of the center position of the grid cell is
dividing the grid cell height range (unit: pixel) and
width range (unit: pixel) by 2 (round down)
Suppose an edge case that a grid cell has a height range
(unit: pixel) of (0, 3)
and a width range (unit: pixel) of (0, 3) the grid cell center
would be (1, 1).
Since the grid cell size is usually very large,
the offset of the grid cell center is minor.
Note: all indices are zero-based and counted from top left corner of
the grid.
Args:
h_idx: the height(row) index of the grid cell in the grid,
the unit is one grid cell.
w_idx: the width(column) index of the grid cell in the
grid, the unit is one grid cell.
Returns:
A tuple of (y index, x index)
"""

return (
(h_idx * self.c_h + min((h_idx + 1) * self.c_h, self._image_height))
// 2,
(w_idx * self.c_w + min((w_idx + 1) * self.c_w, self._image_width))
// 2,
)

@property
def num_grid_cells(self):
return self._height * self._width

@property
def height(self):
return self._height

@property
def width(self):
return self._width

def __repr__(self):
return str(
(self._height, self._width, self._image_height, self._image_width)
)

def __eq__(self, other):
if other is None:
return False
return (
self._height,
self._width,
self._image_height,
self._image_width,
) == (
other._height,
other._width,
other.image_height,
other.image_width,
)

def __hash__(self):
return hash(
(self._height, self._width, self._image_height, self._image_width)
)


GroupType = TypeVar("GroupType", bound=BaseGroup)
LinkType = TypeVar("LinkType", bound=BaseLink)
Loading

0 comments on commit 8a61272

Please sign in to comment.