forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request tensorflow#4232 from pkulzc/master
Release ssdlite mobilenet v2 coco trained model, add quantized training and minor fixes.
Showing
42 changed files
with
1,088 additions
and
299 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
42 changes: 42 additions & 0 deletions
42
research/object_detection/builders/graph_rewriter_builder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Functions for quantized training and evaluation.""" | ||
|
||
import tensorflow as tf | ||
|
||
|
||
def build(graph_rewriter_config, is_training): | ||
"""Returns a function that modifies default graph based on options. | ||
Args: | ||
graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto. | ||
is_training: whether in training of eval mode. | ||
""" | ||
def graph_rewrite_fn(): | ||
"""Function to quantize weights and activation of the default graph.""" | ||
if (graph_rewriter_config.quantization.weight_bits != 8 or | ||
graph_rewriter_config.quantization.activation_bits != 8): | ||
raise ValueError('Only 8bit quantization is supported') | ||
|
||
# Quantize the graph by inserting quantize ops for weights and activations | ||
if is_training: | ||
tf.contrib.quantize.create_training_graph( | ||
input_graph=tf.get_default_graph(), | ||
quant_delay=graph_rewriter_config.quantization.delay) | ||
else: | ||
tf.contrib.quantize.create_eval_graph(input_graph=tf.get_default_graph()) | ||
|
||
tf.contrib.layers.summarize_collection('quant_vars') | ||
return graph_rewrite_fn |
57 changes: 57 additions & 0 deletions
57
research/object_detection/builders/graph_rewriter_builder_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for graph_rewriter_builder.""" | ||
import mock | ||
import tensorflow as tf | ||
from object_detection.builders import graph_rewriter_builder | ||
from object_detection.protos import graph_rewriter_pb2 | ||
|
||
|
||
class QuantizationBuilderTest(tf.test.TestCase): | ||
|
||
def testQuantizationBuilderSetsUpCorrectTrainArguments(self): | ||
with mock.patch.object( | ||
tf.contrib.quantize, 'create_training_graph') as mock_quant_fn: | ||
with mock.patch.object(tf.contrib.layers, | ||
'summarize_collection') as mock_summarize_col: | ||
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter() | ||
graph_rewriter_proto.quantization.delay = 10 | ||
graph_rewriter_proto.quantization.weight_bits = 8 | ||
graph_rewriter_proto.quantization.activation_bits = 8 | ||
graph_rewrite_fn = graph_rewriter_builder.build( | ||
graph_rewriter_proto, is_training=True) | ||
graph_rewrite_fn() | ||
_, kwargs = mock_quant_fn.call_args | ||
self.assertEqual(kwargs['input_graph'], tf.get_default_graph()) | ||
self.assertEqual(kwargs['quant_delay'], 10) | ||
mock_summarize_col.assert_called_with('quant_vars') | ||
|
||
def testQuantizationBuilderSetsUpCorrectEvalArguments(self): | ||
with mock.patch.object(tf.contrib.quantize, | ||
'create_eval_graph') as mock_quant_fn: | ||
with mock.patch.object(tf.contrib.layers, | ||
'summarize_collection') as mock_summarize_col: | ||
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter() | ||
graph_rewriter_proto.quantization.delay = 10 | ||
graph_rewrite_fn = graph_rewriter_builder.build( | ||
graph_rewriter_proto, is_training=False) | ||
graph_rewrite_fn() | ||
_, kwargs = mock_quant_fn.call_args | ||
self.assertEqual(kwargs['input_graph'], tf.get_default_graph()) | ||
mock_summarize_col.assert_called_with('quant_vars') | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.