Skip to content

Commit

Permalink
added support to single boolean input flag to TF graph
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed May 1, 2017
1 parent c60b72b commit be216ab
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
27 changes: 25 additions & 2 deletions src/tflib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ namespace dd
_ntargets = cl._ntargets;
_inputLayer = cl._inputLayer;
_outputLayer = cl._outputLayer;
_inputFlag = cl._inputFlag;
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy, class TMLModel>
Expand Down Expand Up @@ -83,6 +84,10 @@ namespace dd
{
_outputLayer = ad.get("outputlayer").get<std::string>();
}
if (ad.has("input_flag"))
{
_inputFlag = ad.getobj("input_flag");
}
if (ad.has("ntargets")) // XXX: unsupported
_ntargets = ad.get("ntargets").get<int>();
if (_nclasses == 0)
Expand Down Expand Up @@ -366,10 +371,28 @@ namespace dd
if (dv.size() > 1)
tf_concat(dv,vtfinputs);
else vtfinputs = dv;

// other input variables
std::pair<std::string,tensorflow::Tensor> othertfinputs;
std::vector<std::string> lkeys = _inputFlag.list_keys();
bool has_input_vars = false;
for (auto k: lkeys)
{
tensorflow::Tensor ivar(tensorflow::DT_BOOL,tensorflow::TensorShape());
ivar.scalar<bool>()() = _inputFlag.get(k).get<bool>();
othertfinputs.first = k;
othertfinputs.second = ivar;
has_input_vars = true;
break; // a single key for now, may have to use ClientSession for another scheme
}


// running the loded graph and saving the generated output
std::vector<tensorflow::Tensor> finalOutput; // To save the final Output generated by the tensorflow
tensorflow::Status run_status = _session->Run({{_inputLayer,*(vtfinputs.begin())}},{_outputLayer},{},&finalOutput);
std::vector<tensorflow::Tensor> finalOutput; // To save the final output generated by the tensorflow
tensorflow::Status run_status;
if (has_input_vars)
run_status = _session->Run({{_inputLayer,*(vtfinputs.begin())},othertfinputs},{_outputLayer},{},&finalOutput);
else run_status = _session->Run({{_inputLayer,*(vtfinputs.begin())}},{_outputLayer},{},&finalOutput);
if (!run_status.ok())
{
std::cout <<run_status.ToString()<<std::endl;
Expand Down
5 changes: 3 additions & 2 deletions src/tflib.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ namespace dd
int _nclasses = 0; /**< required. */
bool _regression = false; /**< whether the net acts as a regressor. */
int _ntargets = 0; /**< number of classification or regression targets. */
std::string _inputLayer; // Input Layer of the Tensorflow Model
std::string _outputLayer; // OutPut layer of the tensorflow Model
std::string _inputLayer; // input Layer of the model
std::string _outputLayer; // output layer of the model
APIData _inputFlag; // boolean input to the model
std::unique_ptr<tensorflow::Session> _session = nullptr;
std::mutex _net_mutex; /**< mutex around net, e.g. no concurrent predict calls as net is not re-instantiated. Use batches instead. */
};
Expand Down

0 comments on commit be216ab

Please sign in to comment.