Skip to content

Commit

Permalink
Merge pull request #10 from Acellera/remove_tf
Browse files Browse the repository at this point in the history
Remove TesorFlow code
  • Loading branch information
Raimondas Galvelis authored Apr 10, 2020
2 parents 10a32b9 + 66e6bca commit 96c92f1
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 187 deletions.
6 changes: 2 additions & 4 deletions openmmapi/include/TensorRTForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@ class OPENMM_EXPORT_NN TensorRTForce : public Force {
*
* @param file the path to the file containing the graph
*/
TensorRTForce(const std::string& file, const std::string& file2);
TensorRTForce(const std::string& file);
/**
* Get the path to the file containing the graph.
*/
const std::string& getFile() const { return file; }
/**
* Get the content of the protocol buffer defining the graph.
*/
const std::string& getGraphProto() const { return graphProto; }
const std::string& getSerializedGraph() const { return serializedGraph; }
/**
* Set whether this force makes use of periodic boundary conditions. If this is set
Expand All @@ -43,8 +42,7 @@ class OPENMM_EXPORT_NN TensorRTForce : public Force {
protected:
ForceImpl* createImpl() const;
private:
std::string file, graphProto;
std::string file2;
std::string file;
std::string serializedGraph;
bool usePeriodic;
};
Expand Down
3 changes: 1 addition & 2 deletions openmmapi/include/TensorRTKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include "openmm/KernelImpl.h"
#include "openmm/Platform.h"
#include "openmm/System.h"
#include <tensorflow/c/c_api.h>
#include <NvInfer.h>
#include <string>

Expand All @@ -29,7 +28,7 @@ class CalcTesorRTForceKernel : public KernelImpl {
* @param session the TensorFlow session in which to do calculations
* @param graph the TensorFlow graph to use for computing forces and energy
*/
virtual void initialize(const System& system, const TensorRTForce& force, TF_Session* session, TF_Graph* graph, Engine& engine) = 0;
virtual void initialize(const System& system, const TensorRTForce& force, Engine& engine) = 0;
/**
* Execute the kernel to calculate the forces and/or energy.
*
Expand Down
4 changes: 0 additions & 4 deletions openmmapi/include/internal/TensorRTForceImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class Logger : public nvinfer1:: ILogger {
class OPENMM_EXPORT_NN TensorRTForceImpl : public ForceImpl {
public:
TensorRTForceImpl(const TensorRTForce& owner);
~TensorRTForceImpl();
void initialize(ContextImpl& context);
const TensorRTForce& getOwner() const { return owner; }
void updateContextState(ContextImpl& context, bool& forcesInvalid) {}
Expand All @@ -36,9 +35,6 @@ class OPENMM_EXPORT_NN TensorRTForceImpl : public ForceImpl {
private:
const TensorRTForce& owner;
Kernel kernel;
TF_Graph* graph;
TF_Session* session;
TF_Status* status;
Logger logger;
using Runtime = nvinfer1::IRuntime;
using Engine = nvinfer1::ICudaEngine;
Expand Down
7 changes: 3 additions & 4 deletions openmmapi/src/TensorRTForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@

using namespace OpenMM;

TensorRTForce::TensorRTForce(const std::string& file, const std::string& file2) : file(file), file2(file2), usePeriodic(false) {
std::ifstream graphFile(file);
graphProto = std::string((std::istreambuf_iterator<char>(graphFile)), std::istreambuf_iterator<char>());
TensorRTForce::TensorRTForce(const std::string& file): file(file), usePeriodic(false) {

// Read the serialized graph from a file
std::stringstream stream;
stream << std::ifstream(file2, std::ifstream::binary).rdbuf();
stream << std::ifstream(file, std::ifstream::binary).rdbuf();
serializedGraph = stream.str();
}

Expand Down
91 changes: 29 additions & 62 deletions openmmapi/src/TensorRTForceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,76 +5,44 @@

using namespace OpenMM;

TensorRTForceImpl::TensorRTForceImpl(const TensorRTForce& owner) : owner(owner), graph(NULL), session(NULL), status(TF_NewStatus()) {
TensorRTForceImpl::TensorRTForceImpl(const TensorRTForce& owner) : owner(owner) {

// Create TensorRT runtime
const auto destructor = [](Runtime* r) { r->destroy(); };
runtime = {nvinfer1::createInferRuntime(logger), destructor};
}

TensorRTForceImpl::~TensorRTForceImpl() {
if (session != NULL) {
TF_CloseSession(session, status);
TF_DeleteSession(session, status);
}
if (graph != NULL)
TF_DeleteGraph(graph);
TF_DeleteStatus(status);
}

void TensorRTForceImpl::initialize(ContextImpl& context) {
// Load the graph from the file.

const auto& graphProto = owner.getGraphProto();
auto buffer = TF_NewBufferFromString(graphProto.c_str(), graphProto.size());
graph = TF_NewGraph();
auto importOptions = TF_NewImportGraphDefOptions();
TF_GraphImportGraphDef(graph, buffer, importOptions, status);
if (TF_GetCode(status) != TF_OK)
throw OpenMMException(std::string("Error loading TensorFlow graph: ")+TF_Message(status));
TF_DeleteImportGraphDefOptions(importOptions);
TF_DeleteBuffer(buffer);

// Deserialize TensorRT graph
const auto& graph2 = owner.getSerializedGraph();
// Create TensorRT engine
const auto& graph = owner.getSerializedGraph();
const auto destructor = [](Engine* e) { e->destroy(); };
engine = {runtime->deserializeCudaEngine(graph2.data(), graph2.size()), destructor};
engine = {runtime->deserializeCudaEngine(graph.data(), graph.size()), destructor};

// Check that the graph contains all the expected elements and that their types
// are supported.

TF_Output positions = {TF_GraphOperationByName(graph, "positions"), 0};
if (positions.oper == NULL)
throw OpenMMException("TensorRTForce: the graph does not have a 'positions' input");
if (TF_OperationOutputType(positions) != TF_FLOAT)
throw OpenMMException("TensorRTForce: 'positions' must have type float32");

if (owner.usesPeriodicBoundaryConditions()) {
TF_Output boxvectors = {TF_GraphOperationByName(graph, "boxvectors"), 0};
if (boxvectors.oper == NULL)
throw OpenMMException("TensorRTForce: the graph does not have a 'boxvectors' input");
if (TF_OperationOutputType(boxvectors) != TF_FLOAT)
throw OpenMMException("TensorRTForce: 'boxvectors' must have type float32");
}

TF_Output energy = {TF_GraphOperationByName(graph, "energy"), 0};
if (energy.oper == NULL)
throw OpenMMException("TensorRTForce: the graph does not have an 'energy' output");
if (TF_OperationOutputType(energy) != TF_FLOAT)
throw OpenMMException("TensorRTForce: 'energy' must have type float32");

TF_Output forces = {TF_GraphOperationByName(graph, "forces"), 0};
if (forces.oper == NULL)
throw OpenMMException("TensorRTForce: the graph does not have a 'forces' output");
if (TF_OperationOutputType(forces) != TF_FLOAT)
throw OpenMMException("TensorRTForce: 'forces' must have type float32");

// Create the TensorFlow Session.

auto sessionOptions = TF_NewSessionOptions();
session = TF_NewSession(graph, sessionOptions, status);
if (TF_GetCode(status) != TF_OK)
throw OpenMMException(std::string("Error creating TensorFlow session: ")+TF_Message(status));
TF_DeleteSessionOptions(sessionOptions);
// TF_Output positions = {TF_GraphOperationByName(graph, "positions"), 0};
// if (positions.oper == NULL)
// throw OpenMMException("TensorRTForce: the graph does not have a 'positions' input");
// if (TF_OperationOutputType(positions) != TF_FLOAT)
// throw OpenMMException("TensorRTForce: 'positions' must have type float32");
// if (owner.usesPeriodicBoundaryConditions()) {
// TF_Output boxvectors = {TF_GraphOperationByName(graph, "boxvectors"), 0};
// if (boxvectors.oper == NULL)
// throw OpenMMException("TensorRTForce: the graph does not have a 'boxvectors' input");
// if (TF_OperationOutputType(boxvectors) != TF_FLOAT)
// throw OpenMMException("TensorRTForce: 'boxvectors' must have type float32");
// }
// TF_Output energy = {TF_GraphOperationByName(graph, "energy"), 0};
// if (energy.oper == NULL)
// throw OpenMMException("TensorRTForce: the graph does not have an 'energy' output");
// if (TF_OperationOutputType(energy) != TF_FLOAT)
// throw OpenMMException("TensorRTForce: 'energy' must have type float32");
// TF_Output forces = {TF_GraphOperationByName(graph, "forces"), 0};
// if (forces.oper == NULL)
// throw OpenMMException("TensorRTForce: the graph does not have a 'forces' output");
// if (TF_OperationOutputType(forces) != TF_FLOAT)
// throw OpenMMException("TensorRTForce: 'forces' must have type float32");

// Validate TesorRT graph
const auto periodic = owner.usesPeriodicBoundaryConditions();
Expand All @@ -88,10 +56,9 @@ void TensorRTForceImpl::initialize(ContextImpl& context) {

// TODO complete validation

// Create the kernel.

// Create the kernel
kernel = context.getPlatform().createKernel(CalcTesorRTForceKernel::Name(), context);
kernel.getAs<CalcTesorRTForceKernel>().initialize(context.getSystem(), owner, session, graph, *engine);
kernel.getAs<CalcTesorRTForceKernel>().initialize(context.getSystem(), owner, *engine);
}

double TensorRTForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
Expand Down
127 changes: 30 additions & 97 deletions platforms/cuda/src/CudaTensorRTKernels.cpp
Original file line number Diff line number Diff line change
@@ -1,147 +1,80 @@
#include "CudaTensorRTKernels.h"
#include "CudaTensorRTKernelSources.h"
#include "openmm/internal/ContextImpl.h"
#include <map>

using namespace OpenMM;

CudaCalcTensorRTForceKernel::~CudaCalcTensorRTForceKernel() {
if (positionsTensor != NULL)
TF_DeleteTensor(positionsTensor);
if (boxVectorsTensor != NULL)
TF_DeleteTensor(boxVectorsTensor);
}

void CudaCalcTensorRTForceKernel::initialize(const System& system, const TensorRTForce& force, TF_Session* session, TF_Graph* graph, Engine& engine) {
void CudaCalcTensorRTForceKernel::initialize(const System& system, const TensorRTForce& force, Engine& engine) {

cu.setAsCurrent();
this->session = session;
this->graph = graph;

const int numParticles = system.getNumParticles();
usePeriodic = force.usesPeriodicBoundaryConditions();
int numParticles = system.getNumParticles();

// Create TensorRT execution context
// TODO fix the destructor
const auto destructor = [](ExecutionContext* e) { /* e->destroy(); */ };
execution = {engine.createExecutionContext(), destructor};

// Construct input tensors.

const int64_t positionsDims[] = {numParticles, 3};
positionsTensor = TF_AllocateTensor(TF_FLOAT, positionsDims, 2, numParticles*3*TF_DataTypeSize(TF_FLOAT));
if (usePeriodic) {
const int64_t boxVectorsDims[] = {3, 3};
boxVectorsTensor = TF_AllocateTensor(TF_FLOAT, boxVectorsDims, 2, 9*TF_DataTypeSize(TF_FLOAT));
}

// Inititalize CUDA objects.
graphForces.initialize(cu, 3*numParticles, TF_DataTypeSize(TF_FLOAT), "graphForces");

// Initialize CUDA arrays
graphPositions.initialize<float>(cu, 3*numParticles, "graphPosition");
if (usePeriodic)
graphVectors.initialize<float>(cu, 9, "graphVectors");
graphEnergy.initialize<float>(cu, 1, "graphEnergy");
graphForces2.initialize<float>(cu, 3*numParticles, "graphForces2");
graphForces.initialize<float>(cu, 3*numParticles, "graphForces2");

// Create biding for the graph execution
static_assert(sizeof(CUdeviceptr) == sizeof(void*));

bindings.push_back(reinterpret_cast<void*>(graphPositions.getDevicePointer()));
if (usePeriodic)
bindings.push_back(reinterpret_cast<void*>(graphVectors.getDevicePointer()));
bindings.push_back(reinterpret_cast<void*>(graphEnergy.getDevicePointer()));
bindings.push_back(reinterpret_cast<void*>(graphForces2.getDevicePointer()));
bindings.push_back(reinterpret_cast<void*>(graphForces.getDevicePointer()));

// Create kernles
// Create kernels
auto module = cu.createModule(CudaTensorRTKernelSources::TensorRTForce);
addForcesKernel = cu.getKernel(module, "addForces");
}

double CudaCalcTensorRTForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {

std::vector<Vec3> pos;
context.getPositions(pos);
int numParticles = cu.getNumAtoms();
auto positions = reinterpret_cast<float*>(TF_TensorData(positionsTensor));
for (int i = 0; i < numParticles; i++) {
positions[3*i] = pos[i][0];
positions[3*i+1] = pos[i][1];
positions[3*i+2] = pos[i][2];
}

if (usePeriodic) {
Vec3 box[3];
cu.getPeriodicBoxVectors(box[0], box[1], box[2]);
auto boxVectors = reinterpret_cast<float*>(TF_TensorData(boxVectorsTensor));
for (int i = 0; i < 3; i++)
for (int j = 0; j < 3; j++)
boxVectors[3*i+j] = box[i][j];
std::vector<Vec3> positions;
context.getPositions(positions);
std::vector<float> positionsArray;
for (const auto& atom: positions) {
positionsArray.push_back(atom[0]);
positionsArray.push_back(atom[1]);
positionsArray.push_back(atom[2]);
}
graphPositions.upload(positionsArray);

std::vector<TF_Output> inputs;
std::vector<TF_Tensor*> inputTensors;
inputs.push_back({TF_GraphOperationByName(graph, "positions"), 0});
inputTensors.push_back(positionsTensor);
if (usePeriodic) {
inputs.push_back({TF_GraphOperationByName(graph, "boxvectors"), 0});
inputTensors.push_back(boxVectorsTensor);
}

std::vector<TF_Output> outputs;
int forceOutputIndex = 0;
if (includeEnergy)
outputs.push_back({TF_GraphOperationByName(graph, "energy"), 0});
if (includeForces) {
forceOutputIndex = outputs.size();
outputs.push_back({TF_GraphOperationByName(graph, "forces"), 0});
}
std::vector<TF_Tensor*> outputTensors(outputs.size());

auto status = TF_NewStatus();
TF_SessionRun(session, NULL, &inputs[0], &inputTensors[0], inputs.size(),
&outputs[0], &outputTensors[0], outputs.size(),
NULL, 0, NULL, status);
if (TF_GetCode(status) != TF_OK)
throw OpenMMException(std::string("Error running TensorFlow session: ")+TF_Message(status));
TF_DeleteStatus(status);

std::vector<float> positions2;
for (const auto& p: pos) {
positions2.push_back(p[0]);
positions2.push_back(p[1]);
positions2.push_back(p[2]);
}
graphPositions.upload(positions2);

if (usePeriodic) {
std::vector<float> vectors;
Vec3 box[3];
cu.getPeriodicBoxVectors(box[0], box[1], box[2]);
std::vector<float> vectorArray;
Vec3 vectors[3];
cu.getPeriodicBoxVectors(vectors[0], vectors[1], vectors[2]);
for (int i = 0; i < 3; i++) {
vectors.push_back(box[i][0]);
vectors.push_back(box[i][1]);
vectors.push_back(box[i][2]);
vectorArray.push_back(vectors[i][0]);
vectorArray.push_back(vectors[i][1]);
vectorArray.push_back(vectors[i][2]);
}
graphVectors.upload(vectors);
graphVectors.upload(vectorArray);
}

// Execute the graph
execution->executeV2(bindings.data());

double energy = 0.0;
if (includeEnergy)
energy = reinterpret_cast<float*>(TF_TensorData(outputTensors[0]))[0];

if (includeEnergy) {
std::vector<float> energy2;
graphEnergy.download(energy2);
energy = energy2[0];
std::vector<float> energyArray;
graphEnergy.download(energyArray);
energy = energyArray[0];
}

if (includeForces) {
const void* data = TF_TensorData(outputTensors[forceOutputIndex]);
graphForces.upload(data);
int numAtoms = cu.getNumAtoms();
int paddedNumAtoms = cu.getPaddedNumAtoms();
void* args[] = {&graphForces2.getDevicePointer(), &cu.getForce().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(), &numParticles, &paddedNumAtoms};
cu.executeKernel(addForcesKernel, args, numParticles);
void* args[] = {&graphForces.getDevicePointer(), &cu.getForce().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(), &numAtoms, &paddedNumAtoms};
cu.executeKernel(addForcesKernel, args, numAtoms);
}

return energy;
Expand Down
Loading

0 comments on commit 96c92f1

Please sign in to comment.