Skip to content

Commit

Permalink
Merge pull request #856 from khanhlvg:bert-clu-annotator-python
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 461776140
  • Loading branch information
tflite-support-robot committed Jul 19, 2022
2 parents c69ac32 + 476c2fa commit d8d5e93
Show file tree
Hide file tree
Showing 12 changed files with 914 additions and 0 deletions.
40 changes: 40 additions & 0 deletions tensorflow_lite_support/cc/task/processor/proto/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,46 @@ support_py_proto_library(
proto_deps = [":search_result_proto"],
)

proto_library(
name = "clu_annotation_options_proto",
srcs = ["clu_annotation_options.proto"],
)

cc_proto_library(
name = "clu_annotation_options_cc_proto",
deps = [":clu_annotation_options_proto"],
)

support_py_proto_library(
name = "clu_annotation_options_py_pb2",
srcs = ["clu_annotation_options.proto"],
api_version = 2,
proto_deps = [":clu_annotation_options_proto"],
)

proto_library(
name = "clu_proto",
srcs = ["clu.proto"],
deps = [
":class_proto",
],
)

cc_proto_library(
name = "clu_cc_proto",
deps = [":clu_proto"],
)

support_py_proto_library(
name = "clu_py_pb2",
srcs = ["clu.proto"],
api_version = 2,
proto_deps = [":clu_proto"],
py_proto_deps = [
":class_py_pb2",
],
)

proto_library(
name = "qa_answers_proto",
srcs = ["qa_answers.proto"],
Expand Down
91 changes: 91 additions & 0 deletions tensorflow_lite_support/cc/task/processor/proto/clu.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

syntax = "proto2";

package tflite.task.processor;

import "tensorflow_lite_support/cc/task/processor/proto/class.proto";

// The input to CLU (Conversational Language Understanding).
message CluRequest {
// The utterances of dialogue conversation turns in the chronological order.
// The last utterance is the current turn.
repeated string utterances = 1;
}

// The output of CLU.
//
// Example input request:
// ```
// utterances: "I would like to make a restaurant reservation at morning 11:15."
// utterances: "Which restaurant do you want to go to?"
// utterances: "Can I get a reservation for two people at Andes Cafe? Where is
// their address?"
// ```
//
// Example output:
// ```
// domains { display_name: "Restaurants" score: 0.91 }
// intents { display_name: "request(street_address)" score: 0.79 }
// categorical_slots {
// slot: "party_size"
// prediction: { display_name="2" score: 0.78 }
// }
// noncategorical_slots {
// slot: "restaurant_name"
// extraction: { value: "Andes Cafe" confidence: 0.91 start: 42 end: 52 }
// }
// ```
message CluResponse {
// The list of predicted domains.
repeated Class domains = 1;
// The list of predicted intents.
repeated Class intents = 2;
// The list of predicted categorical slots.
repeated CategoricalSlot categorical_slots = 3;
// The list of predicted noncategorical slots.
repeated NonCategoricalSlot noncategorical_slots = 4;
}

// Represents a categorical slot whose values are within a finite set.
message CategoricalSlot {
// The name of the slot.
optional string slot = 1;
// The predicted class.
optional Class prediction = 2;
}

// A single extraction result.
message Extraction {
// The text value of the extraction.
optional string value = 2;
// The score for this extraction e.g. (but not necessarily) a probability in
// [0,1].
optional float score = 3;
// Start of the bytes of this extraction.
optional uint32 start = 4;
// Exclusive end of the bytes of this extraction.
optional uint32 end = 5;
}

// Represents a non-categorical slot whose values are open text extracted from
// the input text.
message NonCategoricalSlot {
// The name of the slot.
optional string slot = 1;
// The predicted extraction.
optional Extraction extraction = 2;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

syntax = "proto2";

package tflite.task.processor;

// Options for setting up an BertCluAnnotator.
// Next Id: 6
message BertCluAnnotationOptions {
// Max number of history turns to encode by the model.
optional int32 max_history_turns = 1 [default = 5];

// The threshold of domain prediction.
optional float domain_threshold = 2 [default = 0.5];

// The threshold of intent prediction.
optional float intent_threshold = 3 [default = 0.5];

// The threshold of categorical slot prediction.
optional float categorical_slot_threshold = 4 [default = 0.5];

// The threshold of noncategorical slot prediction.
optional float noncategorical_slot_threshold = 5 [default = 0.5];
}
19 changes: 19 additions & 0 deletions tensorflow_lite_support/python/task/processor/proto/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,22 @@ py_library(
"//tensorflow_lite_support/python/task/core:optional_dependencies",
],
)

py_library(
name = "clu_annotation_options_pb2",
srcs = ["clu_annotation_options_pb2.py"],
deps = [
"//tensorflow_lite_support/cc/task/processor/proto:clu_annotation_options_py_pb2",
"//tensorflow_lite_support/python/task/core:optional_dependencies",
],
)

py_library(
name = "clu_pb2",
srcs = ["clu_pb2.py"],
deps = [
":class_pb2",
"//tensorflow_lite_support/cc/task/processor/proto:clu_py_pb2",
"//tensorflow_lite_support/python/task/core:optional_dependencies",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CLU annotation options protobuf."""

import dataclasses
from typing import Any, Optional

from tensorflow_lite_support.cc.task.processor.proto import clu_annotation_options_pb2
from tensorflow_lite_support.python.task.core.optional_dependencies import doc_controls

_BertCluAnnotationOptionsProto = clu_annotation_options_pb2.BertCluAnnotationOptions


@dataclasses.dataclass
class BertCluAnnotationOptions:
"""Options for Bert CLU Annotator processor.
Attributes:
max_history_turns: Max number of history turns to encode by the model.
domain_threshold: The threshold of domain prediction.
intent_threshold: The threshold of intent prediction.
categorical_slot_threshold: The threshold of categorical slot prediction.
noncategorical_slot_threshold: The threshold of noncategorical slot
prediction.
"""

max_history_turns: Optional[int] = 5
domain_threshold: Optional[float] = 0.5
intent_threshold: Optional[float] = 0.5
categorical_slot_threshold: Optional[float] = 0.5
noncategorical_slot_threshold: Optional[float] = 0.5

@doc_controls.do_not_generate_docs
def to_pb2(self) -> _BertCluAnnotationOptionsProto:
"""Generates a protobuf object to pass to the C++ layer."""
return _BertCluAnnotationOptionsProto(
max_history_turns=self.max_history_turns,
domain_threshold=self.domain_threshold,
intent_threshold=self.intent_threshold,
categorical_slot_threshold=self.categorical_slot_threshold,
noncategorical_slot_threshold=self.noncategorical_slot_threshold)

@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls,
pb2_obj: _BertCluAnnotationOptionsProto) -> "BertCluAnnotationOptions":
"""Creates a `BertCluAnnotationOptions` object from the given protobuf object."""
return BertCluAnnotationOptions(
max_history_turns=pb2_obj.max_history_turns,
domain_threshold=pb2_obj.domain_threshold,
intent_threshold=pb2_obj.intent_threshold,
categorical_slot_threshold=pb2_obj.categorical_slot_threshold,
noncategorical_slot_threshold=pb2_obj.noncategorical_slot_threshold)

def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, BertCluAnnotationOptions):
return False

return self.to_pb2().__eq__(other.to_pb2())
Loading

0 comments on commit d8d5e93

Please sign in to comment.