Skip to content

Commit

Permalink
Merge pull request axinc-ai#438 from axinc-ai/swiftnet
Browse files Browse the repository at this point in the history
Implement swiftnet
  • Loading branch information
kyakuno authored Apr 19, 2021
2 parents 3000dbb + 57d29eb commit aa95b21
Show file tree
Hide file tree
Showing 9 changed files with 1,049 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ The collection of pre-trained, state-of-the-art models.
| [<img src="image_segmentation/semantic-segmentation-mobilenet-v3/output.png" width=128px>](image_segmentation/semantic-segmentation-mobilenet-v3/) | [semantic-segmentation-mobilenet-v3](/image_segmentation/semantic-segmentation-mobilenet-v3) | [Semantic segmentation with MobileNetV3](https://github.com/OniroAI/Semantic-segmentation-with-MobileNetV3) | TensorFlow | 1.2.5 and later |
| [<img src="image_segmentation/pytorch-unet/data/masks/0cdf5b5d0ce1_01.jpg" width=128px>](image_segmentation/pytorch-unet/) | [pytorch-unet](/image_segmentation/pytorch-unet/) | [Pytorch-Unet](https://github.com/milesial/Pytorch-UNet) | Pytorch | 1.2.5 and later |
| [<img src="image_segmentation/yet-another-anime-segmenter/output.png" width=128px>](image_segmentation/yet-another-anime-segmenter/) | [yet-another-anime-segmenter](/image_segmentation/yet-another-anime-segmenter/) | [Yet-Another-Anime-Segmenter](https://github.com/zymk9/Yet-Another-Anime-Segmenter) | Pytorch | 1.2.6 and later |
| [<img src="image_segmentation/swiftnet/output.png" width=128px>](image_segmentation/swiftnet/) | [swiftnet](/image_segmentation/swiftnet/) | [SwiftNet](https://github.com/orsic/swiftnet) | Pytorch | 1.2.6 and later |

## Natural language processing

Expand Down
674 changes: 674 additions & 0 deletions image_segmentation/swiftnet/LICENSE

Large diffs are not rendered by default.

55 changes: 55 additions & 0 deletions image_segmentation/swiftnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SwiftNet

## Input

![Input](input.png)

(Image from https://www.cityscapes-dataset.com/)

/datasets/Cityscapes/rgb/test/munich/munich_000000_000019_leftImg8bit.png

Ailia input shape: (1, 3, 1024, 2048)

## Output

![Output](output.png)

## Usage

Automatically downloads the onnx and prototxt files on the first run.
It is necessary to be connected to the Internet while downloading.

For the sample image,
``` bash
$ python3 swiftnet.py
```

If you want to specify the input image, put the image path after the `--input` option.
You can use `--savepath` option to change the name of the output file to save.
```bash
$ python3 swiftnet.py --input IMAGE_PATH --savepath SAVE_IMAGE_PATH
```

By adding the `--video` option, you can input the video.
If you pass `0` as an argument to VIDEO_PATH, you can use the webcam input instead of the video file.
```bash
$ python3 swiftnet.py --video VIDEO_PATH
```

The default setting is to use the optimized model and weights, but you can also switch to the normal model by using the --normal option.

## Reference

[SwiftNet](https://github.com/orsic/swiftnet)

## Framework

Pytorch

## Model Format

ONNX opset = 11

## Netron

[swiftnet.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/swiftnet/swiftnet.opt.onnx.prototxt)
Binary file added image_segmentation/swiftnet/input.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added image_segmentation/swiftnet/output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
150 changes: 150 additions & 0 deletions image_segmentation/swiftnet/swiftnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import sys
import time
import numpy as np
import cv2
from PIL import Image as pimg

from swiftnet_utils.labels import labels
from swiftnet_utils.color_lables import ColorizeLabels

import ailia

# import original modules
sys.path.append('../../util')
from utils import get_base_parser, update_parser, get_savepath # noqa: E402
from model_utils import check_and_download_models # noqa: E402
import webcamera_utils # noqa: E402 noqa: E402

# logger
from logging import getLogger # noqa: E402

logger = getLogger(__name__)

# ======================
# Parameters
# ======================
REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/swiftnet/'

WEIGHT_PATH = "swiftnet.opt.onnx"
MODEL_PATH = "swiftnet.opt.onnx.prototxt"

IMAGE_PATH = 'input.png'
SAVE_IMAGE_PATH = 'output.png'
HEIGHT = 1024
WIDTH = 2048

color_info = [label.color for label in labels if label.ignoreInEval is False]

# ======================
# Arguemnt Parser Config
# ======================
parser = get_base_parser('swiftnet model', IMAGE_PATH, SAVE_IMAGE_PATH)
args = update_parser(parser)


# ======================
# Main functions
# ======================
def recognize_from_image():
# net initialize
env_id = ailia.get_gpu_environment_id()
net = ailia.Net(MODEL_PATH, WEIGHT_PATH, env_id=env_id)

# input image loop
for image_path in args.input:
# prepare input data
logger.debug(f'input image: {image_path}')
img = cv2.imread(image_path)
logger.debug(f'input image shape: {img.shape}')
img = cv2.resize(img, (WIDTH, HEIGHT))
img = img.transpose(2, 0, 1)
img = np.expand_dims(img, 0)

# inference
logger.info('Start inference...')
if args.benchmark:
logger.info('BENCHMARK mode')
for i in range(5):
start = int(round(time.time() * 1000))
pred = net.predict(img)
end = int(round(time.time() * 1000))
logger.info(f'\tailia processing time {end - start} ms')
else:
pred = net.predict(img)

# postprocessing
to_color = ColorizeLabels(color_info)
pred = np.argmax(pred, axis=1)
pred = to_color(pred).astype(np.uint8)
pred = pimg.fromarray(pred[0])

# save
savepath = get_savepath(args.savepath, image_path)
logger.info(f'saved at : {savepath}')
pred.save(savepath)

if cv2.waitKey(0) != 32: # space bar
exit()


def recognize_from_video():
# net initialize
net = ailia.Net(MODEL_PATH, WEIGHT_PATH, env_id=args.env_id)

capture = webcamera_utils.get_capture(args.video)

# create video writer if savepath is specified as video format
f_h = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
f_w = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
if args.savepath != SAVE_IMAGE_PATH:
logger.warning(
'currently, video results cannot be output correctly...'
)
writer = webcamera_utils.get_writer(args.savepath, f_h, f_w, rgb=False)
else:
writer = None

while (True):
ret, frame = capture.read()
if (cv2.waitKey(1) & 0xFF == ord('q')) or not ret:
break

input = cv2.resize(frame, (WIDTH, HEIGHT))
input = input.transpose(2, 0, 1)
input = np.expand_dims(input, 0)

# inference
pred = net.predict(input)

# postprocessing
to_color = ColorizeLabels(color_info)
pred = np.argmax(pred, axis=1)[0]
pred = to_color(pred).astype(np.uint8)

cv2.imshow('frame', pred)

# save results
if writer is not None:
writer.write(pred)

capture.release()
cv2.destroyAllWindows()
if writer is not None:
writer.release()
logger.info('Script finished successfully.')


def main():
# model files check and download
check_and_download_models(WEIGHT_PATH, MODEL_PATH, REMOTE_PATH)

if args.video is not None:
# video mode
recognize_from_video()
else:
# image mode
recognize_from_image()


if __name__ == '__main__':
main()
67 changes: 67 additions & 0 deletions image_segmentation/swiftnet/swiftnet_utils/color_lables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from collections import defaultdict

import numpy as np
from PIL import Image as pimg

__all__ = ['ExtractInstances', 'RemapLabels', 'ColorizeLabels']


class ExtractInstances:
def __init__(self, inst_map_to_id=None):
self.inst_map_to_id = inst_map_to_id

def __call__(self, example: dict):
labels = np.int32(example['labels'])
unique_ids = np.unique(labels)
instances = defaultdict(list)
for id in filter(lambda x: x > 1000, unique_ids):
cls = self.inst_map_to_id.get(id // 1000, None)
if cls is not None:
instances[cls] += [labels == id]
example['instances'] = instances
return example


class RemapLabels:
def __init__(self, mapping: dict, ignore_id, total=35):
self.mapping = np.ones((max(total, max(mapping.keys())) + 1,), dtype=np.uint8) * ignore_id
self.ignore_id = ignore_id
for i in range(len(self.mapping)):
self.mapping[i] = mapping[i] if i in mapping else ignore_id

def _trans(self, labels):
max_k = self.mapping.shape[0] - 1
labels[labels > max_k] //= 1000
labels = self.mapping[labels].astype(labels.dtype)
return labels

def __call__(self, example):
if not isinstance(example, dict):
return self._trans(example)
if 'labels' not in example:
return example
ret_dict = {'labels': pimg.fromarray(self._trans(np.array(example['labels'])))}
if 'original_labels' in example:
ret_dict['original_labels'] = pimg.fromarray(self._trans(np.array(example['original_labels'])))
return {**example, **ret_dict}


class ColorizeLabels:
def __init__(self, color_info):
self.color_info = np.array(color_info)

def _trans(self, lab):
R, G, B = [np.zeros_like(lab) for _ in range(3)]
for l in np.unique(lab):
mask = lab == l
R[mask] = self.color_info[l][0]
G[mask] = self.color_info[l][1]
B[mask] = self.color_info[l][2]
return np.stack((R, G, B), axis=-1).astype(np.uint8)

def __call__(self, example):
if not isinstance(example, dict):
return self._trans(example)
assert 'labels' in example
return {**example, **{'labels': self._trans(example['labels']),
'original_labels': self._trans(example['original_labels'])}}
100 changes: 100 additions & 0 deletions image_segmentation/swiftnet/swiftnet_utils/labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from collections import namedtuple

#--------------------------------------------------------------------------------
# Definitions
#--------------------------------------------------------------------------------

# a label and all meta information
Label = namedtuple( 'Label' , [

'name' , # The identifier of this label, e.g. 'car', 'person', ... .
# We use them to uniquely name a class

'id' , # An integer ID that is associated with this label.
# The IDs are used to represent the label in ground truth images
# An ID of -1 means that this label does not have an ID and thus
# is ignored when creating ground truth images (e.g. license plate).
# Do not modify these IDs, since exactly these IDs are expected by the
# evaluation server.

'trainId' , # Feel free to modify these IDs as suitable for your method. Then create
# ground truth images with train IDs, using the tools provided in the
# 'preparation' folder. However, make sure to validate or submit results
# to our evaluation server using the regular IDs above!
# For trainIds, multiple labels might have the same ID. Then, these labels
# are mapped to the same class in the ground truth images. For the inverse
# mapping, we use the label that is defined first in the list below.
# For example, mapping all void-type classes to the same ID in training,
# might make sense for some approaches.
# Max value is 255!

'category' , # The name of the category that this label belongs to

'categoryId' , # The ID of this category. Used to create ground truth images
# on category level.

'hasInstances', # Whether this label distinguishes between single instances or not

'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
# during evaluations or not

'color' , # The color of this label
] )


#--------------------------------------------------------------------------------
# A list of all labels
#--------------------------------------------------------------------------------

# Please adapt the train IDs as appropriate for you approach.
# Note that you might want to ignore labels with ID 255 during training.
# Further note that the current train IDs are only a suggestion. You can use whatever you like.
# Make sure to provide your results using the original IDs and not the training IDs.
# Note that many IDs are ignored in evaluation and thus you never need to predict these!

labels = [
# name id trainId category catId hasInstances ignoreInEval color
Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ),
Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ),
Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ),
Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ),
Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ),
Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ),
Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ),
Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ),
Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ),
Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ),
Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ),
Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ),
Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ),
Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ),
Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ),
Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ),
Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ),
Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ),
Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ),
Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ),
]


def get_train_ids():
train_ids = []
for i in labels:
if not i.ignoreInEval:
train_ids.append(i.id)
return train_ids
3 changes: 2 additions & 1 deletion scripts/download_all_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ cd ../../image_segmentation/u2net; python3 u2net.py -a small ${OPTION}
cd ../../image_segmentation/human_part_segmentation; python3 human_part_segmentation.py ${OPTION}
cd ../../image_segmentation/pytorch-unet; python3 pytorch-unet.py ${OPTION}
cd ../../image_segmentation/semantic-segmentation-mobilenet-v3; python3 semantic-segmentation-mobilenet-v3.py ${OPTION}
cd ../../image_segmentation/yet-another-anime-segmenter python3 yet-another-anime-segmenter.py ${OPTION}
cd ../../image_segmentation/yet-another-anime-segmenter; python3 yet-another-anime-segmenter.py ${OPTION}
cd ../../image_segmentation/swiftnet; python3 swiftnet.py ${OPTION}
cd ../../neural_language_processing/bert; python3 bert.py ${OPTION}
cd ../../neural_language_processing/bert_tweets_sentiment; python3 bert_tweets_sentiment.py ${OPTION}
cd ../../neural_language_processing/bert_maskedlm; python3 bert_maskedlm.py ${OPTION}
Expand Down

0 comments on commit aa95b21

Please sign in to comment.