Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various updates #66

Merged
merged 9 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 0 additions & 32 deletions .circleci/config.yml

This file was deleted.

43 changes: 43 additions & 0 deletions .github/workflows/lint_and_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
name: Lint and Test

on: [pull_request]

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.9', '3.10', '3.11', '3.12' ]

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Cache dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-${{ matrix.python-version }}-
${{ runner.os }}-pip-

- name: Install dependencies
run: |
pip install tox tox-gh-actions

- name: Run tests
run: |
bash download_fixtures.sh
tox

- name: Upload coverage to GitHub Artifacts
uses: actions/upload-artifact@v4
with:
name: coverage-${{ matrix.python-version }}
path: htmlcov/
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: stable
rev: 24.8.0
hooks:
- id: black
language_version: python3.8
language_version: python3.10
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ONNX to PyTorch
![PyPI - License](https://img.shields.io/pypi/l/onnx2pytorch?color)
[![CircleCI](https://circleci.com/gh/ToriML/onnx2pytorch.svg?style=shield)](https://app.circleci.com/pipelines/github/ToriML/onnx2pytorch)
[![Lint and Test](https://github.com/Talmaj/onnx2pytorch/actions/workflows/lint_and_test.yml/badge.svg)](https://github.com/Talmaj/onnx2pytorch/actions/workflows/lint_and_test.yml)
[![Downloads](https://pepy.tech/badge/onnx2pytorch)](https://pepy.tech/project/onnx2pytorch)
![PyPI](https://img.shields.io/pypi/v/onnx2pytorch)

Expand Down
14 changes: 7 additions & 7 deletions download_fixtures.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fi

if [[ ! -f shufflenet_v2.onnx ]]; then
echo Downloading shufflenet_v2
curl -LJo shufflenet_v2.onnx https://github.com/onnx/models/blob/master/vision/classification/shufflenet/model/shufflenet-v2-10.onnx\?raw\=true
curl -LJo shufflenet_v2.onnx https://github.com/onnx/models/raw/main/validated/vision/classification/shufflenet/model/shufflenet-v2-10.onnx
fi

if [[ $1 == "--all" ]]; then
Expand All @@ -20,32 +20,32 @@ if [[ $1 == "--all" ]]; then

if [[ ! -f bertsquad-10.onnx ]]; then
echo Downloading bertsquad-10
curl -LJo bertsquad-10.onnx https://github.com/onnx/models/blob/master/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx\?raw\=true
curl -LJo bertsquad-10.onnx https://github.com/onnx/models/raw/main/validated/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx
fi

if [[ ! -f yolo_v4.onnx ]]; then
echo Downloading yolo_v4
curl -LJo yolo_v4.onnx https://github.com/onnx/models/blob/master/vision/object_detection_segmentation/yolov4/model/yolov4.onnx\?raw\=true
curl -LJo yolo_v4.onnx https://github.com/onnx/models/raw/main/validated/vision/object_detection_segmentation/yolov4/model/yolov4.onnx
fi

if [[ ! -f super_res.onnx ]]; then
echo Downloading super_res
curl -LJo super_res.onnx https://github.com/onnx/models/blob/master/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx\?raw\=true
curl -LJo super_res.onnx https://github.com/onnx/models/raw/main/validated/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx
fi

if [[ ! -f fast_neural_style.onnx ]]; then
echo Downloading fast_neural_style
curl -LJo fast_neural_style.onnx https://github.com/onnx/models/blob/master/vision/style_transfer/fast_neural_style/model/rain-princess-9.onnx\?raw\=true
curl -LJo fast_neural_style.onnx https://github.com/onnx/models/raw/main/validated/vision/style_transfer/fast_neural_style/model/rain-princess-9.onnx
fi

if [[ ! -f efficientnet-lite4.onnx ]]; then
echo Downloading efficientnet-lite4
curl -LJo efficientnet-lite4.onnx https://github.com/onnx/models/blob/master/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx\?raw\=true
curl -LJo efficientnet-lite4.onnx https://github.com/onnx/models/raw/main/validated/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx
fi

if [[ ! -f mobilenetv2-7.onnx ]]; then
echo Downloading mobilenetv2-7
curl -LJo mobilenetv2-7.onnx https://github.com/onnx/models/raw/master/vision/classification/mobilenet/model/mobilenetv2-7.onnx\?raw\=true
curl -LJo mobilenetv2-7.onnx https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-7.onnx
fi

fi
Expand Down
25 changes: 15 additions & 10 deletions onnx2pytorch/operations/instancenorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@
from torch.nn.modules.batchnorm import _LazyNormBase

class _LazyInstanceNorm(_LazyNormBase, _InstanceNorm):

cls_to_become = _InstanceNorm


except ImportError:
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.parameter import UninitializedBuffer, UninitializedParameter

class _LazyInstanceNorm(LazyModuleMixin, _InstanceNorm):

weight: UninitializedParameter # type: ignore[assignment]
bias: UninitializedParameter # type: ignore[assignment]

Expand Down Expand Up @@ -78,24 +75,29 @@ def initialize_parameters(self, input) -> None: # type: ignore[override]
self.reset_parameters()


class LazyInstanceNormUnsafe(_LazyInstanceNorm):
class InstanceNormMixin:
"""Skips dimension check."""

def __init__(self, *args, affine=True, **kwargs):
self.no_batch_dim = None # no_batch_dim has to be set at runtime
super().__init__(*args, affine=affine, **kwargs)

def set_no_dim_batch_dim(self, no_batch_dim):
self.no_batch_dim = no_batch_dim

def _check_input_dim(self, input):
return

def _get_no_batch_dim(self):
return self.no_batch_dim

class InstanceNormUnsafe(_InstanceNorm):
"""Skips dimension check."""

def __init__(self, *args, affine=True, **kwargs):
super().__init__(*args, affine=affine, **kwargs)
class LazyInstanceNormUnsafe(InstanceNormMixin, _LazyInstanceNorm):
pass

def _check_input_dim(self, input):
return

class InstanceNormUnsafe(InstanceNormMixin, _InstanceNorm):
pass


class InstanceNormWrapper(torch.nn.Module):
Expand All @@ -120,4 +122,7 @@ def forward(self, input, scale=None, B=None):
if B is not None:
getattr(self.inu, "bias").data = B

if self.inu.no_batch_dim is None:
self.inu.set_no_dim_batch_dim(input.dim() - 1)

return self.inu.forward(input)
12 changes: 6 additions & 6 deletions tests/onnx2pytorch/convert/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,17 @@ def test_single_layer_lstm(
o2p_lstm = ConvertModel(onnx_lstm, experimental=True)
with torch.no_grad():
o2p_output, o2p_h_n, o2p_c_n = o2p_lstm(input, h_0, c_0)
assert torch.equal(o2p_output, output)
assert torch.equal(o2p_h_n, h_n)
assert torch.equal(o2p_c_n, c_n)
torch.testing.assert_allclose(o2p_output, output, rtol=1e-6, atol=1e-6)
torch.testing.assert_allclose(o2p_h_n, h_n, rtol=1e-6, atol=1e-6)
torch.testing.assert_allclose(o2p_c_n, c_n, rtol=1e-6, atol=1e-6)

onnx_lstm = onnx.ModelProto.FromString(bitstream_data)
o2p_lstm = ConvertModel(onnx_lstm, experimental=True)
with torch.no_grad():
o2p_output, o2p_h_n, o2p_c_n = o2p_lstm(h_0=h_0, input=input, c_0=c_0)
assert torch.equal(o2p_output, output)
assert torch.equal(o2p_h_n, h_n)
assert torch.equal(o2p_c_n, c_n)
torch.testing.assert_allclose(o2p_output, output, rtol=1e-6, atol=1e-6)
torch.testing.assert_allclose(o2p_h_n, h_n, rtol=1e-6, atol=1e-6)
torch.testing.assert_allclose(o2p_c_n, c_n, rtol=1e-6, atol=1e-6)
with pytest.raises(KeyError):
o2p_output, o2p_h_n, o2p_c_n = o2p_lstm(h_0=h_0, input=input)
with pytest.raises(Exception):
Expand Down
4 changes: 2 additions & 2 deletions tests/onnx2pytorch/operations/test_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ def test_clip():
assert torch.equal(op(x), exp_y)

op = Clip(max=0)
exp_y_np = np.clip(x_np, np.NINF, 0)
exp_y_np = np.clip(x_np, -np.inf, 0)
exp_y = torch.from_numpy(exp_y_np)
assert torch.equal(op(x), exp_y)

op = Clip()
exp_y_np = np.clip(x_np, np.NINF, np.inf)
exp_y_np = np.clip(x_np, -np.inf, np.inf)
exp_y = torch.from_numpy(exp_y_np)
assert torch.equal(op(x), exp_y)
9 changes: 8 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
# and then run "tox" from this directory.

[tox]
envlist = clean,py36,py37,py38,py38-torch19,py39
envlist = clean,py39,py310,py311,py312

[gh-actions]
python =
3.9: py39
3.10: py310
3.11: py311
3.12: py312

[testenv]
passenv =
Expand Down