Skip to content

Commit

Permalink
Empty instance fixes (talmolab#569)
Browse files Browse the repository at this point in the history
* Fix augmentation with no instances

* Add Labels.copy

* Add empty instance removal methods

* Add empty instance removal in provider constructor

* Fix augmentation test

* Properly ignore empty instances and frames in DLC import

* Fix test

* Lint
  • Loading branch information
talmo authored Jul 29, 2021
1 parent 1eb06f8 commit 1bc73bf
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 35 deletions.
10 changes: 9 additions & 1 deletion sleap/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,11 @@ def midpoint(self) -> np.ndarray:
@property
def n_visible_points(self) -> int:
"""Return the number of visible points in this instance."""
return sum(~np.isnan(self.points_array[:, 0]))
n = 0
for p in self.points:
if p.visible:
n += 1
return n

def __len__(self) -> int:
"""Return the number of visible points in this instance."""
Expand Down Expand Up @@ -1386,6 +1390,10 @@ def has_predicted_instances(self) -> bool:
"""Return whether the frame contains any predicted instances."""
return len(self.predicted_instances) > 0

def remove_empty_instances(self):
"""Remove instances with no visible nodes from the labeled frame."""
self.instances = [inst for inst in self.instances if inst.n_visible_points > 0]

@property
def unused_predictions(self) -> List[Instance]:
"""Return a list of "unused" :class:`PredictedInstance` objects in frame.
Expand Down
37 changes: 33 additions & 4 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,16 +699,17 @@ def get(self, *args) -> Union[LabeledFrame, List[LabeledFrame]]:
except KeyError:
return None

def extract(self, inds) -> "Labels":
def extract(self, inds, copy: bool = False) -> "Labels":
"""Extract labeled frames from indices and return a new `Labels` object.
Args:
inds: Any valid indexing keys, e.g., a range, slice, list of label indices,
numpy array, `Video`, etc. See `__getitem__` for full list.
copy: If `True`, create a new copy of all of the extracted labeled frames
and associated labels. If `False` (the default), a shallow copy with
references to the original labeled frames and other objects will be
returned.
Returns:
A new `Labels` object with the specified labeled frames.
This will preserve the other data structures even if they are not found in
the extracted labels, including:
- `Labels.videos`
Expand All @@ -726,8 +727,19 @@ def extract(self, inds) -> "Labels":
suggestions=self.suggestions,
provenance=self.provenance,
)
if copy:
new_labels = new_labels.copy()
return new_labels

def copy(self) -> "Labels":
"""Return a full deep copy of the labels.
Notes:
All objects will be re-created by serializing and then deserializing the
labels. This may be slow and will create new instances of all data
structures.
"""
return type(self).from_json(self.to_json())

def __setitem__(self, index, value: LabeledFrame):
"""Set labeled frame at given index."""
# TODO: Maybe we should remove this method altogether?
Expand Down Expand Up @@ -776,6 +788,23 @@ def remove_frames(self, lfs: List[LabeledFrame]):
self.labeled_frames = [lf for lf in self.labeled_frames if lf not in to_remove]
self.update_cache()

def remove_empty_instances(self, keep_empty_frames: bool = True):
"""Remove instances with no visible points.
Args:
keep_empty_frames: If True (the default), frames with no remaining instances
will not be removed.
Notes:
This will modify the labels in place. If a copy is desired, call
`labels.copy()` before this.
"""
for lf in self.labeled_frames:
lf.remove_empty_instances()
self.update_cache()
if not keep_empty_frames:
self.remove_empty_frames()

def remove_empty_frames(self):
"""Remove frames with no instances."""
self.labeled_frames = [
Expand Down
9 changes: 5 additions & 4 deletions sleap/io/format/deeplabcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,11 @@ def read_frames(
# frame.
instances.append(Instance(skeleton=skeleton, points=instance_points))

# Create LabeledFrame and add it to list.
lfs.append(
LabeledFrame(video=video, frame_idx=frame_idx, instances=instances)
)
if len(instances) > 0:
# Create LabeledFrame and add it to list.
lfs.append(
LabeledFrame(video=video, frame_idx=frame_idx, instances=instances)
)

return lfs

Expand Down
18 changes: 5 additions & 13 deletions sleap/nn/data/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,23 +233,15 @@ def py_augment(image, instances):
# Augment the image.
aug_img = aug_det.augment_image(image.numpy())

# This will get converted to a rank 3 tensor (n_instances, n_nodes, 2).
aug_instances = np.full_like(instances, np.nan)

# Augment each set of points for each instance.
aug_instances = []
for instance in instances:
for i, instance in enumerate(instances):
kps = ia.KeypointsOnImage.from_xy_array(
instance.numpy(), tuple(image.shape)
)
aug_instance = aug_det.augment_keypoints(kps).to_xy_array()
aug_instances.append(aug_instance)

# Convert the results to tensors.
# aug_img = tf.convert_to_tensor(aug_img, dtype=image.dtype)

# This will get converted to a rank 3 tensor (n_instances, n_nodes, 2).
aug_instances = np.stack(aug_instances, axis=0)
# aug_instances = [
# tf.convert_to_tensor(x, dtype=instances.dtype) for x in aug_instances
# ]
aug_instances[i] = aug_det.augment_keypoints(kps).to_xy_array()

return aug_img, aug_instances

Expand Down
9 changes: 8 additions & 1 deletion sleap/nn/data/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,16 @@ def from_user_instances(cls, labels: sleap.Labels) -> "LabelsReader":
"""Create a `LabelsReader` using the user instances in a `Labels` set.
Args:
labels: A `sleap.Labels` instance containing user instances.
Returns:
A `LabelsReader` instance that can create a dataset for pipelining.
Notes:
This will remove "empty" instances, i.e., instances with no visible points,
in the original labels. Make a copy of the original labels if needed as they
will be modified in place.
"""
labels.remove_empty_instances(keep_empty_frames=False)
obj = cls.from_user_labeled_frames(labels)
obj.user_instances_only = True
return obj
Expand All @@ -51,7 +58,7 @@ def from_user_labeled_frames(cls, labels: sleap.Labels) -> "LabelsReader":
Returns:
A `LabelsReader` instance that can create a dataset for pipelining.
Note that this constructor will load ALL instances in frames that have user
instances. To load only user labeled indices, use
instances. To load only user labeled instances, use
`LabelsReader.from_user_instances`.
"""
return cls(labels=labels, example_indices=labels.user_labeled_frame_inds)
Expand Down
3 changes: 2 additions & 1 deletion tests/data/dlc/madlc_testdata.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ bodyparts,A,A,B,B,C,C,A,A,B,B,C,C
coords,x,y,x,y,x,y,x,y,x,y,x,y
labeled-data/video/img000.png,0,1,2,3,4,5,6,7,8,9,10,11
labeled-data/video/img001.png,12,13,,,15,16,17,18,,,20,21
labeled-data/video/img002.png,22,23,24,25,26,27,,,,,,
labeled-data/video/img002.png,,,,,,,,,,,,
labeled-data/video/img003.png,22,23,24,25,26,27,,,,,,
36 changes: 35 additions & 1 deletion tests/io/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ def test_has_frame():
@pytest.fixture
def removal_test_labels():
skeleton = Skeleton()
video = Video(backend=MediaVideo)
video = Video(backend=MediaVideo(filename="test"))
lf_user_only = LabeledFrame(
video=video, frame_idx=0, instances=[Instance(skeleton=skeleton)]
)
Expand All @@ -1147,6 +1147,14 @@ def removal_test_labels():
return labels


def test_copy(removal_test_labels):
new_labels = removal_test_labels.copy()
new_labels[0].instances = []
new_labels.remove_frame(new_labels[-1])
assert len(removal_test_labels[0].instances) == 1
assert len(removal_test_labels) == 3


def test_remove_user_instances(removal_test_labels):
labels = removal_test_labels
assert len(labels) == 3
Expand Down Expand Up @@ -1260,3 +1268,29 @@ def test_remove_all_tracks(centered_pair_predictions):
labels.remove_all_tracks()
assert len(labels.tracks) == 0
assert all(inst.track is None for inst in labels.instances())


def test_remove_empty_frames(min_labels):
min_labels.append(sleap.LabeledFrame(video=min_labels.video, frame_idx=2))
assert len(min_labels) == 2
assert len(min_labels[-1]) == 0
min_labels.remove_empty_frames()
assert len(min_labels) == 1
assert len(min_labels[0]) == 2


def test_remove_empty_instances(min_labels):
for inst in min_labels.labeled_frames[0].instances:
for pt in inst.points:
pt.visible = False
min_labels.remove_empty_instances(keep_empty_frames=True)
assert len(min_labels) == 1
assert len(min_labels[0]) == 0


def test_remove_empty_instances_and_frames(min_labels):
for inst in min_labels.labeled_frames[0].instances:
for pt in inst.points:
pt.visible = False
min_labels.remove_empty_instances(keep_empty_frames=False)
assert len(min_labels) == 0
8 changes: 8 additions & 0 deletions tests/io/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ def test_madlc():
)

assert labels.skeleton.node_names == ["A", "B", "C"]
assert len(labels.videos) == 1
assert len(labels.video.filenames) == 4
assert labels.videos[0].filenames[0].endswith("img000.png")
assert labels.videos[0].filenames[1].endswith("img001.png")
assert labels.videos[0].filenames[2].endswith("img002.png")
assert labels.videos[0].filenames[3].endswith("img003.png")

assert len(labels) == 3
assert len(labels[0]) == 2
assert len(labels[1]) == 2
Expand All @@ -176,3 +183,4 @@ def test_madlc():
assert_array_equal(labels[1][0].numpy(), [[12, 13], [np.nan, np.nan], [15, 16]])
assert_array_equal(labels[1][1].numpy(), [[17, 18], [np.nan, np.nan], [20, 21]])
assert_array_equal(labels[2][0].numpy(), [[22, 23], [24, 25], [26, 27]])
assert labels[2].frame_idx == 3
23 changes: 23 additions & 0 deletions tests/nn/data/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,29 @@ def test_augmentation(min_labels):
assert tf.reduce_all(example["instances"] != example_preaug["instances"])


def test_augmentation_with_no_instances(min_labels):
# reproduces #555
min_labels.append(
sleap.LabeledFrame(
video=min_labels.video,
frame_idx=min_labels[-1].frame_idx + 1,
instances=[
sleap.Instance.from_numpy(
np.full([len(min_labels.skeleton.nodes), 2], np.nan),
skeleton=min_labels.skeleton,
)
],
)
)

p = min_labels.to_pipeline(user_labeled_only=False)
p += augmentation.ImgaugAugmenter.from_config(
augmentation.AugmentationConfig(rotate=True)
)
exs = p.run()
assert exs[-1]["instances"].shape[0] == 0


def test_random_cropper(min_labels):
cropper = augmentation.RandomCropper(crop_height=64, crop_width=32)
assert "image" in cropper.input_keys
Expand Down
29 changes: 19 additions & 10 deletions tests/nn/data/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,31 @@ def test_labels_reader(min_labels):


def test_labels_reader_no_visible_points(min_labels):
inst = min_labels.labeled_frames[0].instances[0]
for pt in inst.points:
pt.visible = False

labels_reader = providers.LabelsReader.from_user_instances(min_labels)
ds = labels_reader.make_dataset()
assert not labels_reader.is_from_multi_size_videos

example = next(iter(ds))

# There should be two instances in the labels dataset
assert len(min_labels.labeled_frames[0].instances) == 2
labels = min_labels.copy()
assert len(labels.labeled_frames[0].instances) == 2

# Non-visible ones will be removed in place
inst = labels.labeled_frames[0].instances[0]
for pt in inst.points:
pt.visible = False
labels_reader = providers.LabelsReader.from_user_instances(labels)
assert len(labels.labeled_frames[0].instances) == 1

# Make sure there's only one included with the instances for training
example = next(iter(labels_reader.make_dataset()))
assert len(example["instances"]) == 1

# Now try with no visible instances
labels = min_labels.copy()
for inst in labels.labeled_frames[0].instances:
for pt in inst.points:
pt.visible = False
labels_reader = providers.LabelsReader.from_user_instances(labels)
assert len(labels) == 0
assert len(labels_reader) == 0


def test_labels_reader_subset(min_labels):
labels = sleap.Labels([min_labels[0], min_labels[0], min_labels[0]])
Expand Down

0 comments on commit 1bc73bf

Please sign in to comment.