Skip to content

Commit

Permalink
MetadataWriter for AudioClassifier
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 365262605
  • Loading branch information
lu-wang-g authored and tflite-support-robot committed Mar 26, 2021
1 parent db89782 commit eca4b51
Show file tree
Hide file tree
Showing 13 changed files with 524 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tensorflow_lite_support/metadata/python/metadata_writers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,17 @@ py_library(
":writer_utils",
],
)

py_library(
name = "audio_classifier",
srcs = [
"audio_classifier.py",
],
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [
":metadata_info",
":metadata_writer",
":writer_utils",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Writes metadata and label file to the audio classifier models."""

from typing import List, Optional

from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info
from tensorflow_lite_support.metadata.python.metadata_writers import metadata_writer
from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils

_MODEL_NAME = "AudioClassifier"
_MODEL_DESCRIPTION = (
"Identify the most prominent type in the audio clip from a known set of "
"categories.")
_INPUT_NAME = "audio_clip"
_INPUT_DESCRIPTION = "Input audio clip to be classified."
_OUTPUT_NAME = "probability"
_OUTPUT_DESCRIPTION = "Scores of the labels respectively."
_AUDIO_TENSOR_INDEX = 0


class MetadataWriter(metadata_writer.MetadataWriter):
"""Writes metadata into an audio classifier."""

@classmethod
def create_from_metadata_info(
cls,
model_buffer: bytearray,
general_md: Optional[metadata_info.GeneralMd] = None,
input_md: Optional[metadata_info.InputAudioTensorMd] = None,
output_md: Optional[metadata_info.ClassificationTensorMd] = None):
"""Creates MetadataWriter based on general/input/output information.
Args:
model_buffer: valid buffer of the model file.
general_md: general infromation about the model. If not specified, default
general metadata will be generated.
input_md: input audio tensor informaton, if not specified, default input
metadata will be generated.
output_md: output classification tensor informaton, if not specified,
default output metadata will be generated.
Returns:
A MetadataWriter object.
"""

if general_md is None:
general_md = metadata_info.GeneralMd(
name=_MODEL_NAME, description=_MODEL_DESCRIPTION)

if input_md is None:
input_md = metadata_info.InputAudioTensorMd(
name=_INPUT_NAME, description=_INPUT_DESCRIPTION)

if output_md is None:
output_md = metadata_info.ClassificationTensorMd(
name=_OUTPUT_NAME, description=_OUTPUT_DESCRIPTION)

output_md.associated_files = output_md.associated_files or []

return super().create_from_metadata_info(
model_buffer=model_buffer,
general_md=general_md,
input_md=[input_md],
output_md=[output_md],
associated_files=[
file.file_path for file in output_md.associated_files
])

@classmethod
def create_for_inference(cls, model_buffer: bytearray, sample_rate: int,
channels: int, min_required_samples: int,
label_file_paths: List[str]):
"""Creates mandatory metadata for TFLite Support inference.
The parameters required in this method are mandatory when using TFLite
Support features, such as Task library and Codegen tool (Android Studio ML
Binding). Other metadata fields will be set to default. If other fields need
to be filled, use the method `create_from_metadata_info` to edit them.
Args:
model_buffer: valid buffer of the model file.
sample_rate: the sample rate in Hz when the audio was captured.
channels: the channel count of the audio.
min_required_samples: the minimum required number of per-channel samples
in order to run inference properly. Optional for fixed-size audio
tensors and default to 0. The minimum required flat size of the audio
tensor is min_required_samples x channels.
label_file_paths: paths to the label files [1] in the classification
tensor. Pass in an empty list if the model does not have any label file.
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L95
Returns:
A MetadataWriter object.
Raises:
ValueError: if either sample_rate or channels is non-positive, or if
min_required_samples is negative.
ValueError: if min_required_samples is 0, but the input audio tensor is
not fixed-size.
"""
# To make Task Library working properly, sample_rate, channels,
# min_required_samples need to be positive.
if sample_rate <= 0:
raise ValueError(
"sample_rate should be positive, but got {}.".format(sample_rate))

if channels <= 0:
raise ValueError(
"channels should be positive, but got {}.".format(channels))

if min_required_samples < 0:
raise ValueError(
"min_required_samples should be non-negative, but got {}.".format(
min_required_samples))

if min_required_samples == 0:
tensor_shape = writer_utils.get_input_tensor_shape(
model_buffer, _AUDIO_TENSOR_INDEX)

# The dynamic size input shape can be an empty array or arrays like [1]
# and [1, 1], where the flat size is 1.
flat_size = writer_utils.compute_flat_size(tensor_shape)
if not tensor_shape or flat_size == 1:
raise ValueError(
"The audio tensor is not fixed-size, therefore min_required_samples"
"is required, and should be a positive value.")
# Update min_required_samples if the audio tensor is fixed-size using
# ceiling division.
min_required_samples = (flat_size + channels -1) // channels

input_md = metadata_info.InputAudioTensorMd(_INPUT_NAME, _INPUT_DESCRIPTION,
sample_rate, channels,
min_required_samples)

output_md = metadata_info.ClassificationTensorMd(
name=_OUTPUT_NAME,
description=_OUTPUT_DESCRIPTION,
label_files=[
metadata_info.LabelFileMd(file_path=file_path)
for file_path in label_file_paths
],
tensor_type=writer_utils.get_output_tensor_types(model_buffer)[0])

return cls.create_from_metadata_info(
model_buffer, input_md=input_md, output_md=output_md)
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,28 @@
# ==============================================================================
"""Helper methods for writing metadata into TFLite models."""

import array
import functools
from typing import List, Union, Optional

from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb
from tensorflow_lite_support.metadata import schema_py_generated as _schema_fb


def compute_flat_size(tensor_shape: Optional[array.array]) -> int:
"""Computes the flat size (number of elements) of tensor shape.
Args:
tensor_shape: an array of the tensor shape values.
Returns:
The flat size of the tensor shape. Return 0 if tensor_shape is None.
"""
if not tensor_shape:
return 0
return functools.reduce(lambda x, y: x * y, tensor_shape)


def get_input_tensor_types(
model_buffer: bytearray) -> List[_schema_fb.TensorType]:
"""Gets a list of the input tensor types."""
Expand All @@ -42,6 +58,13 @@ def get_output_tensor_types(
return tensor_types


def get_input_tensor_shape(model_buffer: bytearray,
tensor_index: int) -> array.array:
"""Gets the shape of the specified input tensor."""
subgraph = _get_subgraph(model_buffer)
return subgraph.Tensors(subgraph.Inputs(tensor_index)).ShapeAsNumpy()


def load_file(file_path: str, mode: str = "rb") -> Union[str, bytes]:
"""Loads file from the file path.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,18 @@ py_test(
"@flatbuffers//:runtime_py",
],
)

py_test(
name = "audio_classifier_test",
srcs = ["audio_classifier_test.py"],
data = ["//tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier:test_files"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":test_utils",
# build rule placeholder: tensorflow dep,
"//tensorflow_lite_support/metadata/python/metadata_writers:audio_classifier",
"@absl_py//absl/testing:parameterized",
"@flatbuffers//:runtime_py",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for AudioClassifier.MetadataWriter."""

from absl.testing import parameterized

import tensorflow as tf

from tensorflow_lite_support.metadata.python.metadata_writers import audio_classifier
from tensorflow_lite_support.metadata.python.tests.metadata_writers import test_utils

_FIXED_INPUT_SIZE_MODEL = "../testdata/audio_classifier/daredevil_sound_recognizer_320ms.tflite"
_DYNAMIC_INPUT_SIZE_MODEL = "../testdata/audio_classifier/yamnet_tfhub.tflite"
_LABEL_FILE = "../testdata/audio_classifier/labelmap.txt"
_JSON_FOR_INFERENCE_DYNAMIC = "../testdata/audio_classifier/yamnet_tfhub.json"
_JSON_FOR_INFERENCE_FIXED = "../testdata/audio_classifier/daredevil_sound_recognizer_320ms.json"
_JSON_DEFAULT = "../testdata/audio_classifier/daredevil_sound_recognizer_320ms_default.json"
_SAMPLE_RATE = 2
_CHANNELS = 1
_MIN_REQUIRED_SAMPLES_FIXED = 0
_MIN_REQUIRED_SAMPLES_DYNAMIC = 15600


class MetadataWriterTest(tf.test.TestCase):

def test_create_for_inference_should_succeed_dynaamic_input_shape_model(self):
writer = audio_classifier.MetadataWriter.create_for_inference(
test_utils.load_file(_DYNAMIC_INPUT_SIZE_MODEL), _SAMPLE_RATE,
_CHANNELS, _MIN_REQUIRED_SAMPLES_DYNAMIC, [_LABEL_FILE])

metadata_json = writer.get_metadata_json()
expected_json = test_utils.load_file(_JSON_FOR_INFERENCE_DYNAMIC, "r")
self.assertEqual(metadata_json, expected_json)

def test_create_for_inference_should_succeed_with_fixed_input_shape_model(
self):
writer = audio_classifier.MetadataWriter.create_for_inference(
test_utils.load_file(_FIXED_INPUT_SIZE_MODEL), _SAMPLE_RATE, _CHANNELS,
_MIN_REQUIRED_SAMPLES_FIXED, [_LABEL_FILE])

metadata_json = writer.get_metadata_json()
expected_json = test_utils.load_file(_JSON_FOR_INFERENCE_FIXED, "r")
self.assertEqual(metadata_json, expected_json)

def test_create_from_metadata_info_by_default_should_succeed(self):
writer = audio_classifier.MetadataWriter.create_from_metadata_info(
test_utils.load_file(_FIXED_INPUT_SIZE_MODEL))

metadata_json = writer.get_metadata_json()
expected_json = test_utils.load_file(_JSON_DEFAULT, "r")
self.assertEqual(metadata_json, expected_json)


class MetadataWriterSampleRateTest(tf.test.TestCase, parameterized.TestCase):

@parameterized.named_parameters(
{
"testcase_name": "negative",
"wrong_sample_rate": -1
}, {
"testcase_name": "zero",
"wrong_sample_rate": 0
})
def test_create_for_inference_fails_with_wrong_sample_rate(
self, wrong_sample_rate):

with self.assertRaises(ValueError) as error:
audio_classifier.MetadataWriter.create_for_inference(
test_utils.load_file(_DYNAMIC_INPUT_SIZE_MODEL), wrong_sample_rate,
_CHANNELS, _MIN_REQUIRED_SAMPLES_DYNAMIC, [_LABEL_FILE])

self.assertEqual(
"sample_rate should be positive, but got {}.".format(wrong_sample_rate),
str(error.exception))


class MetadataWriterChannelsTest(tf.test.TestCase, parameterized.TestCase):

@parameterized.named_parameters(
{
"testcase_name": "negative",
"wrong_channels": -1
}, {
"testcase_name": "zero",
"wrong_channels": 0
})
def test_create_for_inference_fails_with_wrong_channels(self, wrong_channels):

with self.assertRaises(ValueError) as error:
audio_classifier.MetadataWriter.create_for_inference(
test_utils.load_file(_DYNAMIC_INPUT_SIZE_MODEL), _SAMPLE_RATE,
wrong_channels, _MIN_REQUIRED_SAMPLES_DYNAMIC, [_LABEL_FILE])

self.assertEqual(
"channels should be positive, but got {}.".format(wrong_channels),
str(error.exception))


class MetadataWriterMinRequiredSamplesTest(tf.test.TestCase,
parameterized.TestCase):

@parameterized.named_parameters(
{
"testcase_name":
"negative",
"wrong_min_required_samples":
-1,
"expected_error_message":
"min_required_samples should be non-negative, but got -1."
},
{
# min_required_samples cannot be zero with dynamic input size model.
"testcase_name": "zero",
"wrong_min_required_samples": 0,
"expected_error_message":
"The audio tensor is not fixed-size, therefore min_required_samples"
"is required, and should be a positive value."
})
def test_create_for_inference_fails_with_wrong_min_required_samples(
self, wrong_min_required_samples, expected_error_message):
with self.assertRaises(ValueError) as error:
audio_classifier.MetadataWriter.create_for_inference(
test_utils.load_file(_DYNAMIC_INPUT_SIZE_MODEL), _SAMPLE_RATE,
_CHANNELS, wrong_min_required_samples, [_LABEL_FILE])

self.assertEqual(expected_error_message, str(error.exception))


if __name__ == "__main__":
tf.test.main()
Loading

0 comments on commit eca4b51

Please sign in to comment.