Skip to content

Commit

Permalink
Java: Provide the Java API in the Android inference library.
Browse files Browse the repository at this point in the history
A couple of things here:

(1) Avoid use of "try-with-resources" blocks in the Java API implementation,
    as older versions of Android (minSdkVersion < 19) do not support this.
    While we remove the use of such blocks in the implementation, the
    documentation and examples still advocate the use of try-with-resources,
    which should be used when possible (JDK 7, Android API >= 19)

(2) Include the Java sources in the Android Inference Library.
    This allows a .aar file to be built that contains both the
    TensorFlowInferenceInterface and the full Java API.
    Making transition from one to the other easier.

Another step in the journey that is tensorflow#5
Change: 145361038
  • Loading branch information
asimshankar authored and tensorflower-gardener committed Jan 24, 2017
1 parent 9b7c47c commit cd4a964
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 24 deletions.
2 changes: 1 addition & 1 deletion tensorflow/contrib/android/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ filegroup(
# JAR with Java bindings to TF.
android_library(
name = "android_tensorflow_inference_java",
srcs = glob(["java/**/*.java"]),
srcs = glob(["java/**/*.java"]) + ["//tensorflow/java:java_sources"],
tags = [
"manual",
"notap",
Expand Down
13 changes: 12 additions & 1 deletion tensorflow/java/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,22 @@ licenses(["notice"]) # Apache 2.0

java_library(
name = "tensorflow",
srcs = glob(["src/main/java/org/tensorflow/*.java"]),
srcs = [":java_sources"],
data = [":libtensorflow_jni"],
visibility = ["//visibility:public"],
)

# NOTE(ashankar): Rule to include the Java API in the Android Inference Library
# .aar. At some point, might make sense for a .aar rule here instead.
filegroup(
name = "java_sources",
srcs = glob(["src/main/java/org/tensorflow/*.java"]),
visibility = [
"//tensorflow/contrib/android:__pkg__",
"//tensorflow/java:__pkg__",
],
)

java_library(
name = "testutil",
testonly = 1,
Expand Down
15 changes: 12 additions & 3 deletions tensorflow/java/src/main/java/org/tensorflow/Operation.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ public final class Operation {

/** Returns the full name of the Operation. */
public String name() {
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
return name(unsafeNativeHandle);
} finally {
r.close();
}
}

Expand All @@ -49,15 +52,21 @@ public String name() {
* operation.
*/
public String type() {
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
return type(unsafeNativeHandle);
} finally {
r.close();
}
}

/** Returns the number of tensors produced by this operation. */
public int numOutputs() {
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
return numOutputs(unsafeNativeHandle);
} finally {
r.close();
}
}

Expand Down
80 changes: 64 additions & 16 deletions tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ public final class OperationBuilder {

OperationBuilder(Graph graph, String type, String name) {
this.graph = graph;
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
this.unsafeNativeHandle = allocate(r.nativeHandle(), type, name);
} finally {
r.close();
}
}

Expand All @@ -50,36 +53,48 @@ public final class OperationBuilder {
* <p>The OperationBuilder is not usable after build() returns.
*/
public Operation build() {
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
Operation op = new Operation(graph, finish(unsafeNativeHandle));
unsafeNativeHandle = 0;
return op;
} finally {
r.close();
}
}

public OperationBuilder addInput(Output input) {
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
addInput(unsafeNativeHandle, input.op().getUnsafeNativeHandle(), input.index());
} finally {
r.close();
}
return this;
}

public OperationBuilder addInputList(Output[] inputs) {
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
long[] opHandles = new long[inputs.length];
int[] indices = new int[inputs.length];
for (int i = 0; i < inputs.length; ++i) {
opHandles[i] = inputs[i].op().getUnsafeNativeHandle();
indices[i] = inputs[i].index();
}
addInputList(unsafeNativeHandle, opHandles, indices);
} finally {
r.close();
}
return this;
}

public OperationBuilder setDevice(String device) {
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
setDevice(unsafeNativeHandle, device);
} finally {
r.close();
}
return this;
}
Expand All @@ -90,57 +105,81 @@ public OperationBuilder setAttr(String name, String value) {
}

public OperationBuilder setAttr(String name, byte[] value) {
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
setAttrString(unsafeNativeHandle, name, value);
} finally {
r.close();
}
return this;
}

public OperationBuilder setAttr(String name, long value) {
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
setAttrInt(unsafeNativeHandle, name, value);
} finally {
r.close();
}
return this;
}

public OperationBuilder setAttr(String name, long[] value) {
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
setAttrIntList(unsafeNativeHandle, name, value);
} finally {
r.close();
}
return this;
}

public OperationBuilder setAttr(String name, float value) {
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
setAttrFloat(unsafeNativeHandle, name, value);
} finally {
r.close();
}
return this;
}

public OperationBuilder setAttr(String name, float[] value) {
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
setAttrFloatList(unsafeNativeHandle, name, value);
} finally {
r.close();
}
return this;
}

public OperationBuilder setAttr(String name, boolean value) {
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
setAttrBool(unsafeNativeHandle, name, value);
} finally {
r.close();
}
return this;
}

public OperationBuilder setAttr(String name, boolean[] value) {
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
setAttrBoolList(unsafeNativeHandle, name, value);
} finally {
r.close();
}
return this;
}

public OperationBuilder setAttr(String name, DataType value) {
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
setAttrType(unsafeNativeHandle, name, value.c());
} finally {
r.close();
}
return this;
}
Expand All @@ -150,15 +189,21 @@ public OperationBuilder setAttr(String name, DataType[] value) {
for (int i = 0; i < value.length; ++i) {
ctypes[i] = value[i].c();
}
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
setAttrTypeList(unsafeNativeHandle, name, ctypes);
} finally {
r.close();
}
return this;
}

public OperationBuilder setAttr(String name, Tensor value) {
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
setAttrTensor(unsafeNativeHandle, name, value.getNativeHandle());
} finally {
r.close();
}
return this;
}
Expand All @@ -169,8 +214,11 @@ public OperationBuilder setAttr(String name, Tensor[] value) {
for (Tensor t : value) {
handles[idx++] = t.getNativeHandle();
}
try (Graph.Reference r = graph.ref()) {
Graph.Reference r = graph.ref();
try {
setAttrTensorList(unsafeNativeHandle, name, handles);
} finally {
r.close();
}
return this;
}
Expand Down
10 changes: 8 additions & 2 deletions tensorflow/java/src/main/java/org/tensorflow/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ public final class Session implements AutoCloseable {
/** Construct a new session with the associated {@link Graph}. */
public Session(Graph g) {
graph = g;
try (Graph.Reference r = g.ref()) {
Graph.Reference r = g.ref();
try {
nativeHandle = allocate(r.nativeHandle());
graphRef = g.ref();
} finally {
r.close();
}
}

Expand Down Expand Up @@ -193,7 +196,8 @@ public List<Tensor> run() {
for (Operation op : targets) {
targetOpHandles[idx++] = op.getUnsafeNativeHandle();
}
try (Reference runref = new Reference()) {
Reference runRef = new Reference();
try {
Session.run(
nativeHandle,
null, /* runOptions */
Expand All @@ -205,6 +209,8 @@ public List<Tensor> run() {
targetOpHandles,
false, /* wantRunMetadata */
outputTensorHandles);
} finally {
runRef.close();
}
List<Tensor> ret = new ArrayList<Tensor>();
for (long h : outputTensorHandles) {
Expand Down
15 changes: 14 additions & 1 deletion tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,20 @@ private TensorFlow() {}

/** Load the TensorFlow runtime C library. */
static void init() {
System.loadLibrary("tensorflow_jni");
try {
System.loadLibrary("tensorflow_jni");
} catch (UnsatisfiedLinkError e) {
// The native code might have been statically linked (through a custom launcher) or be part of
// an application-level library. For example, tensorflow/examples/android and
// tensorflow/contrib/android include the required native code in differently named libraries.
// To allow for such cases, the UnsatisfiedLinkError does not bubble up here.
try {
version();
} catch (UnsatisfiedLinkError e2) {
System.err.println(
"TensorFlow Java API methods will throw an UnsatisfiedLinkError unless native code shared libraries are loaded");
}
}
}

static {
Expand Down

0 comments on commit cd4a964

Please sign in to comment.