Skip to content

Commit

Permalink
TensorFlow: Minor updates to docs, BUILD, GPU config / perf, etc.
Browse files Browse the repository at this point in the history
Changes:
- Updates to op documentation and index by Josh

- More changes to BUILD files for python 3 support by @girving

- Fix to Eigen to use DenseIndex everywhere by @jiayq

- Enable configuration for cuda compute capability by @zheng-xq,
  including updates to docs.

- Route aggregation method through optimizer by schuster

- Updates to install instructions for bazel 0.1.1.

Base CL: 107702099
  • Loading branch information
Vijay Vasudevan committed Nov 12, 2015
1 parent f2102f4 commit 4dffee7
Show file tree
Hide file tree
Showing 48 changed files with 811 additions and 694 deletions.
64 changes: 64 additions & 0 deletions configure
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,70 @@ CUDA_TOOLKIT_PATH="$CUDA_TOOLKIT_PATH"
CUDNN_INSTALL_PATH="$CUDNN_INSTALL_PATH"
EOF

function UnofficialSetting() {
echo -e "\nWARNING: You are configuring unofficial settings in TensorFlow. Because some external libraries are not backward compatible, these settings are largely untested and unsupported. \n"

# Configure the compute capabilities that TensorFlow builds for.
# Since Cuda toolkit is not backward-compatible, this is not guaranteed to work.
while true; do
fromuser=""
if [ -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then
cat << EOF
Please specify a list of comma-separated Cuda compute capabilities you want to build with.
You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus.
Please note that each additional compute capability significantly increases your build time and binary size.
EOF
read -p "[Default is: \"3.5,5.2\"]: " TF_CUDA_COMPUTE_CAPABILITIES
fromuser=1
fi
# Check whether all capabilities from the input is valid
COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES//,/ }
ALL_VALID=1
for CAPABILITY in $COMPUTE_CAPABILITIES; do
if [[ ! "$CAPABILITY" =~ [0-9]+.[0-9]+ ]]; then
echo "Invalid compute capability: " $CAPABILITY
ALL_VALID=0
break
fi
done
if [ "$ALL_VALID" == "0" ]; then
if [ -z "$fromuser" ]; then
exit 1
fi
else
break
fi
TF_CUDA_COMPUTE_CAPABILITIES=""
done

if [ ! -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then
export WARNING="Unofficial setting. DO NOT"" SUBMIT!!!"
function CudaGenCodeOpts() {
OUTPUT=""
for CAPABILITY in $@; do
OUTPUT=${OUTPUT}" \"${CAPABILITY}\", "
done
echo $OUTPUT
}
export CUDA_GEN_CODES_OPTS=$(CudaGenCodeOpts ${TF_CUDA_COMPUTE_CAPABILITIES//,/ })
perl -pi -0 -e 's,\n( *)([^\n]*supported_cuda_compute_capabilities\s*=\s*\[).*?(\]),\n\1# $ENV{WARNING}\n\1\2$ENV{CUDA_GEN_CODES_OPTS}\3,s' third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc
function CudaVersionOpts() {
OUTPUT=""
for CAPABILITY in $@; do
OUTPUT=$OUTPUT"CudaVersion(\"${CAPABILITY}\"), "
done
echo $OUTPUT
}
export CUDA_VERSION_OPTS=$(CudaVersionOpts ${TF_CUDA_COMPUTE_CAPABILITIES//,/ })
perl -pi -0 -e 's,\n( *)([^\n]*supported_cuda_compute_capabilities\s*=\s*\{).*?(\}),\n\1// $ENV{WARNING}\n\1\2$ENV{CUDA_VERSION_OPTS}\3,s' tensorflow/core/common_runtime/gpu/gpu_device.cc
fi
}

# Only run the unofficial settings when users explicitly choose to.
if [ "$TF_UNOFFICIAL_SETTING" == "1" ]; then
UnofficialSetting
fi

# Invoke the cuda_config.sh and set up the TensorFlow's canonical view of the Cuda libraries
(cd third_party/gpus/cuda; ./cuda_config.sh;) || exit -1

Expand Down
1 change: 1 addition & 0 deletions six.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ py_library(
name = "six",
srcs = ["six.py"],
visibility = ["//visibility:public"],
srcs_version = "PY2AND3",
)
25 changes: 25 additions & 0 deletions tensorflow/core/common_runtime/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,31 @@ Status ExecutorImpl::InferAllocAttr(
const DeviceNameUtils::ParsedName& local_dev_name,
AllocatorAttributes* attr) {
Status s;
// Note that it's possible for *n to be a Recv and *dst to be a Send,
// so these two cases are not mutually exclusive.
if (IsRecv(n)) {
string src_name;
s = GetNodeAttr(n->def(), "send_device", &src_name);
if (!s.ok()) return s;
DeviceNameUtils::ParsedName parsed_src_name;
if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) {
s = errors::Internal("Bad send_device attr '", src_name, "' in node ",
n->name());
return s;
}
if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) {
// Value is going to be the sink of an RPC.
attr->set_nic_compatible(true);
VLOG(2) << "node " << n->name() << " is the sink of an RPC in";
} else if (local_dev_name.type == "CPU" && parsed_src_name.type == "GPU") {
// Value is going to be the sink of a local DMA from GPU to CPU.
attr->set_gpu_compatible(true);
VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy";
} else {
VLOG(2) << "default alloc case local type " << local_dev_name.type
<< " remote type " << parsed_src_name.type;
}
}
if (IsSend(dst)) {
string dst_name;
s = GetNodeAttr(dst->def(), "recv_device", &dst_name);
Expand Down
57 changes: 50 additions & 7 deletions tensorflow/core/common_runtime/gpu/gpu_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <stdlib.h>
#include <string.h>
#include <algorithm>

//#include "base/commandlineflags.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
Expand Down Expand Up @@ -590,10 +591,50 @@ static int GetMinGPUMultiprocessorCount() {
return kDefaultMinGPUMultiprocessorCount;
}

namespace {

struct CudaVersion {
// Initialize from version_name in the form of "3.5"
explicit CudaVersion(const std::string& version_name) {
size_t dot_pos = version_name.find('.');
CHECK(dot_pos != string::npos);
string major_str = version_name.substr(0, dot_pos);
CHECK(strings::safe_strto32(major_str.c_str(), &major_part));
string minor_str = version_name.substr(dot_pos + 1);
CHECK(strings::safe_strto32(minor_str.c_str(), &minor_part));
}
CudaVersion() {}
bool operator<(const CudaVersion& other) const {
if (this->major_part != other.major_part) {
return this->major_part < other.major_part;
}
return this->minor_part < other.minor_part;
}
friend std::ostream& operator<<(std::ostream& os,
const CudaVersion& version) {
os << version.major_part << "." << version.minor_part;
return os;
}
int major_part = -1;
int minor_part = -1;
};

// "configure" uses the specific name to substitute the following string.
// If you change it, make sure you modify "configure" as well.
std::vector<CudaVersion> supported_cuda_compute_capabilities = {
CudaVersion("3.5"), CudaVersion("5.2")};

} // namespace

void BaseGPUDeviceFactory::GetValidDeviceIds(std::vector<int>* ids) {
auto gpu_manager = GPUMachineManager();
int min_gpu_core_count = GetMinGPUMultiprocessorCount();
if (gpu_manager) {
CHECK(!supported_cuda_compute_capabilities.empty());
CudaVersion min_supported_capability =
*std::min_element(supported_cuda_compute_capabilities.begin(),
supported_cuda_compute_capabilities.end());

auto visible_device_count = gpu_manager->VisibleDeviceCount();
for (int i = 0; i < gpu_manager->VisibleDeviceCount(); ++i) {
auto exec_status = gpu_manager->ExecutorForDevice(i);
Expand All @@ -602,17 +643,19 @@ void BaseGPUDeviceFactory::GetValidDeviceIds(std::vector<int>* ids) {
}
gpu::StreamExecutor* se = exec_status.ValueOrDie();
const gpu::DeviceDescription& desc = se->GetDeviceDescription();
int major, minor;
if (!desc.cuda_compute_capability(&major, &minor)) {
CudaVersion device_capability;
if (!desc.cuda_compute_capability(&device_capability.major_part,
&device_capability.minor_part)) {
continue;
}
// Only consider GPUs with compute capability >= 3.5 (Kepler or
// higher)
if (major < 3 || (major == 3 && minor < 5)) {
// Only GPUs with no less than the minimum supported compute capability is
// accepted.
if (device_capability < min_supported_capability) {
LOG(INFO) << "Ignoring gpu device "
<< "(" << GetShortDeviceDescription(i, desc) << ") "
<< "with Cuda compute capability " << major << "." << minor
<< ". The minimum required Cuda capability is 3.5.";
<< "with Cuda compute capability " << device_capability
<< ". The minimum required Cuda capability is "
<< min_supported_capability << ".";
continue;
}

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/framework/rendezvous.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,9 @@ class LocalRendezvousImpl : public Rendezvous {
// message arrives.
Item* item = new Item;
item->waiter = done;
item->recv_alloc_attrs = recv_args.alloc_attrs;
if (recv_args.device_context) {
item->recv_dev_context = recv_args.device_context;
item->recv_alloc_attrs = recv_args.alloc_attrs;
item->recv_dev_context->Ref();
}
CHECK(table_.insert({key, item}).second);
Expand Down
11 changes: 6 additions & 5 deletions tensorflow/core/framework/tensor_slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ class TensorSlice {
// We allow NDIMS to be greater than dims(), in which case we will pad the
// higher dimensions with trivial dimensions.
template <int NDIMS>
void FillIndicesAndSizes(const TensorShape& shape,
Eigen::DSizes<ptrdiff_t, NDIMS>* indices,
Eigen::DSizes<ptrdiff_t, NDIMS>* sizes) const;
void FillIndicesAndSizes(
const TensorShape& shape,
Eigen::DSizes<Eigen::DenseIndex, NDIMS>* indices,
Eigen::DSizes<Eigen::DenseIndex, NDIMS>* sizes) const;

// Interaction with other TensorSlices.

Expand Down Expand Up @@ -162,8 +163,8 @@ class TensorSlice {

template <int NDIMS>
void TensorSlice::FillIndicesAndSizes(
const TensorShape& shape, Eigen::DSizes<ptrdiff_t, NDIMS>* indices,
Eigen::DSizes<ptrdiff_t, NDIMS>* sizes) const {
const TensorShape& shape, Eigen::DSizes<Eigen::DenseIndex, NDIMS>* indices,
Eigen::DSizes<Eigen::DenseIndex, NDIMS>* sizes) const {
CHECK_EQ(shape.dims(), dims()) << "Incompatible dimensions between shape "
<< "slices: shape = " << shape.DebugString()
<< ", slice = " << DebugString();
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/concat_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ void ConcatGPU(const GPUDevice& d,
const std::vector<
std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
typename TTypes<T, 2>::Matrix* output) {
Eigen::array<ptrdiff_t, 2> offset(0, 0);
Eigen::array<Eigen::DenseIndex, 2> offset(0, 0);
for (int i = 0; i < inputs.size(); ++i) {
Eigen::array<ptrdiff_t, 2> size = inputs[i]->dimensions();
Eigen::array<Eigen::DenseIndex, 2> size = inputs[i]->dimensions();
output->slice(offset, size).device(d) = *inputs[i];
offset[1] += size[1];
}
Expand Down
Loading

0 comments on commit 4dffee7

Please sign in to comment.