Skip to content

Commit

Permalink
Add imgviz.io.lblsave
Browse files Browse the repository at this point in the history
  • Loading branch information
wkentaro committed Nov 19, 2024
1 parent 5feea2a commit 088f36a
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
1 change: 1 addition & 0 deletions imgviz/_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._pyglet import pyglet_run
from .base import imread
from .base import imsave
from .base import lblsave
from .opencv import cv_imshow
from .opencv import cv_waitkey
from .pil import pil_imshow
Expand Down
29 changes: 29 additions & 0 deletions imgviz/_io/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
import os.path as osp
import pathlib
from typing import Union

import numpy as np # NOQA
import PIL.Image

from .. import utils
from ..label import label_colormap


def imread(filename):
Expand Down Expand Up @@ -45,3 +48,29 @@ def imsave(filename, arr):
except OSError:
pass
return utils.numpy_to_pillow(arr).save(filename)


def lblsave(filename: Union[str, pathlib.Path], lbl: np.ndarray) -> None:
"""Save label image to PNG file with a colormap.
Parameters
----------
filename: str | pathlib.Path
Filename. Must end with '.png'.
lbl: numpy.ndarray, (H, W), np.uint8
Label image to save.
Returns
-------
None
"""
if not str(filename).lower().endswith(".png"):
raise ValueError(f"filename must end with '.png': {filename}")
if lbl.dtype != np.uint8:
raise ValueError(f"lbl.dtype must be np.uint8, but got {lbl.dtype}")

lbl_pil = PIL.Image.fromarray(lbl, mode="P")
colormap = label_colormap(n_label=256)
lbl_pil.putpalette(colormap.flatten())
lbl_pil.save(filename)
22 changes: 22 additions & 0 deletions tests/io_tests/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pathlib

import numpy as np

import imgviz


def test_lblsave(tmp_path: pathlib.Path) -> None:
data: dict = imgviz.data.arc2017()

label_cls: np.ndarray = data["class_label"]

assert label_cls.min() == 0
assert label_cls.max() == 25

label_cls = label_cls.astype(np.uint8)

png_file: pathlib.Path = tmp_path / "label_cls.png"
imgviz.io.lblsave(png_file, label_cls)
label_cls_read = imgviz.io.imread(png_file)

np.testing.assert_allclose(label_cls, label_cls_read)

0 comments on commit 088f36a

Please sign in to comment.