Skip to content

Commit

Permalink
various fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ivicadimitrovski committed Mar 6, 2023
1 parent d0f8849 commit e642697
Show file tree
Hide file tree
Showing 12 changed files with 103 additions and 74 deletions.
8 changes: 2 additions & 6 deletions aitlas/base/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import logging
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as nnf
import torch

from .models import BaseModel
from .schemas import BaseSegmentationClassifierSchema
from .metrics import SegmentationRunningScore
from ..utils import DiceLoss

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")

Expand All @@ -25,8 +24,6 @@ def get_predicted(self, outputs, threshold=None):
predicted = (predicted_probs >= (
threshold if threshold else self.config.threshold
)).long()
#predicted_probs = nnf.softmax(outputs, dim=1)
#predicted = (outputs >= predicted_probs).long()
return predicted_probs, predicted

def load_optimizer(self):
Expand All @@ -35,8 +32,7 @@ def load_optimizer(self):

def load_criterion(self):
"""Load the loss function"""
return nn.BCEWithLogitsLoss()
#return nn.CrossEntropyLoss()
return DiceLoss()

def load_lr_scheduler(self, optimizer):
return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.1, min_lr=1e-6)
Expand Down
8 changes: 4 additions & 4 deletions aitlas/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def extract_features(self):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.classifier.parameters():
param.require_grad = True
param.requires_grad = True


class AlexNetMultiLabel(BaseMultilabelClassifier):
Expand Down Expand Up @@ -70,6 +70,6 @@ def extract_features(self):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.classifier.parameters():
param.require_grad = True
param.requires_grad = True
8 changes: 4 additions & 4 deletions aitlas/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def forward(self, x):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.classifier.parameters():
param.require_grad = True
param.requires_grad = True

def extract_features(self):
""" Remove final layers if we only need to extract features """
Expand Down Expand Up @@ -69,6 +69,6 @@ def extract_features(self):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.classifier.parameters():
param.require_grad = True
param.requires_grad = True
8 changes: 4 additions & 4 deletions aitlas/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def extract_features(self):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.classifier.parameters():
param.require_grad = True
param.requires_grad = True


class DenseNet161MultiLabel(BaseMultilabelClassifier):
Expand Down Expand Up @@ -88,6 +88,6 @@ def extract_features(self):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.classifier.parameters():
param.require_grad = True
param.requires_grad = True
24 changes: 12 additions & 12 deletions aitlas/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def forward(self, x):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.classifier.parameters():
param.require_grad = True
param.requires_grad = True

""" Remove final layers if we only need to extract features """
def extract_features(self):
Expand Down Expand Up @@ -65,9 +65,9 @@ def extract_features(self):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.classifier.parameters():
param.require_grad = True
param.requires_grad = True


class EfficientNetB4(BaseMulticlassClassifier):
Expand All @@ -92,9 +92,9 @@ def forward(self, x):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.classifier.parameters():
param.require_grad = True
param.requires_grad = True

""" Remove final layers if we only need to extract features """
def extract_features(self):
Expand Down Expand Up @@ -131,9 +131,9 @@ def extract_features(self):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.classifier.parameters():
param.require_grad = True
param.requires_grad = True


class EfficientNetB7(BaseMulticlassClassifier):
Expand All @@ -158,9 +158,9 @@ def forward(self, x):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.fc.parameters():
param.require_grad = True
param.requires_grad = True

""" Remove final layers if we only need to extract features """
def extract_features(self):
Expand Down Expand Up @@ -197,9 +197,9 @@ def extract_features(self):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.classifier.parameters():
param.require_grad = True
param.requires_grad = True



18 changes: 9 additions & 9 deletions aitlas/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, config):
weights=None, progress=False, num_classes=num_classes
)
# remove prefix "module."
checkpoint = {k.replace("module.backbone.", ""): v for k, v in checkpoint.items()}
checkpoint = {k.replace("backbone.", ""): v for k, v in checkpoint.items()}
checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()}
for k, v in self.model.state_dict().items():
if k not in list(checkpoint):
Expand All @@ -56,9 +56,9 @@ def forward(self, x):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.fc.parameters():
param.require_grad = True
param.requires_grad = True

def extract_features(self):
""" Remove final layers if we only need to extract features """
Expand Down Expand Up @@ -95,9 +95,9 @@ def extract_features(self):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.fc.parameters():
param.require_grad = True
param.requires_grad = True


class ResNet50MultiLabel(BaseMultilabelClassifier):
Expand Down Expand Up @@ -139,9 +139,9 @@ def extract_features(self):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.fc.parameters():
param.require_grad = True
param.requires_grad = True


class ResNet152MultiLabel(BaseMultilabelClassifier):
Expand Down Expand Up @@ -172,6 +172,6 @@ def extract_features(self):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.fc.parameters():
param.require_grad = True
param.requires_grad = True
18 changes: 18 additions & 0 deletions aitlas/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ def __init__(self, config):
in_features=768, out_features=self.config.num_classes, bias=True
)

if self.config.freeze:
self.freeze()

def freeze(self):
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.head.parameters():
param.requires_grad = True

def forward(self, x):
return self.model(x)

Expand All @@ -36,5 +45,14 @@ def __init__(self, config):
in_features=768, out_features=self.config.num_classes, bias=True
)

if self.config.freeze:
self.freeze()

def freeze(self):
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.head.parameters():
param.requires_grad = True

def forward(self, x):
return self.model(x)
16 changes: 8 additions & 8 deletions aitlas/models/vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def forward(self, x):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.classifier.parameters():
param.require_grad = True
param.requires_grad = True

def extract_features(self):
""" Remove final layers if we only need to extract features """
Expand Down Expand Up @@ -70,9 +70,9 @@ def extract_features(self):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.classifier.parameters():
param.require_grad = True
param.requires_grad = True


class VGG16MultiLabel(BaseMultilabelClassifier):
Expand Down Expand Up @@ -105,9 +105,9 @@ def extract_features(self):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.classifier.parameters():
param.require_grad = True
param.requires_grad = True


class VGG19MultiLabel(BaseMultilabelClassifier):
Expand All @@ -134,9 +134,9 @@ def forward(self, x):

def freeze(self):
for param in self.model.parameters():
param.require_grad = False
param.requires_grad = False
for param in self.model.classifier.parameters():
param.require_grad = True
param.requires_grad = True

def extract_features(self):
""" Remove final layers if we only need to extract features """
Expand Down
22 changes: 20 additions & 2 deletions aitlas/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, config):
in_features=768, out_features=num_classes, bias=True
)
# remove prefix "module."
checkpoint = {k.replace("module.backbone.", ""): v for k, v in checkpoint.items()}
checkpoint = {k.replace("backbone.", ""): v for k, v in checkpoint.items()}
checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()}
for k, v in self.model.state_dict().items():
if k not in list(checkpoint):
Expand All @@ -51,6 +51,15 @@ def __init__(self, config):
in_features=768, out_features=self.config.num_classes, bias=True
)

if self.config.freeze:
self.freeze()

def freeze(self):
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.head.parameters():
param.requires_grad = True

def forward(self, x):
return self.model(x)

Expand Down Expand Up @@ -78,7 +87,7 @@ def __init__(self, config):
in_features=768, out_features=num_classes, bias=True
)
# remove prefix "module."
checkpoint = {k.replace("module.backbone.", ""): v for k, v in checkpoint.items()}
checkpoint = {k.replace("backbone.", ""): v for k, v in checkpoint.items()}
checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()}
for k, v in self.model.state_dict().items():
if k not in list(checkpoint):
Expand All @@ -98,5 +107,14 @@ def __init__(self, config):
in_features=768, out_features=self.config.num_classes, bias=True
)

if self.config.freeze:
self.freeze()

def freeze(self):
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.head.parameters():
param.requires_grad = True

def forward(self, x):
return self.model(x)
12 changes: 12 additions & 0 deletions aitlas/transforms/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ def __call__(self, sample):
return torch.tensor(sample, dtype=torch.float32) / 255


class Pad(BaseTransforms):
def __call__(self, sample):
data_transforms = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Pad(4),
transforms.ToTensor()
]
)
return data_transforms(sample)


class ColorTransformations(BaseTransforms):
def __call__(self, sample):
sample = np.asarray(sample)
Expand Down
Loading

0 comments on commit e642697

Please sign in to comment.