Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various updates #66

Merged
merged 9 commits into from
Sep 15, 2024
Merged
Next Next commit
Fix instance norm tests.
  • Loading branch information
Talmaj committed Sep 15, 2024
commit f20576f84bf68319ffc80eb7a2e81b9b9c631a01
25 changes: 15 additions & 10 deletions onnx2pytorch/operations/instancenorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@
from torch.nn.modules.batchnorm import _LazyNormBase

class _LazyInstanceNorm(_LazyNormBase, _InstanceNorm):

cls_to_become = _InstanceNorm


except ImportError:
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.parameter import UninitializedBuffer, UninitializedParameter

class _LazyInstanceNorm(LazyModuleMixin, _InstanceNorm):

weight: UninitializedParameter # type: ignore[assignment]
bias: UninitializedParameter # type: ignore[assignment]

Expand Down Expand Up @@ -78,24 +75,29 @@ def initialize_parameters(self, input) -> None: # type: ignore[override]
self.reset_parameters()


class LazyInstanceNormUnsafe(_LazyInstanceNorm):
class InstanceNormMixin:
"""Skips dimension check."""

def __init__(self, *args, affine=True, **kwargs):
self.no_batch_dim = None # no_batch_dim has to be set at runtime
super().__init__(*args, affine=affine, **kwargs)

def set_no_dim_batch_dim(self, no_batch_dim):
self.no_batch_dim = no_batch_dim

def _check_input_dim(self, input):
return

def _get_no_batch_dim(self):
return self.no_batch_dim

class InstanceNormUnsafe(_InstanceNorm):
"""Skips dimension check."""

def __init__(self, *args, affine=True, **kwargs):
super().__init__(*args, affine=affine, **kwargs)
class LazyInstanceNormUnsafe(InstanceNormMixin, _LazyInstanceNorm):
pass

def _check_input_dim(self, input):
return

class InstanceNormUnsafe(InstanceNormMixin, _InstanceNorm):
pass


class InstanceNormWrapper(torch.nn.Module):
Expand All @@ -120,4 +122,7 @@ def forward(self, input, scale=None, B=None):
if B is not None:
getattr(self.inu, "bias").data = B

if self.inu.no_batch_dim is None:
self.inu.set_no_dim_batch_dim(input.dim() - 1)

return self.inu.forward(input)