Skip to content

Commit

Permalink
Add updated architecture with lower mem footprint and custom training
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 committed Apr 20, 2022
1 parent d862b3a commit 8c1ad0a
Show file tree
Hide file tree
Showing 19 changed files with 193 additions and 386 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ systems pytorch data directory where datasets are stored.
python download_data.py
cd /YOUR/PYTORCH/DATA/DIR
unzip cocostuff.zip
unzip cityscapes.zip
unzip potsdam.zip
unzip potsdamraw.zip
```


Expand Down
6 changes: 3 additions & 3 deletions src/configs/demo_config.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
output_root: '../'
model_path: "../saved_models/potsdam_test.ckpt"
image_dir: "../samples"
experiment_name: "demo_potsdam"
model_path: "../saved_models/cocostuff27_vit_base_5.ckpt"
image_dir: "/datadrive/pytorch-data/cocostuff/images/val2017"
experiment_name: "cocostuff_val"
res: 320
batch_size: 8
num_workers: 24
Expand Down
5 changes: 5 additions & 0 deletions src/configs/eval_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ run_prediction: True
dark_mode: True
use_ddp: False

model_paths:
- "../saved_models/cocostuff27_vit_base_5.ckpt"
#- "../saved_models/cityscapes_vit_base_1.ckpt"
#- "../saved_models/potsdam_test.ckpt"

hydra:
run:
dir: "."
Expand Down
2 changes: 1 addition & 1 deletion src/configs/plot_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ experiment_name: "exp1"
log_dir: "cleaning"

plot_correspondence: True
plot_movie: False
plot_movie: True


# Loader params
Expand Down
13 changes: 8 additions & 5 deletions src/configs/train_config.yml
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
output_root: '../'
pytorch_data_dir: '/datadrive/pytorch-data/'
experiment_name: "exp1"
log_dir: "iarpa"
log_dir: "cocostuff27"
azureml_logging: True
submitting_to_aml: False

# Loader params
num_workers: 24
max_steps: 5000

batch_size: 16

num_neighbors: 7
dataset_name: "cocostuff27"

# Used if dataset_name is "directory"
dir_dataset_name: ~
dir_dataset_n_classes: 5

batch_size: 16
dataset_name: "iarpa"
has_labels: False
crop_type: ~
crop_type: "five"
crop_ratio: .5
res: 224
loader_crop_type: "center"
Expand Down
2 changes: 1 addition & 1 deletion src/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as VF
from core import unnorm
from utils import unnorm

MAX_ITER = 10
POS_W = 3
Expand Down
21 changes: 9 additions & 12 deletions src/crop_datasets.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
try:
from .core import *
from .modules import *
except (ModuleNotFoundError, ImportError):
from core import *
from modules import *
from modules import *
import os
from .data import ContrastiveSegDataset
from data import ContrastiveSegDataset
import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities.seed import seed_everything
from torch.utils.data import DataLoader
from torchvision.transforms.functional import five_crop, _get_image_size, crop
from tqdm import tqdm
from torch.utils.data import Dataset


def _random_crops(img, size, seed, n):
Expand Down Expand Up @@ -103,11 +99,13 @@ def __init__(self, cfg, dataset_name, img_set, crop_type, crop_ratio):
dataset_name,
None,
img_set,
cfg.num_neighbors,
T.ToTensor(),
ToTargetTensor(),
cfg=cfg,
pos_labels=False, pos_images=False, mask=False,
num_neighbors=cfg.num_neighbors,
pos_labels=False,
pos_images=False,
mask=False,
aug_geometric_transform=None,
aug_photometric_transform=None,
extra_transform=cropper
Expand Down Expand Up @@ -139,17 +137,16 @@ def my_app(cfg: DictConfig) -> None:
# crop_types = ["five","random"]
# crop_ratios = [.5, .7]

dataset_names = ["cocostuff27"]
dataset_names = ["cityscapes"]
img_sets = ["train", "val"]
crop_types = [None]
crop_types = ["five"]
crop_ratios = [.5]

for crop_ratio in crop_ratios:
for crop_type in crop_types:
for dataset_name in dataset_names:
for img_set in img_sets:
dataset = RandomCropComputer(cfg, dataset_name, img_set, crop_type, crop_ratio)
print(len(dataset))
loader = DataLoader(dataset, 1, shuffle=False, num_workers=cfg.num_workers, collate_fn=lambda l: l)
for _ in tqdm(loader):
pass
Expand Down
60 changes: 33 additions & 27 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from tqdm import tqdm



def bit_get(val, idx):
"""Gets the bit value.
Args:
Expand Down Expand Up @@ -73,43 +72,50 @@ def create_cityscapes_colormap():
return np.array(colors)


class Iarpa(Dataset):
def __init__(self, root, image_set, transform, target_transform):
super(Iarpa, self).__init__()
class DirectoryDataset(Dataset):
def __init__(self, root, path, image_set, transform, target_transform):
super(DirectoryDataset, self).__init__()
self.split = image_set
self.root = os.path.join(root, "iarpa")
self.dir = join(root, path)
self.img_dir = join(self.dir, "imgs", self.split)
self.label_dir = join(self.dir, "labels", self.split)

self.transform = transform
self.target_transform = target_transform

self.all_files = np.array(sorted(os.listdir(self.root)))
np.random.seed(0)
random_vals = np.random.rand(len(self.all_files)) > .05

if image_set == "train":
self.files = self.all_files[np.where(random_vals)]
elif image_set == "val":
self.files = self.all_files[np.where(1-random_vals)]
self.img_files = np.array(sorted(os.listdir(self.img_dir)))
assert len(self.img_files) > 0
if os.path.exists(join(self.dir, "labels")):
self.label_files = np.array(sorted(os.listdir(self.label_dir)))
assert len(self.img_files) == len(self.label_files)
else:
raise ValueError("Unknown image set: {}".format(image_set))
self.label_files = None

def __getitem__(self, index):
image_fn = self.files[index]
img = Image.open(join(self.root, image_fn))
image_fn = self.img_files[index]
img = Image.open(join(self.img_dir, image_fn))

if self.label_files is not None:
label_fn = self.label_files[index]
label = Image.open(join(self.label_dir, label_fn))

seed = np.random.randint(2147483647)
random.seed(seed)
torch.manual_seed(seed)
img = self.transform(img)

random.seed(seed)
torch.manual_seed(seed)
label = torch.zeros(img.shape[1], img.shape[2], dtype=torch.int64) - 1
mask = (label > 0).to(torch.float32)
if self.label_files is not None:
random.seed(seed)
torch.manual_seed(seed)
label = self.target_transform(label)
else:
label = torch.zeros(img.shape[1], img.shape[2], dtype=torch.int64) - 1

mask = (label > 0).to(torch.float32)
return img, label, mask

def __len__(self):
return len(self.files)
return len(self.img_files)


class Potsdam(Dataset):
Expand Down Expand Up @@ -446,10 +452,10 @@ def __init__(self,
self.n_classes = 3
dataset_class = PotsdamRaw
extra_args = dict(coarse_labels=True)
elif dataset_name == "iarpa":
self.n_classes = 5
dataset_class = Iarpa
extra_args = dict()
elif dataset_name == "directory":
self.n_classes = cfg.dir_dataset_n_classes
dataset_class = DirectoryDataset
extra_args = dict(path=cfg.dir_dataset_name)
elif dataset_name == "cityscapes" and crop_type is None:
self.n_classes = 27
dataset_class = CityscapesSeg
Expand Down Expand Up @@ -493,8 +499,9 @@ def __init__(self,
else:
model_type = cfg.model_type

nice_dataset_name = cfg.dir_dataset_name if dataset_name == "directory" else dataset_name
feature_cache_file = join(pytorch_data_dir, "nns", "nns_{}_{}_{}_{}_{}.npz".format(
model_type, dataset_name, image_set, crop_type, cfg.res))
model_type, nice_dataset_name, image_set, crop_type, cfg.res))
if pos_labels or pos_images:
if not os.path.exists(feature_cache_file) or compute_knns:
raise ValueError("could not find nn file {} please run precompute_knns".format(feature_cache_file))
Expand Down Expand Up @@ -556,4 +563,3 @@ def __getitem__(self, ind):
ret["coord_aug"] = coord_aug.permute(1, 2, 0)

return ret

9 changes: 2 additions & 7 deletions src/demo_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
try:
from .core import *
from .modules import *
except (ModuleNotFoundError, ImportError):
from core import *
from modules import *
from modules import *
import hydra
import torch.multiprocessing
from PIL import Image
Expand All @@ -12,7 +7,7 @@
from torch.utils.data import DataLoader, Dataset
from train_segmentation import LitUnsupervisedSegmenter
from tqdm import tqdm

import random
torch.multiprocessing.set_sharing_strategy('file_system')


Expand Down
9 changes: 3 additions & 6 deletions src/download_datasets.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
try:
from .core import *
except (ModuleNotFoundError, ImportError):
from core import *
from utils import *
import hydra
from omegaconf import DictConfig
import os
import wget


@hydra.main(config_path="configs", config_name="eval_config.yml")
def my_app(cfg: DictConfig) -> None:
pytorch_data_dir = cfg.pytorch_data_dir
dataset_names = [
"potsdam",
#"cityscapes",
"cityscapes",
"cocostuff",
"potsdamraw"]
url_base = "https://marhamilresearch4.blob.core.windows.net/stego-public/pytorch_data/"
Expand All @@ -27,7 +25,6 @@ def my_app(cfg: DictConfig) -> None:
print("\n Found {}, skipping download".format(dataset_name))



if __name__ == "__main__":
prep_args()
my_app()
Loading

0 comments on commit 8c1ad0a

Please sign in to comment.