Skip to content

Commit

Permalink
updated
Browse files Browse the repository at this point in the history
  • Loading branch information
AntreasAntoniou committed Apr 2, 2024
1 parent d2dccd6 commit 7c45bad
Show file tree
Hide file tree
Showing 15 changed files with 263 additions and 27 deletions.
32 changes: 32 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details

# Required
version: 2

# Set the OS, Python version and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.10"
# You can also specify other tool versions:
# nodejs: "19"
# rust: "1.64"
# golang: "1.19"

# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: docs/conf.py

# Optionally build your docs in additional formats such as PDF and ePub
formats:
- pdf
# - epub

# Optional but recommended, declare the Python requirements required
# to build your documentation
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
install:
- requirements: requirements_dev.txt
9 changes: 9 additions & 0 deletions gate/boilerplate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,15 @@ def load_checkpoint(
"per_epoch_metrics",
loaded_evaluator_per_epoch_metrics,
)
# renamed_model_state = {}
# for name, param in torch.load(
# checkpoint_path / "pytorch_model.bin"
# ).items():
# print(name)
# if ".module." in name:
# name = name.replace(".module.", ".")
# renamed_model_state[name] = param
# torch.save(renamed_model_state, checkpoint_path / "pytorch_model.bin")

self.accelerator.load_state(checkpoint_path)

Expand Down
6 changes: 3 additions & 3 deletions gate/data/image/segmentation/ade20k.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,21 @@ def build_gate_dataset(
train_set = GATEDataset(
dataset=build_dataset("train", data_dir=data_dir),
infinite_sampling=True,
transforms=[input_transforms, train_transforms],
transforms=[input_transforms, train_transforms, transforms],
meta_data={"class_names": CLASSES, "num_classes": num_classes},
)

val_set = GATEDataset(
dataset=build_dataset("val", data_dir=data_dir),
infinite_sampling=False,
transforms=[input_transforms, eval_transforms],
transforms=[input_transforms, eval_transforms, transforms],
meta_data={"class_names": CLASSES, "num_classes": num_classes},
)

test_set = GATEDataset(
dataset=build_dataset("test", data_dir=data_dir),
infinite_sampling=False,
transforms=[input_transforms, eval_transforms],
transforms=[input_transforms, eval_transforms, transforms],
meta_data={"class_names": CLASSES, "num_classes": num_classes},
)

Expand Down
6 changes: 3 additions & 3 deletions gate/data/image/segmentation/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def build_gate_dataset(
tuple_to_dict,
input_transforms,
train_transforms,
# transforms,
transforms,
remap_train_labels,
remap_duds,
],
Expand All @@ -167,7 +167,7 @@ def build_gate_dataset(
tuple_to_dict,
input_transforms,
eval_transforms,
# transforms,
transforms,
remap_train_labels,
remap_duds,
],
Expand All @@ -181,7 +181,7 @@ def build_gate_dataset(
tuple_to_dict,
input_transforms,
eval_transforms,
# transforms,
transforms,
remap_train_labels,
remap_duds,
],
Expand Down
6 changes: 3 additions & 3 deletions gate/data/image/segmentation/coco_10k.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def build_gate_dataset(
transforms=[
input_transforms,
train_transforms,
# transforms,
transforms,
remap_train_labels,
],
meta_data={"class_names": CLASSES, "num_classes": num_classes},
Expand All @@ -117,7 +117,7 @@ def build_gate_dataset(
transforms=[
input_transforms,
eval_transforms,
# transforms,
transforms,
remap_train_labels,
],
meta_data={"class_names": CLASSES, "num_classes": num_classes},
Expand All @@ -129,7 +129,7 @@ def build_gate_dataset(
transforms=[
input_transforms,
eval_transforms,
# transforms,
transforms,
remap_train_labels,
],
meta_data={"class_names": CLASSES, "num_classes": num_classes},
Expand Down
6 changes: 3 additions & 3 deletions gate/data/image/segmentation/coco_164k.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def build_gate_dataset(
transforms=[
input_transforms,
train_transforms,
# transforms,
transforms,
remap_train_labels,
],
meta_data={"class_names": CLASSES, "num_classes": num_classes},
Expand All @@ -126,7 +126,7 @@ def build_gate_dataset(
transforms=[
input_transforms,
eval_transforms,
# transforms,
transforms,
remap_train_labels,
],
meta_data={"class_names": CLASSES, "num_classes": num_classes},
Expand All @@ -138,7 +138,7 @@ def build_gate_dataset(
transforms=[
input_transforms,
eval_transforms,
# transforms,
transforms,
remap_train_labels,
],
meta_data={"class_names": CLASSES, "num_classes": num_classes},
Expand Down
107 changes: 107 additions & 0 deletions gate/data/image/segmentation/debug_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import os

import matplotlib.pyplot as plt
import torch
from rich import print
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from gate.data.image.segmentation.cityscapes import CLASSES as CLASSES_HUG
from gate.data.image.segmentation.cityscapes import (
build_gate_dataset as build_gate_dataset_hug,
)
from gate.data.image.segmentation.cityscapes_pytorch import (
CLASSES,
build_gate_dataset,
)
from gate.metrics.segmentation import IoUMetric


def main():
dataset_dict = build_gate_dataset(data_dir=os.environ.get("DATASET_DIR"))
dataset_hug_dict = build_gate_dataset_hug(
data_dir=os.environ.get("DATASET_DIR")
)

dataloader = DataLoader(
dataset_dict["val"], batch_size=32, shuffle=False, num_workers=32
)
dataloader_hug = DataLoader(
dataset_hug_dict["val"], batch_size=32, shuffle=False, num_workers=32
)

label_set = sorted(list(set([item.id for item in CLASSES])))

iou_metric = IoUMetric(
num_classes=len(label_set),
ignore_index=0,
class_idx_to_name={idx: item for idx, item in enumerate(label_set)},
)

label_set = set()
idx = 0
label_frequency_dict = {}
label_frequency_dict_hug = {}
dataset_sizes = {
"original_val": len(dataset_dict["val"]),
"hug_val": len(dataset_hug_dict["val"]),
}

for item, item_hug in tqdm(zip(dataloader, dataloader_hug)):
image, labels = item["image"], item["labels"]
image_hug, labels_hug = item_hug["image"], item_hug["labels"]

# diff_image = torch.abs(image - image_hug)
# diff_labels = torch.abs(labels - labels_hug)

# visualize images and diff
# canvas = torch.cat([image, image_hug, diff_image], dim=2)
# canvas = canvas.permute(1, 2, 0).numpy()
# canvas = (canvas * 255).astype("uint8")

# save the canvas
# plt.imsave(f"canvas_{idx}.png", canvas)
idx += 1
# print(f"Image: {diff_image.max()}, Labels: {diff_labels.max()}")

# preds are labels in one hot format
preds = labels.clone()
labels = labels.squeeze()
labels_hug = labels_hug.squeeze()

label_set.update(set(labels.unique().tolist()))
print(label_set)
print(len(label_set))
label_freq = torch.bincount(labels.view(-1))
# get keys and frequency of each label
label_keys = torch.nonzero(label_freq).view(-1)
label_values = label_freq[label_keys]

label_freq_hug = torch.bincount(labels_hug.view(-1))
# get keys and frequency of each label
label_keys_hug = torch.nonzero(label_freq_hug).view(-1)
label_values_hug = label_freq_hug[label_keys_hug]

label_frequency_dict = {
label: label_frequency_dict.get(label, 0) + 1
for label in labels.view(-1).tolist()
}
label_frequency_dict_hug = {
label: label_frequency_dict_hug.get(label, 0) + 1
for label in labels_hug.view(-1).tolist()
}
print(
f"Label Frequency Dict: {label_frequency_dict}, Hug: {label_frequency_dict_hug}"
)
iou_metric.update(preds, labels)

metrics = iou_metric.compute_metrics()
iou_metric.pretty_print(metrics=metrics)
iou_metric.reset() # Resetting the metrics after computation
metrics_with_ignore = {
k: v for k, v in metrics.items() if "per_class" not in k
}


if __name__ == "__main__":
main()
6 changes: 3 additions & 3 deletions gate/data/image/segmentation/nyu_depth_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,21 @@ def build_gate_dataset(
train_set = GATEDataset(
dataset=build_dataset("train", data_dir=data_dir),
infinite_sampling=True,
transforms=[input_transforms, train_transforms],
transforms=[input_transforms, train_transforms, transforms],
meta_data={"class_names": CLASSES, "num_classes": num_classes},
)

val_set = GATEDataset(
dataset=build_dataset("val", data_dir=data_dir),
infinite_sampling=False,
transforms=[input_transforms, eval_transforms],
transforms=[input_transforms, eval_transforms, transforms],
meta_data={"class_names": CLASSES, "num_classes": num_classes},
)

test_set = GATEDataset(
dataset=build_dataset("test", data_dir=data_dir),
infinite_sampling=False,
transforms=[input_transforms, eval_transforms],
transforms=[input_transforms, eval_transforms, transforms],
meta_data={"class_names": CLASSES, "num_classes": num_classes},
)

Expand Down
6 changes: 3 additions & 3 deletions gate/data/image/segmentation/pascal_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def build_gate_dataset(
transforms=[
input_transforms,
train_transforms,
# transforms,
transforms,
],
meta_data={"class_names": CLASSES, "num_classes": num_classes},
)
Expand All @@ -241,7 +241,7 @@ def build_gate_dataset(
transforms=[
input_transforms,
eval_transforms,
# transforms,
transforms,
],
meta_data={"class_names": CLASSES, "num_classes": num_classes},
)
Expand All @@ -252,7 +252,7 @@ def build_gate_dataset(
transforms=[
input_transforms,
eval_transforms,
# transforms,
transforms,
],
meta_data={"class_names": CLASSES, "num_classes": num_classes},
)
Expand Down
16 changes: 13 additions & 3 deletions gate/data/medical/segmentation/automated_cardiac_diagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,22 @@ def __call__(self, item: Dict):
image = patient_normalization(image)
annotation = annotation.long()

image = [T.ToPILImage()(i) for i in image]

return {
"image": image,
"labels": annotation,
}


def stack_slices(item: Dict) -> Dict:
image = item["image"]
image_stack = torch.stack(image)

labels = item["labels"]
return {"image": image_stack, "labels": labels}


@configurable(
group="dataset",
name="acdc",
Expand Down Expand Up @@ -207,7 +217,7 @@ def build_gate_dataset(
train_set = GATEDataset(
dataset=build_dataset("train", data_dir=data_dir),
infinite_sampling=True,
transforms=[train_transforms],
transforms=[train_transforms, transforms, stack_slices],
meta_data={
"class_names": CLASSES,
"num_classes": num_classes,
Expand All @@ -217,7 +227,7 @@ def build_gate_dataset(
val_set = GATEDataset(
dataset=build_dataset("val", data_dir=data_dir),
infinite_sampling=False,
transforms=[eval_transforms],
transforms=[eval_transforms, transforms, stack_slices],
meta_data={
"class_names": CLASSES,
"num_classes": num_classes,
Expand All @@ -227,7 +237,7 @@ def build_gate_dataset(
test_set = GATEDataset(
dataset=build_dataset("test", data_dir=data_dir),
infinite_sampling=False,
transforms=[eval_transforms],
transforms=[eval_transforms, transforms, stack_slices],
meta_data={
"class_names": CLASSES,
"num_classes": num_classes,
Expand Down
6 changes: 3 additions & 3 deletions gate/data/transforms/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import cv2
import numpy as np
import PIL
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
Expand Down Expand Up @@ -440,7 +441,6 @@ def __call__(self, inputs: Dict):
image = self.photo_metric_distortion(image)

annotation = torch.from_numpy(np.array(annotation))
# logger.info(f"annotation max: {annotation.max()}")

if len(annotation.shape) == 2:
annotation = annotation.unsqueeze(0)
Expand All @@ -451,8 +451,8 @@ def __call__(self, inputs: Dict):
else:
raise ValueError("Unsupported annotation shape")

if not isinstance(image, torch.Tensor):
image = T.ToTensor()(image)
if not isinstance(image, Image.Image):
image = T.ToPILImage()(image)

image = T.Resize(
(self.input_size[0], self.input_size[1]),
Expand Down
Loading

0 comments on commit 7c45bad

Please sign in to comment.