forked from tensorflow/tflite-support
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
13 changed files
with
524 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
159 changes: 159 additions & 0 deletions
159
tensorflow_lite_support/metadata/python/metadata_writers/audio_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Writes metadata and label file to the audio classifier models.""" | ||
|
||
from typing import List, Optional | ||
|
||
from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info | ||
from tensorflow_lite_support.metadata.python.metadata_writers import metadata_writer | ||
from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils | ||
|
||
_MODEL_NAME = "AudioClassifier" | ||
_MODEL_DESCRIPTION = ( | ||
"Identify the most prominent type in the audio clip from a known set of " | ||
"categories.") | ||
_INPUT_NAME = "audio_clip" | ||
_INPUT_DESCRIPTION = "Input audio clip to be classified." | ||
_OUTPUT_NAME = "probability" | ||
_OUTPUT_DESCRIPTION = "Scores of the labels respectively." | ||
_AUDIO_TENSOR_INDEX = 0 | ||
|
||
|
||
class MetadataWriter(metadata_writer.MetadataWriter): | ||
"""Writes metadata into an audio classifier.""" | ||
|
||
@classmethod | ||
def create_from_metadata_info( | ||
cls, | ||
model_buffer: bytearray, | ||
general_md: Optional[metadata_info.GeneralMd] = None, | ||
input_md: Optional[metadata_info.InputAudioTensorMd] = None, | ||
output_md: Optional[metadata_info.ClassificationTensorMd] = None): | ||
"""Creates MetadataWriter based on general/input/output information. | ||
Args: | ||
model_buffer: valid buffer of the model file. | ||
general_md: general infromation about the model. If not specified, default | ||
general metadata will be generated. | ||
input_md: input audio tensor informaton, if not specified, default input | ||
metadata will be generated. | ||
output_md: output classification tensor informaton, if not specified, | ||
default output metadata will be generated. | ||
Returns: | ||
A MetadataWriter object. | ||
""" | ||
|
||
if general_md is None: | ||
general_md = metadata_info.GeneralMd( | ||
name=_MODEL_NAME, description=_MODEL_DESCRIPTION) | ||
|
||
if input_md is None: | ||
input_md = metadata_info.InputAudioTensorMd( | ||
name=_INPUT_NAME, description=_INPUT_DESCRIPTION) | ||
|
||
if output_md is None: | ||
output_md = metadata_info.ClassificationTensorMd( | ||
name=_OUTPUT_NAME, description=_OUTPUT_DESCRIPTION) | ||
|
||
output_md.associated_files = output_md.associated_files or [] | ||
|
||
return super().create_from_metadata_info( | ||
model_buffer=model_buffer, | ||
general_md=general_md, | ||
input_md=[input_md], | ||
output_md=[output_md], | ||
associated_files=[ | ||
file.file_path for file in output_md.associated_files | ||
]) | ||
|
||
@classmethod | ||
def create_for_inference(cls, model_buffer: bytearray, sample_rate: int, | ||
channels: int, min_required_samples: int, | ||
label_file_paths: List[str]): | ||
"""Creates mandatory metadata for TFLite Support inference. | ||
The parameters required in this method are mandatory when using TFLite | ||
Support features, such as Task library and Codegen tool (Android Studio ML | ||
Binding). Other metadata fields will be set to default. If other fields need | ||
to be filled, use the method `create_from_metadata_info` to edit them. | ||
Args: | ||
model_buffer: valid buffer of the model file. | ||
sample_rate: the sample rate in Hz when the audio was captured. | ||
channels: the channel count of the audio. | ||
min_required_samples: the minimum required number of per-channel samples | ||
in order to run inference properly. Optional for fixed-size audio | ||
tensors and default to 0. The minimum required flat size of the audio | ||
tensor is min_required_samples x channels. | ||
label_file_paths: paths to the label files [1] in the classification | ||
tensor. Pass in an empty list if the model does not have any label file. | ||
[1]: | ||
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L95 | ||
Returns: | ||
A MetadataWriter object. | ||
Raises: | ||
ValueError: if either sample_rate or channels is non-positive, or if | ||
min_required_samples is negative. | ||
ValueError: if min_required_samples is 0, but the input audio tensor is | ||
not fixed-size. | ||
""" | ||
# To make Task Library working properly, sample_rate, channels, | ||
# min_required_samples need to be positive. | ||
if sample_rate <= 0: | ||
raise ValueError( | ||
"sample_rate should be positive, but got {}.".format(sample_rate)) | ||
|
||
if channels <= 0: | ||
raise ValueError( | ||
"channels should be positive, but got {}.".format(channels)) | ||
|
||
if min_required_samples < 0: | ||
raise ValueError( | ||
"min_required_samples should be non-negative, but got {}.".format( | ||
min_required_samples)) | ||
|
||
if min_required_samples == 0: | ||
tensor_shape = writer_utils.get_input_tensor_shape( | ||
model_buffer, _AUDIO_TENSOR_INDEX) | ||
|
||
# The dynamic size input shape can be an empty array or arrays like [1] | ||
# and [1, 1], where the flat size is 1. | ||
flat_size = writer_utils.compute_flat_size(tensor_shape) | ||
if not tensor_shape or flat_size == 1: | ||
raise ValueError( | ||
"The audio tensor is not fixed-size, therefore min_required_samples" | ||
"is required, and should be a positive value.") | ||
# Update min_required_samples if the audio tensor is fixed-size using | ||
# ceiling division. | ||
min_required_samples = (flat_size + channels -1) // channels | ||
|
||
input_md = metadata_info.InputAudioTensorMd(_INPUT_NAME, _INPUT_DESCRIPTION, | ||
sample_rate, channels, | ||
min_required_samples) | ||
|
||
output_md = metadata_info.ClassificationTensorMd( | ||
name=_OUTPUT_NAME, | ||
description=_OUTPUT_DESCRIPTION, | ||
label_files=[ | ||
metadata_info.LabelFileMd(file_path=file_path) | ||
for file_path in label_file_paths | ||
], | ||
tensor_type=writer_utils.get_output_tensor_types(model_buffer)[0]) | ||
|
||
return cls.create_from_metadata_info( | ||
model_buffer, input_md=input_md, output_md=output_md) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
142 changes: 142 additions & 0 deletions
142
tensorflow_lite_support/metadata/python/tests/metadata_writers/audio_classifier_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for AudioClassifier.MetadataWriter.""" | ||
|
||
from absl.testing import parameterized | ||
|
||
import tensorflow as tf | ||
|
||
from tensorflow_lite_support.metadata.python.metadata_writers import audio_classifier | ||
from tensorflow_lite_support.metadata.python.tests.metadata_writers import test_utils | ||
|
||
_FIXED_INPUT_SIZE_MODEL = "../testdata/audio_classifier/daredevil_sound_recognizer_320ms.tflite" | ||
_DYNAMIC_INPUT_SIZE_MODEL = "../testdata/audio_classifier/yamnet_tfhub.tflite" | ||
_LABEL_FILE = "../testdata/audio_classifier/labelmap.txt" | ||
_JSON_FOR_INFERENCE_DYNAMIC = "../testdata/audio_classifier/yamnet_tfhub.json" | ||
_JSON_FOR_INFERENCE_FIXED = "../testdata/audio_classifier/daredevil_sound_recognizer_320ms.json" | ||
_JSON_DEFAULT = "../testdata/audio_classifier/daredevil_sound_recognizer_320ms_default.json" | ||
_SAMPLE_RATE = 2 | ||
_CHANNELS = 1 | ||
_MIN_REQUIRED_SAMPLES_FIXED = 0 | ||
_MIN_REQUIRED_SAMPLES_DYNAMIC = 15600 | ||
|
||
|
||
class MetadataWriterTest(tf.test.TestCase): | ||
|
||
def test_create_for_inference_should_succeed_dynaamic_input_shape_model(self): | ||
writer = audio_classifier.MetadataWriter.create_for_inference( | ||
test_utils.load_file(_DYNAMIC_INPUT_SIZE_MODEL), _SAMPLE_RATE, | ||
_CHANNELS, _MIN_REQUIRED_SAMPLES_DYNAMIC, [_LABEL_FILE]) | ||
|
||
metadata_json = writer.get_metadata_json() | ||
expected_json = test_utils.load_file(_JSON_FOR_INFERENCE_DYNAMIC, "r") | ||
self.assertEqual(metadata_json, expected_json) | ||
|
||
def test_create_for_inference_should_succeed_with_fixed_input_shape_model( | ||
self): | ||
writer = audio_classifier.MetadataWriter.create_for_inference( | ||
test_utils.load_file(_FIXED_INPUT_SIZE_MODEL), _SAMPLE_RATE, _CHANNELS, | ||
_MIN_REQUIRED_SAMPLES_FIXED, [_LABEL_FILE]) | ||
|
||
metadata_json = writer.get_metadata_json() | ||
expected_json = test_utils.load_file(_JSON_FOR_INFERENCE_FIXED, "r") | ||
self.assertEqual(metadata_json, expected_json) | ||
|
||
def test_create_from_metadata_info_by_default_should_succeed(self): | ||
writer = audio_classifier.MetadataWriter.create_from_metadata_info( | ||
test_utils.load_file(_FIXED_INPUT_SIZE_MODEL)) | ||
|
||
metadata_json = writer.get_metadata_json() | ||
expected_json = test_utils.load_file(_JSON_DEFAULT, "r") | ||
self.assertEqual(metadata_json, expected_json) | ||
|
||
|
||
class MetadataWriterSampleRateTest(tf.test.TestCase, parameterized.TestCase): | ||
|
||
@parameterized.named_parameters( | ||
{ | ||
"testcase_name": "negative", | ||
"wrong_sample_rate": -1 | ||
}, { | ||
"testcase_name": "zero", | ||
"wrong_sample_rate": 0 | ||
}) | ||
def test_create_for_inference_fails_with_wrong_sample_rate( | ||
self, wrong_sample_rate): | ||
|
||
with self.assertRaises(ValueError) as error: | ||
audio_classifier.MetadataWriter.create_for_inference( | ||
test_utils.load_file(_DYNAMIC_INPUT_SIZE_MODEL), wrong_sample_rate, | ||
_CHANNELS, _MIN_REQUIRED_SAMPLES_DYNAMIC, [_LABEL_FILE]) | ||
|
||
self.assertEqual( | ||
"sample_rate should be positive, but got {}.".format(wrong_sample_rate), | ||
str(error.exception)) | ||
|
||
|
||
class MetadataWriterChannelsTest(tf.test.TestCase, parameterized.TestCase): | ||
|
||
@parameterized.named_parameters( | ||
{ | ||
"testcase_name": "negative", | ||
"wrong_channels": -1 | ||
}, { | ||
"testcase_name": "zero", | ||
"wrong_channels": 0 | ||
}) | ||
def test_create_for_inference_fails_with_wrong_channels(self, wrong_channels): | ||
|
||
with self.assertRaises(ValueError) as error: | ||
audio_classifier.MetadataWriter.create_for_inference( | ||
test_utils.load_file(_DYNAMIC_INPUT_SIZE_MODEL), _SAMPLE_RATE, | ||
wrong_channels, _MIN_REQUIRED_SAMPLES_DYNAMIC, [_LABEL_FILE]) | ||
|
||
self.assertEqual( | ||
"channels should be positive, but got {}.".format(wrong_channels), | ||
str(error.exception)) | ||
|
||
|
||
class MetadataWriterMinRequiredSamplesTest(tf.test.TestCase, | ||
parameterized.TestCase): | ||
|
||
@parameterized.named_parameters( | ||
{ | ||
"testcase_name": | ||
"negative", | ||
"wrong_min_required_samples": | ||
-1, | ||
"expected_error_message": | ||
"min_required_samples should be non-negative, but got -1." | ||
}, | ||
{ | ||
# min_required_samples cannot be zero with dynamic input size model. | ||
"testcase_name": "zero", | ||
"wrong_min_required_samples": 0, | ||
"expected_error_message": | ||
"The audio tensor is not fixed-size, therefore min_required_samples" | ||
"is required, and should be a positive value." | ||
}) | ||
def test_create_for_inference_fails_with_wrong_min_required_samples( | ||
self, wrong_min_required_samples, expected_error_message): | ||
with self.assertRaises(ValueError) as error: | ||
audio_classifier.MetadataWriter.create_for_inference( | ||
test_utils.load_file(_DYNAMIC_INPUT_SIZE_MODEL), _SAMPLE_RATE, | ||
_CHANNELS, wrong_min_required_samples, [_LABEL_FILE]) | ||
|
||
self.assertEqual(expected_error_message, str(error.exception)) | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() |
Oops, something went wrong.