Skip to content

Commit

Permalink
Removed dependency on tensorflow.contrib for HParams.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 358219418
  • Loading branch information
wanxinwx authored and copybara-github committed Feb 18, 2021
1 parent 265d410 commit c374ee4
Show file tree
Hide file tree
Showing 6 changed files with 1,177 additions and 7 deletions.
8 changes: 4 additions & 4 deletions graph_compression/compression_lib/compression_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
import numpy as np
import tensorflow.compat.v1 as tf
from graph_compression.compression_lib import compression_op_utils as comp_op_utils
from tensorflow.contrib import training as contrib_training
from model_pruning.python import hparam as contrib_hparam


class MatrixCompressorInferface(object):
Expand Down Expand Up @@ -131,7 +131,7 @@ def get_default_hparams():
Returns:
tf.HParams object initialized to default values.
"""
return contrib_training.HParams(
return contrib_hparam.HParams(
name='model_compression',
rank=100,
num_rows=10,
Expand Down Expand Up @@ -333,7 +333,7 @@ def get_default_hparams():
tf.HParams object initialized to default values.
"""
return contrib_training.HParams(
return contrib_hparam.HParams(
name='model_compression',
alpha_decrement_value=0.01,
begin_compression_step=0,
Expand Down Expand Up @@ -889,7 +889,7 @@ def get_default_hparams():
tf.HParams object initialized to default values.
"""
return contrib_training.HParams(
return contrib_hparam.HParams(
name='input_compression',
compression_frequency=10,
use_tpu=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from graph_compression.compression_lib import compression_op
from graph_compression.compression_lib import decompose_matrix
from graph_compression.compression_lib import kmeans_quantize
from model_pruning.python import hparam
from model_pruning.python import pruning
from tensorflow.contrib.training.python.training import hparam


class SimhashMatrixCompressor(compression_op.LowRankDecompMatrixCompressor):
Expand Down
1 change: 1 addition & 0 deletions model_pruning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import print_function

# pylint: disable=unused-import
from model_pruning.python.hparam import HParams
from model_pruning.python.layers.rnn_cells import MaskedBasicLSTMCell
from model_pruning.python.layers.rnn_cells import MaskedLSTMCell
from model_pruning.python.pruning import apply_mask
Expand Down
Loading

0 comments on commit c374ee4

Please sign in to comment.