Skip to content

Commit

Permalink
Merge changes from github.
Browse files Browse the repository at this point in the history
Change: 131310818
  • Loading branch information
tensorflower-gardener committed Aug 25, 2016
1 parent 1fa09b5 commit 2c598e8
Show file tree
Hide file tree
Showing 41 changed files with 1,469 additions and 382 deletions.
73 changes: 15 additions & 58 deletions configure
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ while [ "$TF_NEED_CUDA" == "" ]; do
esac
done

export TF_NEED_CUDA
if [ "$TF_NEED_CUDA" == "0" ]; then
echo "Configuration finished"
exit
Expand All @@ -97,6 +98,7 @@ while true; do
fi
fi
if [ -e "$GCC_HOST_COMPILER_PATH" ]; then
export CC=$GCC_HOST_COMPILER_PATH
break
fi
echo "Invalid gcc path. ${GCC_HOST_COMPILER_PATH} cannot be found" 1>&2
Expand All @@ -107,7 +109,6 @@ while true; do
# Retry
done


# Find out where the CUDA toolkit is installed
OSNAME=`uname -s`

Expand Down Expand Up @@ -140,6 +141,8 @@ while true; do
fi

if [ -e "${CUDA_TOOLKIT_PATH}/${CUDA_RT_LIB_PATH}" ]; then
export CUDA_TOOLKIT_PATH
export CUDA_VERSION=$TF_CUDA_VERSION
break
fi
echo "Invalid path to CUDA $TF_CUDA_VERSION toolkit. ${CUDA_TOOLKIT_PATH}/${CUDA_RT_LIB_PATH} cannot be found"
Expand Down Expand Up @@ -200,13 +203,16 @@ while true; do
fi

if [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_ALT_PATH}" -o -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_PATH}" ]; then
export CUDNN_VERSION=$TF_CUDNN_VERSION
export CUDNN_INSTALL_PATH
break
fi

if [ "$OSNAME" == "Linux" ]; then
CUDNN_PATH_FROM_LDCONFIG="$(ldconfig -p | sed -n 's/.*libcudnn.so .* => \(.*\)/\1/p')"
if [ -e "${CUDNN_PATH_FROM_LDCONFIG}${TF_CUDNN_EXT}" ]; then
CUDNN_INSTALL_PATH="$(dirname ${CUDNN_PATH_FROM_LDCONFIG})"
export CUDNN_VERSION=$TF_CUDNN_VERSION
export CUDNN_INSTALL_PATH="$(dirname ${CUDNN_PATH_FROM_LDCONFIG})"
break
fi
fi
Expand All @@ -225,42 +231,11 @@ while true; do
CUDNN_INSTALL_PATH=""
done

cat > third_party/gpus/cuda/cuda.config <<EOF
# CUDA_TOOLKIT_PATH refers to the CUDA toolkit.
CUDA_TOOLKIT_PATH="$CUDA_TOOLKIT_PATH"
# CUDNN_INSTALL_PATH refers to the cuDNN toolkit. The cuDNN header and library
# files can be either in this directory, or under include/ and lib64/
# directories separately.
CUDNN_INSTALL_PATH="$CUDNN_INSTALL_PATH"
# The Cuda SDK version that should be used in this build (empty to use libcudart.so symlink)
TF_CUDA_VERSION=$TF_CUDA_VERSION
# The Cudnn version that should be used in this build
TF_CUDNN_VERSION=$TF_CUDNN_VERSION
EOF

# Configure the gcc host compiler to use
export WARNING=$DO_NOT_SUBMIT_WARNING
perl -pi -e "s,CPU_COMPILER = \('.*'\),# \$ENV{WARNING}\nCPU_COMPILER = ('$GCC_HOST_COMPILER_PATH'),s" third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc
perl -pi -e "s,GCC_HOST_COMPILER_PATH = \('.*'\),# \$ENV{WARNING}\nGCC_HOST_COMPILER_PATH = ('$GCC_HOST_COMPILER_PATH'),s" third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc

# Configure the platform name.
perl -pi -e "s,PLATFORM = \".*\",PLATFORM = \"$OSNAME\",s" third_party/gpus/cuda/platform.bzl

# Configure the Cuda toolkit version to work with.
perl -pi -e "s,(GetCudaVersion.*return )\"[0-9\.]*\",\1\"$TF_CUDA_VERSION\",s" tensorflow/stream_executor/dso_loader.cc
perl -pi -e "s,CUDA_VERSION = \"[0-9\.]*\",CUDA_VERSION = \"$TF_CUDA_VERSION\",s" third_party/gpus/cuda/platform.bzl

# Configure the Cudnn version to work with.
perl -pi -e "s,(GetCudnnVersion.*return )\"[0-9\.]*\",\1\"$TF_CUDNN_VERSION\",s" tensorflow/stream_executor/dso_loader.cc
perl -pi -e "s,CUDNN_VERSION = \"[0-9\.]*\",CUDNN_VERSION = \"$TF_CUDNN_VERSION\",s" third_party/gpus/cuda/platform.bzl


# Configure the compute capabilities that TensorFlow builds for.
# Since Cuda toolkit is not backward-compatible, this is not guaranteed to work.
while true; do
fromuser=""
default_cuda_compute_capabilities="3.5,5.2"
if [ -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then
cat << EOF
Please specify a list of comma-separated Cuda compute capabilities you want to build with.
Expand All @@ -270,6 +245,9 @@ EOF
read -p "[Default is: \"3.5,5.2\"]: " TF_CUDA_COMPUTE_CAPABILITIES
fromuser=1
fi
if [ -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then
TF_CUDA_COMPUTE_CAPABILITIES=$default_cuda_compute_capabilities
fi
# Check whether all capabilities from the input is valid
COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES//,/ }
ALL_VALID=1
Expand All @@ -285,34 +263,13 @@ EOF
exit 1
fi
else
export CUDA_COMPUTE_CAPABILITIES=$TF_CUDA_COMPUTE_CAPABILITIES
break
fi
TF_CUDA_COMPUTE_CAPABILITIES=""
done

if [ ! -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then
export WARNING=$DO_NOT_SUBMIT_WARNING
function CudaGenCodeOpts() {
OUTPUT=""
for CAPABILITY in $@; do
OUTPUT=${OUTPUT}" \"${CAPABILITY}\", "
done
echo $OUTPUT
}
export CUDA_GEN_CODES_OPTS=$(CudaGenCodeOpts ${TF_CUDA_COMPUTE_CAPABILITIES//,/ })
perl -pi -0 -e 's,\n( *)([^\n]*supported_cuda_compute_capabilities\s*=\s*\[).*?(\]),\n\1# $ENV{WARNING}\n\1\2$ENV{CUDA_GEN_CODES_OPTS}\3,s' third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc
function CudaVersionOpts() {
OUTPUT=""
for CAPABILITY in $@; do
OUTPUT=$OUTPUT"CudaVersion(\"${CAPABILITY}\"), "
done
echo $OUTPUT
}
export CUDA_VERSION_OPTS=$(CudaVersionOpts ${TF_CUDA_COMPUTE_CAPABILITIES//,/ })
perl -pi -0 -e 's,\n( *)([^\n]*supported_cuda_compute_capabilities\s*=\s*\{).*?(\}),\n\1// $ENV{WARNING}\n\1\2$ENV{CUDA_VERSION_OPTS}\3,s' tensorflow/core/common_runtime/gpu/gpu_device.cc
fi

# Invoke the cuda_config.sh and set up the TensorFlow's canonical view of the Cuda libraries
(cd third_party/gpus/cuda; ./cuda_config.sh;) || exit -1
bazel clean --expunge
bazel fetch //...

echo "Configuration finished"
4 changes: 1 addition & 3 deletions tensorflow/core/common_runtime/gpu/gpu_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -785,10 +785,8 @@ struct CudaVersion {
int minor_part = -1;
};

// "configure" uses the specific name to substitute the following string.
// If you change it, make sure you modify "configure" as well.
std::vector<CudaVersion> supported_cuda_compute_capabilities = {
CudaVersion("3.5"), CudaVersion("5.2")};
TF_CUDA_CAPABILITIES,};

std::vector<CudaVersion> GetSupportedCudaComputeCapabilities() {
auto cuda_caps = supported_cuda_compute_capabilities;
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/lrn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ limitations under the License.
#endif

#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#include "cuda/include/cuda.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/util/stream_executor_util.h"
#endif // GOOGLE_CUDA
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/core/kernels/fill_functor.h"

#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#include "cuda/include/cuda.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA

Expand Down
36 changes: 36 additions & 0 deletions tensorflow/core/ops/ops.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -4680,6 +4680,42 @@ op {
summary: "Decode a PNG-encoded image to a uint8 or uint16 tensor."
description: "The attr `channels` indicates the desired number of color channels for the\ndecoded image.\n\nAccepted values are:\n\n* 0: Use the number of channels in the PNG-encoded image.\n* 1: output a grayscale image.\n* 3: output an RGB image.\n* 4: output an RGBA image.\n\nIf needed, the PNG-encoded image is transformed to match the requested number\nof color channels."
}
op {
name: "DecodeGif"
input_arg {
name: "contents"
description: "0-D. The GIF-encoded image."
type: DT_STRING
}
output_arg {
name: "image"
description: "3-D with shape `[height, width, channels]`."
type_attr: "dtype"
}
attr {
name: "channels"
type: "int"
default_value {
i: 0
}
description: "Number of color channels for the decoded image."
}
attr {
name: "dtype"
type: "type"
default_value {
type: DT_UINT8
}
allowed_values {
list {
type: DT_UINT8
type: DT_UINT16
}
}
}
summary: "Decode a GIF-encoded image to a uint8 or uint16 tensor."
description: "The attr `channels` indicates the desired number of color channels for the\ndecoded image.\n\nAccepted values are:\n\n* 0: Use the number of channels in the GIF-encoded image.\n* 1: output a grayscale image.\n* 3: output an RGB image.\n* 4: output an RGBA image.\n\nIf needed, the GIF-encoded image is transformed to match the requested number\nof color channels."
}
op {
name: "DecodeRaw"
input_arg {
Expand Down
18 changes: 9 additions & 9 deletions tensorflow/core/platform/default/build_config/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ exports_files(["LICENSE"])

load("//tensorflow:tensorflow.bzl", "tf_copts")
load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
load("//third_party/gpus/cuda:platform.bzl", "cuda_library_path")
load("@local_config_cuda//cuda:platform.bzl", "cuda_library_path")

cc_library(
name = "gtest",
Expand All @@ -32,7 +32,7 @@ tf_cuda_library(
deps = [
"//tensorflow/stream_executor",
] + select({
"//third_party/gpus/cuda:darwin": ["IOKit"],
"@local_config_cuda//cuda:darwin": ["IOKit"],
"//conditions:default": [],
}),
)
Expand Down Expand Up @@ -91,20 +91,20 @@ filegroup(
cc_library(
name = "cuda",
data = [
"//third_party/gpus/cuda:{}".format(cuda_library_path("cudart")),
"@local_config_cuda//cuda:{}".format(cuda_library_path("cudart")),
],
linkopts = select({
"//third_party/gpus/cuda:darwin": [
"-Wl,-rpath,third_party/gpus/cuda/lib",
"-Wl,-rpath,third_party/gpus/cuda/extras/CUPTI/lib",
"@local_config_cuda//cuda:darwin": [
"-Wl,-rpath,../local_config_cuda/cuda/lib",
"-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib",
],
"//conditions:default": [
"-Wl,-rpath,third_party/gpus/cuda/lib64",
"-Wl,-rpath,third_party/gpus/cuda/extras/CUPTI/lib64",
"-Wl,-rpath,../local_config_cuda/cuda/lib64",
"-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib64",
],
}),
deps = [
"//third_party/gpus/cuda:cudart",
"@local_config_cuda//cuda:cudart",
],
)

Expand Down
6 changes: 3 additions & 3 deletions tensorflow/core/platform/default/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ tf_cuda_library(
copts = tf_copts(),
cuda_deps = [
"//tensorflow/core:stream_executor",
"//third_party/gpus/cuda:cuda_headers",
"//third_party/gpus/cuda:cupti_headers",
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cupti_headers",
],
data = ["//third_party/gpus/cuda:cupti_dsos"],
data = ["@local_config_cuda//cuda:cupti_dsos"],
visibility = ["//visibility:public"],
)
2 changes: 1 addition & 1 deletion tensorflow/core/platform/default/gpu/cupti_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ limitations under the License.
#include <stddef.h>
#include <stdint.h>

#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
#include "cuda/extras/CUPTI/include/cupti.h"

namespace perftools {
namespace gputools {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/util/port.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/core/util/port.h"

#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#include "cuda/include/cuda.h"
#endif

namespace tensorflow {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,8 @@ applies gradients.

### Gating Gradients

Both `minimize()` and `compute_gradients()` accept a `gate_gradient` argument
that controls the degree of parallelism during the application of the
gradients.
Both `minimize()` and `compute_gradients()` accept a `gate_gradients` argument
that controls the degree of parallelism during the application of the gradients.

The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.

Expand Down
5 changes: 2 additions & 3 deletions tensorflow/g3doc/api_docs/python/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,8 @@ applies gradients.

### Gating Gradients

Both `minimize()` and `compute_gradients()` accept a `gate_gradient` argument
that controls the degree of parallelism during the application of the
gradients.
Both `minimize()` and `compute_gradients()` accept a `gate_gradients` argument
that controls the degree of parallelism during the application of the gradients.

The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.

Expand Down
9 changes: 4 additions & 5 deletions tensorflow/g3doc/get_started/os_setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ github source.

The TensorFlow Python API supports Python 2.7 and Python 3.3+.

The GPU version (Linux only) works best with Cuda Toolkit 7.5 and
cuDNN v4. other versions are supported (Cuda toolkit >= 7.0 and
cuDNN 6.5(v2), 7.0(v3), v5) only when installing from sources.
Please see [Cuda installation](#optional-install-cuda-gpus-on-linux)
for details.
The GPU version (Linux & Mac OS X only) works best with Cuda Toolkit 7.5 and
cuDNN v4. other versions are supported (Cuda toolkit >= 7.0 and cuDNN 6.5(v2),
7.0(v3), v5) only when installing from sources. Please see [Cuda installation]
(#optional-install-cuda-gpus-on-linux) for details.

## Overview

Expand Down
18 changes: 16 additions & 2 deletions tensorflow/python/ops/ctc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ def ctc_loss(inputs, labels, sequence_length,
<= sequence_length(b) for all b.
```
Notes:
This class performs the softmax operation for you, so inputs should
be e.g. linear projections of outputs by an LSTM.
The `inputs` Tensor's innermost dimension size, `num_classes`, represents
`num_labels + 1` classes, where num_labels is the number of true labels, and
the largest value `(num_classes - 1)` is reserved for the blank label.
For example, for a vocabulary containing 3 labels `[a, b, c]`,
`num_classes = 4` and the labels indexing is `{a: 0, b: 1, c: 2, blank: 3}`.
Regarding the arguments `preprocess_collapse_repeated` and
`ctc_merge_repeated`:
Expand Down Expand Up @@ -84,10 +96,12 @@ def ctc_loss(inputs, labels, sequence_length,
Args:
inputs: 3-D `float` `Tensor` sized
`[max_time x batch_size x num_classes]`. The logits.
`[max_time x batch_size x num_classes]`. The logits.
labels: An `int32` `SparseTensor`.
`labels.indices[i, :] == [b, t]` means `labels.values[i]` stores
the id for (batch b, time t). See `core/ops/ctc_ops.cc` for more details.
the id for (batch b, time t).
`labels.values[i]` must take on values in `[0, num_labels)`.
See `core/ops/ctc_ops.cc` for more details.
sequence_length: 1-D `int32` vector, size `[batch_size]`.
The sequence lengths.
preprocess_collapse_repeated: Boolean. Default: False.
Expand Down
3 changes: 1 addition & 2 deletions tensorflow/python/ops/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,8 +1036,7 @@ def report_uninitialized_variables(var_list=None,
Returns:
A 1-D tensor containing names of the uninitialized variables, or an empty
1-D
tensor if there are no variables or no uninitialized variables.
1-D tensor if there are no variables or no uninitialized variables.
"""
if var_list is None:
var_list = all_variables() + local_variables()
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/python/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ class directly, but instead instantiate one of its subclasses such as
### Gating Gradients
Both `minimize()` and `compute_gradients()` accept a `gate_gradient` argument
that controls the degree of parallelism during the application of the
gradients.
Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
argument that controls the degree of parallelism during the application of
the gradients.
The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
Expand Down
Loading

0 comments on commit 2c598e8

Please sign in to comment.