Skip to content

Commit

Permalink
Expose RunOptions and RunOutputs to Python's Run().
Browse files Browse the repository at this point in the history
The Run() call that takes these protos previously only worked in C++.
Example usage in Python (see also new unit test in session_test.py):

      sess.run(constant_op.constant(1.0),
               options=run_options,
               run_outputs=run_outputs)

A thin TF_Buffer struct is introduced to the C API, which takes care of
handling in/out bytes.  For instance, the "in protobuf" RunOptions is
serialized to bytes, which then get typemap'd to TF_Buffer.  The "out protobuf"
RunOutputs is handled analogously.
Change: 117152952
  • Loading branch information
concretevitamin authored and tensorflower-gardener committed Mar 14, 2016
1 parent 80344f3 commit b54ad57
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 47 deletions.
74 changes: 61 additions & 13 deletions tensorflow/core/client/tensor_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ using tensorflow::Session;
using tensorflow::Tensor;
using tensorflow::TensorBuffer;
using tensorflow::SessionOptions;
using tensorflow::RunOptions;
using tensorflow::RunOutputs;
using tensorflow::TensorShape;

extern "C" {
Expand Down Expand Up @@ -183,6 +185,31 @@ void TF_SetConfig(TF_SessionOptions* options, const void* proto,
tensorflow::errors::InvalidArgument("Unparseable ConfigProto");
}
}
// --------------------------------------------------------------------------
TF_Buffer* TF_NewBuffer() { return new TF_Buffer; }

TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
void* copy = malloc(proto_len);
memcpy(copy, proto, proto_len);

TF_Buffer* buf = new TF_Buffer;
buf->data = copy;
buf->length = proto_len;
buf->data_deallocator = [](void* data, size_t length) {
delete reinterpret_cast<char*>(data);
};
return buf;
}

void TF_DeleteBuffer(TF_Buffer* buffer) {
if (buffer->data_deallocator != nullptr) {
(*buffer->data_deallocator)(const_cast<void*>(buffer->data),
buffer->length);
}
delete buffer;
}

TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }

// --------------------------------------------------------------------------
struct TF_Session {
Expand Down Expand Up @@ -331,6 +358,7 @@ Status LoadLibrary(const char* library_filename, void** result,
} // namespace tensorflow

void TF_Run_Helper(TF_Session* s, const char* handle,
const TF_Buffer* run_options,
// Input tensors
const char** c_input_names, TF_Tensor** c_inputs,
int ninputs,
Expand All @@ -339,7 +367,7 @@ void TF_Run_Helper(TF_Session* s, const char* handle,
int noutputs,
// Target nodes
const char** c_target_node_names, int ntargets,
TF_Status* status) {
TF_Buffer* run_outputs, TF_Status* status) {
status->status = Status::OK();
for (int i = 0; i < noutputs; i++) {
c_outputs[i] = NULL;
Expand Down Expand Up @@ -380,10 +408,33 @@ void TF_Run_Helper(TF_Session* s, const char* handle,
target_node_names[i] = c_target_node_names[i];
}
Status result;

if (handle == nullptr) {
result = s->session->Run(inputs, output_tensor_names, target_node_names,
&outputs);
if (run_options == nullptr) {
result = s->session->Run(inputs, output_tensor_names, target_node_names,
&outputs);
} else {
// Prepares (input) RunOptions and (output) RunOutputs params
RunOptions run_options_proto;
if (!run_options_proto.ParseFromArray(run_options->data,
run_options->length)) {
status->status =
tensorflow::errors::InvalidArgument("Unparseable RunOptions proto");
}
RunOutputs run_outputs_proto;

result = s->session->Run(run_options_proto, inputs, output_tensor_names,
target_node_names, &outputs, &run_outputs_proto);

// Serialize back to upstream client, who now owns the new buffer
int proto_size = run_outputs_proto.ByteSize();
void* str_buf = reinterpret_cast<void*>(operator new(proto_size));
run_outputs_proto.SerializeToArray(str_buf, proto_size);
run_outputs->data = str_buf;
run_outputs->length = proto_size;
}
} else {
// NOTE(zongheng): PRun does not support RunOptions yet.
result = s->session->PRun(handle, inputs, output_tensor_names, &outputs);
}
if (!result.ok()) {
Expand Down Expand Up @@ -413,17 +464,18 @@ void TF_Run_Helper(TF_Session* s, const char* handle,

extern "C" {

void TF_Run(TF_Session* s,
void TF_Run(TF_Session* s, const TF_Buffer* run_options,
// Input tensors
const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
// Output tensors
const char** c_output_tensor_names, TF_Tensor** c_outputs,
int noutputs,
// Target nodes
const char** c_target_node_names, int ntargets, TF_Status* status) {
TF_Run_Helper(s, nullptr, c_input_names, c_inputs, ninputs,
const char** c_target_node_names, int ntargets,
TF_Buffer* run_outputs, TF_Status* status) {
TF_Run_Helper(s, nullptr, run_options, c_input_names, c_inputs, ninputs,
c_output_tensor_names, c_outputs, noutputs, c_target_node_names,
ntargets, status);
ntargets, run_outputs, status);
}

void TF_PRunSetup(TF_Session* s,
Expand Down Expand Up @@ -469,15 +521,11 @@ void TF_PRun(TF_Session* s, const char* handle,
// Target nodes
const char** c_target_node_names, int ntargets,
TF_Status* status) {
TF_Run_Helper(s, handle, c_input_names, c_inputs, ninputs,
TF_Run_Helper(s, handle, nullptr, c_input_names, c_inputs, ninputs,
c_output_tensor_names, c_outputs, noutputs, c_target_node_names,
ntargets, status);
ntargets, nullptr, status);
}

const void* TF_BufferData(TF_Buffer* buffer) { return buffer->data; }

size_t TF_BufferLength(TF_Buffer* buffer) { return buffer->length; }

TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
TF_Library* lib_handle = new TF_Library;
status->status = tensorflow::LoadLibrary(
Expand Down
29 changes: 25 additions & 4 deletions tensorflow/core/public/tensor_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,26 @@ typedef struct TF_Status TF_Status;
// Typically, the data consists of a serialized protocol buffer, but other data
// may also be held in a buffer.
//
// TF_Buffer itself does not do any memory management of the pointed-to block.
// By default, TF_Buffer itself does not do any memory management of the
// pointed-to block. If need be, users of this struct should specify how to
// deallocate the block by setting the `data_deallocator` function pointer.
typedef struct {
const void* data;
size_t length;
const void* data = nullptr;
size_t length = 0;
void (*data_deallocator)(void* data, size_t length) = nullptr;
} TF_Buffer;

// Makes a copy of the input and sets an appropriate deallocator. Useful for
// passing in read-only, input protobufs.
extern TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len);

// Useful for passing *out* a protobuf.
extern TF_Buffer* TF_NewBuffer();

extern void TF_DeleteBuffer(TF_Buffer*);

extern TF_Buffer TF_GetBuffer(TF_Buffer* buffer);

// --------------------------------------------------------------------------
// TF_Library holds information about dynamically loaded TensorFlow plugins.
typedef struct TF_Library TF_Library;
Expand Down Expand Up @@ -172,7 +186,7 @@ typedef struct TF_Tensor TF_Tensor;
// Return a new tensor that holds the bytes data[0,len-1].
//
// The data will be deallocated by a subsequent call to TF_DeleteTensor via:
// (*deallocator_fn)(data, len, deallocator_arg)
// (*deallocator)(data, len, deallocator_arg)
// Clients must provide a custom deallocator function so they can pass in
// memory managed by something like numpy.
extern TF_Tensor* TF_NewTensor(TF_DataType, long long* dims, int num_dims,
Expand Down Expand Up @@ -252,20 +266,27 @@ extern void TF_ExtendGraph(TF_Session*, const void* proto, size_t proto_len,
// failure, inputs[] become the property of the implementation (the
// implementation will eventually call TF_DeleteTensor on each input).
//
// The caller retains the ownership of both `run_options` and `run_outputs`, and
// should manually call TF_DeleteBuffer on them.
//
// On success, the tensors corresponding to output_names[0,noutputs-1]
// are placed in outputs[], and these outputs[] become the property
// of the caller (the caller must eventually call TF_DeleteTensor on
// them).
//
// On failure, outputs[] contains nulls.
extern void TF_Run(TF_Session*,
// RunOptions
const TF_Buffer* run_options,
// Input tensors
const char** input_names, TF_Tensor** inputs, int ninputs,
// Output tensors
const char** output_tensor_names, TF_Tensor** outputs,
int noutputs,
// Target nodes
const char** target_node_names, int ntargets,
// RunOutputs
TF_Buffer* run_outputs,
// Output status
TF_Status*);

Expand Down
56 changes: 46 additions & 10 deletions tensorflow/python/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def sess_str(self):
"""The TensorFlow process to which this session will connect."""
raise NotImplementedError('sess_str')

def run(self, fetches, feed_dict=None):
def run(self, fetches, feed_dict=None, options=None, run_outputs=None):
"""Runs operations in the session. See `Session.run()` for details."""
raise NotImplementedError('run')

Expand Down Expand Up @@ -254,7 +254,7 @@ def as_default(self):
lambda feed: [feed])]
# pylint: enable=g-long-lambda

def run(self, fetches, feed_dict=None):
def run(self, fetches, feed_dict=None, options=None, run_outputs=None):
"""Runs the operations and evaluates the tensors in `fetches`.
This method runs one "step" of TensorFlow computation, by
Expand Down Expand Up @@ -293,11 +293,22 @@ def run(self, fetches, feed_dict=None):
the value should be a
[`SparseTensorValue`](../../api_docs/python/sparse_ops.md#SparseTensorValue).
The optional `options` argument expects a [`RunOptions`] proto. The options
allow controling the behavior of this particular step (e.g. turning tracing
on).
The optional `run_outputs` argument expects a [`RunOutputs`] proto. When
appropriate, the non-Tensor output of this step will be collected there. For
example, when users turn on tracing in `options`, the profiled info will be
collected into this argument and passed back.
Args:
fetches: A single graph element, or a list of graph elements
(described above).
feed_dict: A dictionary that maps graph elements to values
(described above).
options: A [`RunOptions`] protocol buffer
run_outputs: A [`RunOutputs`] protocol buffer
Returns:
Either a single value if `fetches` is a single graph element, or
Expand All @@ -310,7 +321,23 @@ def run(self, fetches, feed_dict=None):
ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a
`Tensor` that doesn't exist.
"""
return self._run(None, fetches, feed_dict)
run_outputs_ptr = tf_session.TF_NewBuffer()
if options:
options_ptr = tf_session.TF_NewBufferFromString(
compat.as_bytes(options.SerializeToString()))
else:
options_ptr = None

try:
result = self._run(None, fetches, feed_dict, options_ptr, run_outputs_ptr)
if run_outputs:
proto_data = tf_session.TF_GetBuffer(run_outputs_ptr)
run_outputs.ParseFromString(compat.as_bytes(proto_data))
finally:
tf_session.TF_DeleteBuffer(run_outputs_ptr)
if options:
tf_session.TF_DeleteBuffer(options_ptr)
return result

def partial_run(self, handle, fetches, feed_dict=None):
"""Continues the execution with more feeds and fetches.
Expand Down Expand Up @@ -345,7 +372,7 @@ def partial_run(self, handle, fetches, feed_dict=None):
Either a single value if `fetches` is a single graph element, or
a list of values if `fetches` is a list (described above).
"""
return self._run(handle, fetches, feed_dict)
return self._run(handle, fetches, feed_dict, None, None)

def partial_run_setup(self, fetches, feeds=None):
"""Sets up a graph with feeds and fetches for partial run.
Expand Down Expand Up @@ -457,7 +484,7 @@ def _fetch_fn(fetch):
unique_fetch_targets = list(unique_fetch_targets)
return unique_fetch_targets, target_list, fetch_info

def _run(self, handle, fetches, feed_dict):
def _run(self, handle, fetches, feed_dict, options, run_outputs):
"""Perform either run or partial_run, depending the exitence of `handle`."""
def _feed_fn(feed, feed_val):
for tensor_type, _, feed_fn, _ in BaseSession._REGISTERED_EXPANSIONS:
Expand Down Expand Up @@ -506,7 +533,7 @@ def _feed_fn(feed, feed_val):

# Run request and get response.
results = self._do_run(handle, target_list, unique_fetches,
feed_dict_string)
feed_dict_string, options, run_outputs)

# User may have fetched the same tensor multiple times, but we
# only fetch them from the runtime once. Furthermore, they may
Expand All @@ -529,7 +556,8 @@ def _feed_fn(feed, feed_val):
# Captures the name of a node in an error status.
_NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =')

def _do_run(self, handle, target_list, fetch_list, feed_dict):
def _do_run(self, handle, target_list, fetch_list, feed_dict,
options, run_outputs):
"""Runs a step based on the given fetches and feeds.
Args:
Expand All @@ -540,17 +568,25 @@ def _do_run(self, handle, target_list, fetch_list, feed_dict):
be fetched and operations to be run.
feed_dict: A dictionary that maps tensor names (as byte arrays) to
numpy ndarrays.
options: A (pointer to a) [`RunOptions`] protocol buffer, or None
run_outputs: A (pointer to a) [`RunOutputs`] protocol buffer, or None
Returns:
A list of numpy ndarrays, corresponding to the elements of
`fetch_list`. If the ith element of `fetch_list` contains the
name of an operation, the first Tensor output of that operation
will be returned for that element.
"""
def _run_fn(session, feed_dict, fetch_list, target_list):
def _run_fn(session, feed_dict, fetch_list, target_list, options, run_outputs):
# Ensure any changes to the graph are reflected in the runtime.
self._extend_graph()
return tf_session.TF_Run(session, feed_dict, fetch_list, target_list)
if options:
return tf_session.TF_Run(session, options,
feed_dict, fetch_list, target_list,
run_outputs)
else:
return tf_session.TF_Run(
session, None, feed_dict, fetch_list, target_list, None)

def _prun_fn(session, handle, feed_dict, fetch_list):
if target_list:
Expand All @@ -559,7 +595,7 @@ def _prun_fn(session, handle, feed_dict, fetch_list):

if handle is None:
return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
target_list)
target_list, options, run_outputs)
else:
return self._do_call(_prun_fn, self._session, handle, feed_dict,
fetch_list)
Expand Down
25 changes: 25 additions & 0 deletions tensorflow/python/client/session_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import six
from six.moves import xrange # pylint: disable=redefined-builtin

from tensorflow.core.framework import step_stats_pb2
from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
Expand Down Expand Up @@ -906,5 +907,29 @@ def testFeedDictKeyException(self):
with self.assertRaisesRegexp(TypeError, "Cannot interpret feed_dict"):
sess.run(a, feed_dict={'a': [2.0]})

def testPerStepTrace(self):
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_outputs = config_pb2.RunOutputs()

with ops.device('/cpu:0'):
with session.Session() as sess:
sess.run(constant_op.constant(1.0))
self.assertTrue(not run_outputs.HasField('step_stats'))

sess.run(constant_op.constant(1.0), run_outputs=run_outputs)
self.assertTrue(not run_outputs.HasField('step_stats'))

sess.run(constant_op.constant(1.0),
options=run_options,
run_outputs=run_outputs)
self.assertTrue(run_outputs.HasField('step_stats'))

step_stats = step_stats_pb2.StepStats()
self.assertEquals(len(step_stats.dev_stats), 0)

step_stats.CopyFrom(run_outputs.step_stats)
self.assertEquals(len(step_stats.dev_stats), 1)

if __name__ == '__main__':
googletest.main()
Loading

0 comments on commit b54ad57

Please sign in to comment.