Skip to content

Commit

Permalink
Move TensorForestEstimator to contrib, since that's where most of its…
Browse files Browse the repository at this point in the history
… code is and it will not be considered a canned estimator in the near future.

Change: 143989623
  • Loading branch information
tensorflower-gardener committed Jan 9, 2017
1 parent c71ac2d commit 7ad7e4d
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 29 deletions.
19 changes: 0 additions & 19 deletions tensorflow/contrib/learn/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ py_library(
"//tensorflow/contrib/rnn:rnn_py",
"//tensorflow/contrib/session_bundle:exporter",
"//tensorflow/contrib/session_bundle:gc",
"//tensorflow/contrib/tensor_forest:client_lib",
"//tensorflow/contrib/tensor_forest:data_ops_py",
"//tensorflow/contrib/tensor_forest:eval_metrics",
"//tensorflow/contrib/tensor_forest:tensor_forest_py",
"//tensorflow/contrib/training:training_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
Expand Down Expand Up @@ -674,21 +670,6 @@ py_test(
],
)

py_test(
name = "random_forest_test",
size = "medium",
srcs = ["python/learn/estimators/random_forest_test.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
"//tensorflow/contrib/learn/python/learn/datasets",
"//tensorflow/contrib/tensor_forest:tensor_forest_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//third_party/py/numpy",
],
)

py_test(
name = "dynamic_rnn_estimator_test",
size = "medium",
Expand Down
2 changes: 0 additions & 2 deletions tensorflow/contrib/learn/python/learn/estimators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,6 @@ def model_fn(features, targets, mode, params):
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModeKeys
from tensorflow.contrib.learn.python.learn.estimators.prediction_key import PredictionKey
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook
from tensorflow.contrib.learn.python.learn.estimators.run_config import ClusterConfig
from tensorflow.contrib.learn.python.learn.estimators.run_config import Environment
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig
Expand Down
32 changes: 32 additions & 0 deletions tensorflow/contrib/tensor_forest/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ py_library(
":constants",
":data_ops_py",
":eval_metrics",
":random_forest",
":tensor_forest_ops_py",
":tensor_forest_py",
],
Expand Down Expand Up @@ -395,3 +396,34 @@ py_test(
"//tensorflow/python:variables",
],
)

py_library(
name = "random_forest",
srcs = ["client/random_forest.py"],
srcs_version = "PY2AND3",
deps = [
":client_lib",
":data_ops_py",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/learn",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:state_ops",
],
)

py_test(
name = "random_forest_test",
size = "medium",
srcs = ["client/random_forest_test.py"],
srcs_version = "PY2AND3",
deps = [
":random_forest",
":tensor_forest_py",
"//tensorflow/contrib/learn/python/learn/datasets",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//third_party/py/numpy",
],
)
1 change: 1 addition & 0 deletions tensorflow/contrib/tensor_forest/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@

# pylint: disable=unused-import
from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.client import random_forest
# pylint: enable=unused-import
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from __future__ import print_function

from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import trainable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import numpy as np

from tensorflow.contrib.learn.python.learn.datasets import base
from tensorflow.contrib.learn.python.learn.estimators import random_forest
from tensorflow.contrib.tensor_forest.client import random_forest
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.platform import test

Expand Down
11 changes: 5 additions & 6 deletions tensorflow/examples/learn/random_forest_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,24 @@
import sys
import tempfile

import tensorflow as tf

# pylint: disable=g-backslash-continuation
from tensorflow.contrib.learn.python.learn\
import metric_spec
from tensorflow.contrib.learn.python.learn.estimators\
import random_forest
from tensorflow.contrib.tensor_forest.client\
import eval_metrics
from tensorflow.contrib.tensor_forest.client\
import random_forest
from tensorflow.contrib.tensor_forest.python\
import tensor_forest
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.platform import app

FLAGS = None


def build_estimator(model_dir):
"""Build an estimator."""
params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
params = tensor_forest.ForestHParams(
num_classes=10, num_features=784,
num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes)
graph_builder_class = tensor_forest.RandomForestGraphs
Expand Down Expand Up @@ -129,4 +128,4 @@ def main(_):
help='If true, use training loss as termination criteria.'
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
app.run(main=main, argv=[sys.argv[0]] + unparsed)

0 comments on commit 7ad7e4d

Please sign in to comment.