Skip to content

Commit

Permalink
🚸 Add predict_kwargs for models predict() function
Browse files Browse the repository at this point in the history
  • Loading branch information
BrikerMan committed Jul 15, 2019
1 parent 22ad76f commit d81611f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
8 changes: 6 additions & 2 deletions kashgari/tasks/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,8 @@ def compile_model(self, **kwargs):
def predict(self,
x_data,
batch_size=32,
debug_info=False):
debug_info=False,
predict_kwargs: Dict = None):
"""
Generates output predictions for the input samples.
Expand All @@ -383,17 +384,20 @@ def predict(self,
x_data: The input data, as a Numpy array (or list of Numpy arrays if the model has multiple inputs).
batch_size: Integer. If unspecified, it will default to 32.
debug_info: Bool, Should print out the logging info.
predict_kwargs: arguments passed to ``predict()`` function of ``tf.keras.Model``
Returns:
array(s) of predictions.
"""
if predict_kwargs is None:
predict_kwargs = {}
with utils.custom_object_scope():
if isinstance(x_data, tuple):
lengths = [len(sen) for sen in x_data[0]]
else:
lengths = [len(sen) for sen in x_data]
tensor = self.embedding.process_x_dataset(x_data)
pred = self.tf_model.predict(tensor, batch_size=batch_size)
pred = self.tf_model.predict(tensor, batch_size=batch_size, **predict_kwargs)
res = self.embedding.reverse_numerize_label_sequences(pred.argmax(-1),
lengths)
if debug_info:
Expand Down
12 changes: 9 additions & 3 deletions kashgari/tasks/classification/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def predict(self,
x_data,
batch_size=32,
multi_label_threshold: float = 0.5,
debug_info=False):
debug_info=False,
predict_kwargs: Dict = None):
"""
Generates output predictions for the input samples.
Expand All @@ -54,6 +55,7 @@ def predict(self,
batch_size: Integer. If unspecified, it will default to 32.
multi_label_threshold:
debug_info: Bool, Should print out the logging info.
predict_kwargs: arguments passed to ``predict()`` function of ``tf.keras.Model``
Returns:
array(s) of predictions.
Expand Down Expand Up @@ -81,7 +83,8 @@ def predict_top_k_class(self,
x_data,
top_k=5,
batch_size=32,
debug_info=False) -> List[Dict]:
debug_info=False,
predict_kwargs: Dict = None) -> List[Dict]:
"""
Generates output predictions with confidence for the input samples.
Expand All @@ -92,6 +95,7 @@ def predict_top_k_class(self,
top_k: int
batch_size: Integer. If unspecified, it will default to 32.
debug_info: Bool, Should print out the logging info.
predict_kwargs: arguments passed to ``predict()`` function of ``tf.keras.Model``
Returns:
array(s) of predictions.
Expand Down Expand Up @@ -121,9 +125,11 @@ def predict_top_k_class(self,
}
]
"""
if predict_kwargs is None:
predict_kwargs = {}
with kashgari.utils.custom_object_scope():
tensor = self.embedding.process_x_dataset(x_data)
pred = self.tf_model.predict(tensor, batch_size=batch_size)
pred = self.tf_model.predict(tensor, batch_size=batch_size, **predict_kwargs)
new_results = []

for sample_prob in pred:
Expand Down
6 changes: 4 additions & 2 deletions kashgari/tasks/labeling/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@ def predict_entities(self,
x_data,
batch_size=None,
join_chunk=' ',
debug_info=False):
debug_info=False,
predict_kwargs: Dict = None):
"""Gets entities from sequence.
Args:
x_data: The input data, as a Numpy array (or list of Numpy arrays if the model has multiple inputs).
batch_size: Integer. If unspecified, it will default to 32.
join_chunk: str or False,
debug_info: Bool, Should print out the logging info.
predict_kwargs: arguments passed to ``predict()`` function of ``tf.keras.Model``
Returns:
list: list of entity.
Expand All @@ -47,7 +49,7 @@ def predict_entities(self,
text_seq = x_data[0]
else:
text_seq = x_data
res = self.predict(x_data, batch_size, debug_info)
res = self.predict(x_data, batch_size, debug_info, predict_kwargs)
new_res = [get_entities(seq) for seq in res]
final_res = []
for index, seq in enumerate(new_res):
Expand Down

0 comments on commit d81611f

Please sign in to comment.