Skip to content

Commit

Permalink
Merge pull request tensorflow#3157 from asimshankar/java-samples
Browse files Browse the repository at this point in the history
[samples]: Samples using the Java API.
  • Loading branch information
asimshankar authored Jan 16, 2018
2 parents 9949146 + 28bd85d commit ca15f5d
Show file tree
Hide file tree
Showing 24 changed files with 2,559 additions and 0 deletions.
3 changes: 3 additions & 0 deletions samples/languages/java/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# TensorFlow for Java: Examples

Examples using the TensorFlow Java API.
7 changes: 7 additions & 0 deletions samples/languages/java/docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
FROM tensorflow/tensorflow:1.4.0
WORKDIR /
RUN apt-get update
RUN apt-get -y install maven openjdk-8-jdk
RUN mvn dependency:get -Dartifact=org.tensorflow:tensorflow:1.4.0
RUN mvn dependency:get -Dartifact=org.tensorflow:proto:1.4.0
CMD ["/bin/bash", "-l"]
15 changes: 15 additions & 0 deletions samples/languages/java/docker/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Dockerfile for building an image suitable for running the Java examples.

Typical usage:

```
docker build -t java-tensorflow .
docker run -it --rm -v ${PWD}/..:/examples -w /examples java-tensorflow
```

That second command will pop you into a shell which has all
the dependencies required to execute the scripts and Java
examples.

The script `sanity_test.sh` builds this container and runs a compilation
check on all the maven projects.
7 changes: 7 additions & 0 deletions samples/languages/java/docker/sanity_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash
#
# Silly sanity test
DIR="$(cd "$(dirname "$0")" && pwd -P)"

docker build -t java-tensorflow .
docker run -it --rm -v ${PWD}/..:/examples java-tensorflow bash /examples/docker/test_inside_container.sh
12 changes: 12 additions & 0 deletions samples/languages/java/docker/test_inside_container.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash

set -ex

cd /examples/label_image
mvn compile

cd /examples/object_detection
mvn compile

cd /examples/training
mvn compile
3 changes: 3 additions & 0 deletions samples/languages/java/label_image/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
images
src/main/resources
target
23 changes: 23 additions & 0 deletions samples/languages/java/label_image/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Image Classification Example

1. Download the model:
- If you have [TensorFlow 1.4+ for Python installed](https://www.tensorflow.org/install/),
run `python ./download.py`
- If not, but you have [docker](https://www.docker.com/get-docker) installed,
run `download.sh`.

2. Compile [`LabelImage.java`](src/main/java/LabelImage.java):

```
mvn compile
```

3. Download some sample images:
If you already have some images, great. Otherwise `download_sample_images.sh`
gets a few.

3. Classify!

```
mvn -q exec:java -Dexec.args="<path to image file>"
```
93 changes: 93 additions & 0 deletions samples/languages/java/label_image/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Create an image classification graph.
Script to download a pre-trained image classifier and tweak it so that
the model accepts raw bytes of an encoded image.
Doing so involves some model-specific normalization of an image.
Ideally, this would have been part of the image classifier model,
but the particular model being used didn't include this normalization,
so this script does the necessary tweaking.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from six.moves import urllib
import os
import zipfile
import tensorflow as tf

URL = 'https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip'
LABELS_FILE = 'imagenet_comp_graph_label_strings.txt'
GRAPH_FILE = 'tensorflow_inception_graph.pb'

GRAPH_INPUT_TENSOR = 'input:0'
GRAPH_PROBABILITIES_TENSOR = 'output:0'

IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224
MEAN = 117
SCALE = 1

LOCAL_DIR = 'src/main/resources'


def download():
print('Downloading %s' % URL)
zip_filename, _ = urllib.request.urlretrieve(URL)
with zipfile.ZipFile(zip_filename) as zip:
zip.extract(LABELS_FILE)
zip.extract(GRAPH_FILE)
os.rename(LABELS_FILE, os.path.join(LOCAL_DIR, 'labels.txt'))
os.rename(GRAPH_FILE, os.path.join(LOCAL_DIR, 'graph.pb'))


def create_graph_to_decode_and_normalize_image():
"""See file docstring.
Returns:
input: The placeholder to feed the raw bytes of an encoded image.
y: A Tensor (the decoded, normalized image) to be fed to the graph.
"""
image = tf.placeholder(tf.string, shape=(), name='encoded_image_bytes')
with tf.name_scope("preprocess"):
y = tf.image.decode_image(image, channels=3)
y = tf.cast(y, tf.float32)
y = tf.expand_dims(y, axis=0)
y = tf.image.resize_bilinear(y, (IMAGE_HEIGHT, IMAGE_WIDTH))
y = (y - MEAN) / SCALE
return (image, y)


def patch_graph():
"""Create graph.pb that applies the model in URL to raw image bytes."""
with tf.Graph().as_default() as g:
input_image, image_normalized = create_graph_to_decode_and_normalize_image()
original_graph_def = tf.GraphDef()
with open(os.path.join(LOCAL_DIR, 'graph.pb')) as f:
original_graph_def.ParseFromString(f.read())
softmax = tf.import_graph_def(
original_graph_def,
name='inception',
input_map={GRAPH_INPUT_TENSOR: image_normalized},
return_elements=[GRAPH_PROBABILITIES_TENSOR])
# We're constructing a graph that accepts a single image (as opposed to a
# batch of images), so might as well make the output be a vector of
# probabilities, instead of a batch of vectors with batch size 1.
output_probabilities = tf.squeeze(softmax, name='probabilities')
# Overwrite the graph.
with open(os.path.join(LOCAL_DIR, 'graph.pb'), 'w') as f:
f.write(g.as_graph_def().SerializeToString())
print('------------------------------------------------------------')
print('MODEL GRAPH : graph.pb')
print('LABELS : labels.txt')
print('INPUT TENSOR : %s' % input_image.op.name)
print('OUTPUT TENSOR: %s' % output_probabilities.op.name)


if __name__ == '__main__':
if not os.path.exists(LOCAL_DIR):
os.makedirs(LOCAL_DIR)
download()
patch_graph()
4 changes: 4 additions & 0 deletions samples/languages/java/label_image/download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash

DIR="$(cd "$(dirname "$0")" && pwd -P)"
docker run -it -v ${DIR}:/x -w /x --rm tensorflow/tensorflow:1.4.0 python download.py
10 changes: 10 additions & 0 deletions samples/languages/java/label_image/download_sample_images.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
DIR=$(dirname $0)
mkdir -p ${DIR}/images
cd ${DIR}/images

# Some random images
curl -o "porcupine.jpg" -L "https://cdn.pixabay.com/photo/2014/11/06/12/46/porcupines-519145_960_720.jpg"
curl -o "whale.jpg" -L "https://static.pexels.com/photos/417196/pexels-photo-417196.jpeg"
curl -o "terrier1u.jpg" -L "https://upload.wikimedia.org/wikipedia/commons/3/34/Australian_Terrier_Melly_%282%29.JPG"
curl -o "terrier2.jpg" -L "https://cdn.pixabay.com/photo/2014/05/13/07/44/yorkshire-terrier-343198_960_720.jpg"
26 changes: 26 additions & 0 deletions samples/languages/java/label_image/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<project>
<modelVersion>4.0.0</modelVersion>
<groupId>org.myorg</groupId>
<artifactId>label-image</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<exec.mainClass>LabelImage</exec.mainClass>
<!-- The sample code requires at least JDK 1.7. -->
<!-- The maven compiler plugin defaults to a lower version -->
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.4.0</version>
</dependency>
<!-- For ByteStreams.toByteArray: https://google.github.io/guava/releases/23.0/api/docs/com/google/common/io/ByteStreams.html -->
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>23.6-jre</version>
</dependency>
</dependencies>
</project>
98 changes: 98 additions & 0 deletions samples/languages/java/label_image/src/main/java/LabelImage.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

import com.google.common.io.ByteStreams;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;

/**
* Simplified version of
* https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
*/
public class LabelImage {
public static void main(String[] args) throws Exception {
if (args.length < 1) {
System.err.println("USAGE: Provide a list of image filenames");
System.exit(1);
}
final List<String> labels = loadLabels();
try (Graph graph = new Graph();
Session session = new Session(graph)) {
graph.importGraphDef(loadGraphDef());

float[] probabilities = null;
for (String filename : args) {
byte[] bytes = Files.readAllBytes(Paths.get(filename));
try (Tensor<String> input = Tensors.create(bytes);
Tensor<Float> output =
session
.runner()
.feed("encoded_image_bytes", input)
.fetch("probabilities")
.run()
.get(0)
.expect(Float.class)) {
if (probabilities == null) {
probabilities = new float[(int) output.shape()[0]];
}
output.copyTo(probabilities);
int label = argmax(probabilities);
System.out.printf(
"%-30s --> %-15s (%.2f%% likely)\n",
filename, labels.get(label), probabilities[label] * 100.0);
}
}
}
}

private static byte[] loadGraphDef() throws IOException {
try (InputStream is = LabelImage.class.getClassLoader().getResourceAsStream("graph.pb")) {
return ByteStreams.toByteArray(is);
}
}

private static ArrayList<String> loadLabels() throws IOException {
ArrayList<String> labels = new ArrayList<String>();
String line;
final InputStream is = LabelImage.class.getClassLoader().getResourceAsStream("labels.txt");
try (BufferedReader reader = new BufferedReader(new InputStreamReader(is))) {
while ((line = reader.readLine()) != null) {
labels.add(line);
}
}
return labels;
}

private static int argmax(float[] probabilities) {
int best = 0;
for (int i = 1; i < probabilities.length; ++i) {
if (probabilities[i] > probabilities[best]) {
best = i;
}
}
return best;
}
}
5 changes: 5 additions & 0 deletions samples/languages/java/object_detection/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
images
labels
models
src/main/protobuf
target
55 changes: 55 additions & 0 deletions samples/languages/java/object_detection/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Object Detection in Java

Example of using pre-trained models of the [TensorFlow Object Detection
API](https://github.com/tensorflow/models/tree/master/research/object_detection)
in Java.

## Quickstart

1. Download some metadata files:
```
./download.sh
```

2. Download a model from the [object detection API model
zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md).
For example:
```
mkdir -p models
curl -L \
http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2017_11_17.tar.gz \
| tar -xz -C models/
```

3. Have some test images handy. For example:
```
mkdir -p images
curl -L -o images/test.jpg \
https://pixnio.com/free-images/people/mother-father-and-children-washing-dog-labrador-retriever-outside-in-the-fresh-air-725x483.jpg
```

4. Compile and run!
```
mvn -q compile exec:java \
-Dexec.args="models/ssd_inception_v2_coco_2017_11_17/saved_model labels/mscoco_label_map.pbtxt images/test.jpg"
```

## Notes

- This example demonstrates the use of the TensorFlow [SavedModel
format](https://www.tensorflow.org/programmers_guide/saved_model). If you have
TensorFlow for Python installed, you could explore the model to get the names
of the tensors using `saved_model_cli` command. For example:
```
saved_model_cli show --dir models/ssd_inception_v2_coco_2017_11_17/saved_model/ --all
```

- The file in `src/main/object_detection/protos/` was generated using:

```
./download.sh
protoc -Isrc/main/protobuf --java_out=src/main/java src/main/protobuf/string_int_label_map.proto
```

Where `protoc` was downloaded from
https://github.com/google/protobuf/releases/tag/v3.5.1
18 changes: 18 additions & 0 deletions samples/languages/java/object_detection/download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash

set -ex

DIR="$(cd "$(dirname "$0")" && pwd -P)"
cd "${DIR}"

# The protobuf file needed for mapping labels to human readable names.
# From:
# https://github.com/tensorflow/models/blob/f87a58c/research/object_detection/protos/string_int_label_map.proto
mkdir -p src/main/protobuf
curl -L -o src/main/protobuf/string_int_label_map.proto "https://raw.githubusercontent.com/tensorflow/models/f87a58cd96d45de73c9a8330a06b2ab56749a7fa/research/object_detection/protos/string_int_label_map.proto"

# Labels from:
# https://github.com/tensorflow/models/tree/865c14c/research/object_detection/data
mkdir -p labels
curl -L -o labels/mscoco_label_map.pbtxt "https://raw.githubusercontent.com/tensorflow/models/865c14c1209cb9ae188b2a1b5f0883c72e050d4c/research/object_detection/data/mscoco_label_map.pbtxt"
curl -L -o labels/oid_bbox_trainable_label_map.pbtxt "https://raw.githubusercontent.com/tensorflow/models/865c14c1209cb9ae188b2a1b5f0883c72e050d4c/research/object_detection/data/oid_bbox_trainable_label_map.pbtxt"
Loading

0 comments on commit ca15f5d

Please sign in to comment.