-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #856 from khanhlvg:bert-clu-annotator-python
PiperOrigin-RevId: 461776140
- Loading branch information
Showing
12 changed files
with
914 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
syntax = "proto2"; | ||
|
||
package tflite.task.processor; | ||
|
||
import "tensorflow_lite_support/cc/task/processor/proto/class.proto"; | ||
|
||
// The input to CLU (Conversational Language Understanding). | ||
message CluRequest { | ||
// The utterances of dialogue conversation turns in the chronological order. | ||
// The last utterance is the current turn. | ||
repeated string utterances = 1; | ||
} | ||
|
||
// The output of CLU. | ||
// | ||
// Example input request: | ||
// ``` | ||
// utterances: "I would like to make a restaurant reservation at morning 11:15." | ||
// utterances: "Which restaurant do you want to go to?" | ||
// utterances: "Can I get a reservation for two people at Andes Cafe? Where is | ||
// their address?" | ||
// ``` | ||
// | ||
// Example output: | ||
// ``` | ||
// domains { display_name: "Restaurants" score: 0.91 } | ||
// intents { display_name: "request(street_address)" score: 0.79 } | ||
// categorical_slots { | ||
// slot: "party_size" | ||
// prediction: { display_name="2" score: 0.78 } | ||
// } | ||
// noncategorical_slots { | ||
// slot: "restaurant_name" | ||
// extraction: { value: "Andes Cafe" confidence: 0.91 start: 42 end: 52 } | ||
// } | ||
// ``` | ||
message CluResponse { | ||
// The list of predicted domains. | ||
repeated Class domains = 1; | ||
// The list of predicted intents. | ||
repeated Class intents = 2; | ||
// The list of predicted categorical slots. | ||
repeated CategoricalSlot categorical_slots = 3; | ||
// The list of predicted noncategorical slots. | ||
repeated NonCategoricalSlot noncategorical_slots = 4; | ||
} | ||
|
||
// Represents a categorical slot whose values are within a finite set. | ||
message CategoricalSlot { | ||
// The name of the slot. | ||
optional string slot = 1; | ||
// The predicted class. | ||
optional Class prediction = 2; | ||
} | ||
|
||
// A single extraction result. | ||
message Extraction { | ||
// The text value of the extraction. | ||
optional string value = 2; | ||
// The score for this extraction e.g. (but not necessarily) a probability in | ||
// [0,1]. | ||
optional float score = 3; | ||
// Start of the bytes of this extraction. | ||
optional uint32 start = 4; | ||
// Exclusive end of the bytes of this extraction. | ||
optional uint32 end = 5; | ||
} | ||
|
||
// Represents a non-categorical slot whose values are open text extracted from | ||
// the input text. | ||
message NonCategoricalSlot { | ||
// The name of the slot. | ||
optional string slot = 1; | ||
// The predicted extraction. | ||
optional Extraction extraction = 2; | ||
} |
37 changes: 37 additions & 0 deletions
37
tensorflow_lite_support/cc/task/processor/proto/clu_annotation_options.proto
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
syntax = "proto2"; | ||
|
||
package tflite.task.processor; | ||
|
||
// Options for setting up an BertCluAnnotator. | ||
// Next Id: 6 | ||
message BertCluAnnotationOptions { | ||
// Max number of history turns to encode by the model. | ||
optional int32 max_history_turns = 1 [default = 5]; | ||
|
||
// The threshold of domain prediction. | ||
optional float domain_threshold = 2 [default = 0.5]; | ||
|
||
// The threshold of intent prediction. | ||
optional float intent_threshold = 3 [default = 0.5]; | ||
|
||
// The threshold of categorical slot prediction. | ||
optional float categorical_slot_threshold = 4 [default = 0.5]; | ||
|
||
// The threshold of noncategorical slot prediction. | ||
optional float noncategorical_slot_threshold = 5 [default = 0.5]; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
79 changes: 79 additions & 0 deletions
79
tensorflow_lite_support/python/task/processor/proto/clu_annotation_options_pb2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""CLU annotation options protobuf.""" | ||
|
||
import dataclasses | ||
from typing import Any, Optional | ||
|
||
from tensorflow_lite_support.cc.task.processor.proto import clu_annotation_options_pb2 | ||
from tensorflow_lite_support.python.task.core.optional_dependencies import doc_controls | ||
|
||
_BertCluAnnotationOptionsProto = clu_annotation_options_pb2.BertCluAnnotationOptions | ||
|
||
|
||
@dataclasses.dataclass | ||
class BertCluAnnotationOptions: | ||
"""Options for Bert CLU Annotator processor. | ||
Attributes: | ||
max_history_turns: Max number of history turns to encode by the model. | ||
domain_threshold: The threshold of domain prediction. | ||
intent_threshold: The threshold of intent prediction. | ||
categorical_slot_threshold: The threshold of categorical slot prediction. | ||
noncategorical_slot_threshold: The threshold of noncategorical slot | ||
prediction. | ||
""" | ||
|
||
max_history_turns: Optional[int] = 5 | ||
domain_threshold: Optional[float] = 0.5 | ||
intent_threshold: Optional[float] = 0.5 | ||
categorical_slot_threshold: Optional[float] = 0.5 | ||
noncategorical_slot_threshold: Optional[float] = 0.5 | ||
|
||
@doc_controls.do_not_generate_docs | ||
def to_pb2(self) -> _BertCluAnnotationOptionsProto: | ||
"""Generates a protobuf object to pass to the C++ layer.""" | ||
return _BertCluAnnotationOptionsProto( | ||
max_history_turns=self.max_history_turns, | ||
domain_threshold=self.domain_threshold, | ||
intent_threshold=self.intent_threshold, | ||
categorical_slot_threshold=self.categorical_slot_threshold, | ||
noncategorical_slot_threshold=self.noncategorical_slot_threshold) | ||
|
||
@classmethod | ||
@doc_controls.do_not_generate_docs | ||
def create_from_pb2( | ||
cls, | ||
pb2_obj: _BertCluAnnotationOptionsProto) -> "BertCluAnnotationOptions": | ||
"""Creates a `BertCluAnnotationOptions` object from the given protobuf object.""" | ||
return BertCluAnnotationOptions( | ||
max_history_turns=pb2_obj.max_history_turns, | ||
domain_threshold=pb2_obj.domain_threshold, | ||
intent_threshold=pb2_obj.intent_threshold, | ||
categorical_slot_threshold=pb2_obj.categorical_slot_threshold, | ||
noncategorical_slot_threshold=pb2_obj.noncategorical_slot_threshold) | ||
|
||
def __eq__(self, other: Any) -> bool: | ||
"""Checks if this object is equal to the given object. | ||
Args: | ||
other: The object to be compared with. | ||
Returns: | ||
True if the objects are equal. | ||
""" | ||
if not isinstance(other, BertCluAnnotationOptions): | ||
return False | ||
|
||
return self.to_pb2().__eq__(other.to_pb2()) |
Oops, something went wrong.