forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[samples]: Samples using the Java API.
1 parent
364f96d
commit 997f787
Showing
22 changed files
with
2,539 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# TensorFlow for Java: Examples | ||
|
||
Examples using the TensorFlow Java API. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
FROM tensorflow/tensorflow:1.4.0 | ||
WORKDIR / | ||
RUN apt-get update | ||
RUN apt-get -y install maven openjdk-8-jdk | ||
RUN mvn dependency:get -Dartifact=org.tensorflow:tensorflow:1.4.0 | ||
RUN mvn dependency:get -Dartifact=org.tensorflow:proto:1.4.0 | ||
CMD ["/bin/bash", "-l"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
Dockerfile for building an image suitable for running the Java examples. | ||
|
||
Typical usage: | ||
|
||
``` | ||
docker build -t java-tensorflow . | ||
docker run -it --rm -v ${PWD}/..:/examples java-tensorflow | ||
``` | ||
|
||
That second command will pop you into a shell which has all | ||
the dependencies required to execute the scripts and Java | ||
examples. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
images | ||
src/main/resources | ||
target |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Image Classification Example | ||
|
||
1. Download the model: | ||
- If you have [TensorFlow 1.4+ for Python installed](https://www.tensorflow.org/install/), | ||
run `python ./download.py` | ||
- If not, but you have [docker](https://www.docker.com/get-docker) installed, | ||
run `download.sh`. | ||
|
||
2. Compile [`LabelImage.java`](src/main/java/LabelImage.java): | ||
|
||
``` | ||
mvn compile | ||
``` | ||
|
||
3. Download some sample images: | ||
If you already have some images, great. Otherwise `download_sample_images.sh` | ||
gets a few. | ||
|
||
3. Classify! | ||
|
||
``` | ||
mvn -q exec:java -Dexec.args="<path to image file>" | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
"""Create an image classification graph. | ||
Script to download a pre-trained image classifier and tweak it so that | ||
the model accepts raw bytes of an encoded image. | ||
Doing so involves some model-specific normalization of an image. | ||
Ideally, this would have been part of the image classifier model, | ||
but the particular model being used didn't include this normalization, | ||
so this script does the necessary tweaking. | ||
""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from six.moves import urllib | ||
import os | ||
import zipfile | ||
import tensorflow as tf | ||
|
||
URL = 'https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip' | ||
LABELS_FILE = 'imagenet_comp_graph_label_strings.txt' | ||
GRAPH_FILE = 'tensorflow_inception_graph.pb' | ||
|
||
GRAPH_INPUT_TENSOR = 'input:0' | ||
GRAPH_PROBABILITIES_TENSOR = 'output:0' | ||
|
||
IMAGE_HEIGHT = 224 | ||
IMAGE_WIDTH = 224 | ||
MEAN = 117 | ||
SCALE = 1 | ||
|
||
LOCAL_DIR = 'src/main/resources' | ||
|
||
|
||
def download(): | ||
print('Downloading %s' % URL) | ||
zip_filename, _ = urllib.request.urlretrieve(URL) | ||
with zipfile.ZipFile(zip_filename) as zip: | ||
zip.extract(LABELS_FILE) | ||
zip.extract(GRAPH_FILE) | ||
os.rename(LABELS_FILE, os.path.join(LOCAL_DIR, 'labels.txt')) | ||
os.rename(GRAPH_FILE, os.path.join(LOCAL_DIR, 'graph.pb')) | ||
|
||
|
||
def create_graph_to_decode_and_normalize_image(): | ||
"""See file docstring. | ||
Returns: | ||
input: The placeholder to feed the raw bytes of an encoded image. | ||
y: A Tensor (the decoded, normalized image) to be fed to the graph. | ||
""" | ||
image = tf.placeholder(tf.string, shape=(), name='encoded_image_bytes') | ||
with tf.name_scope("preprocess"): | ||
y = tf.image.decode_image(image, channels=3) | ||
y = tf.cast(y, tf.float32) | ||
y = tf.expand_dims(y, axis=0) | ||
y = tf.image.resize_bilinear(y, (IMAGE_HEIGHT, IMAGE_WIDTH)) | ||
y = (y - MEAN) / SCALE | ||
return (image, y) | ||
|
||
|
||
def patch_graph(): | ||
"""Create graph.pb that applies the model in URL to raw image bytes.""" | ||
with tf.Graph().as_default() as g: | ||
input_image, image_normalized = create_graph_to_decode_and_normalize_image() | ||
original_graph_def = tf.GraphDef() | ||
with open(os.path.join(LOCAL_DIR, 'graph.pb')) as f: | ||
original_graph_def.ParseFromString(f.read()) | ||
softmax = tf.import_graph_def( | ||
original_graph_def, | ||
name='inception', | ||
input_map={GRAPH_INPUT_TENSOR: image_normalized}, | ||
return_elements=[GRAPH_PROBABILITIES_TENSOR]) | ||
# We're constructing a graph that accepts a single image (as opposed to a | ||
# batch of images), so might as well make the output be a vector of | ||
# probabilities, instead of a batch of vectors with batch size 1. | ||
output_probabilities = tf.squeeze(softmax, name='probabilities') | ||
# Overwrite the graph. | ||
with open(os.path.join(LOCAL_DIR, 'graph.pb'), 'w') as f: | ||
f.write(g.as_graph_def().SerializeToString()) | ||
print('------------------------------------------------------------') | ||
print('MODEL GRAPH : graph.pb') | ||
print('LABELS : labels.txt') | ||
print('INPUT TENSOR : %s' % input_image.op.name) | ||
print('OUTPUT TENSOR: %s' % output_probabilities.op.name) | ||
|
||
|
||
if __name__ == '__main__': | ||
if not os.path.exists(LOCAL_DIR): | ||
os.makedirs(LOCAL_DIR) | ||
download() | ||
patch_graph() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
#!/bin/bash | ||
|
||
DIR="$(cd "$(dirname "$0")" && pwd -P)" | ||
docker run -it -v ${DIR}:/x -w /x --rm tensorflow/tensorflow:1.4.0 python download.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#!/bin/bash | ||
DIR=$(dirname $0) | ||
mkdir -p ${DIR}/images | ||
cd ${DIR}/images | ||
|
||
# Some random images | ||
curl -o "porcupine.jpg" -L "https://cdn.pixabay.com/photo/2014/11/06/12/46/porcupines-519145_960_720.jpg" | ||
curl -o "whale.jpg" -L "https://static.pexels.com/photos/417196/pexels-photo-417196.jpeg" | ||
curl -o "terrier1u.jpg" -L "https://upload.wikimedia.org/wikipedia/commons/3/34/Australian_Terrier_Melly_%282%29.JPG" | ||
curl -o "terrier2.jpg" -L "https://cdn.pixabay.com/photo/2014/05/13/07/44/yorkshire-terrier-343198_960_720.jpg" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
<project> | ||
<modelVersion>4.0.0</modelVersion> | ||
<groupId>org.myorg</groupId> | ||
<artifactId>label-image</artifactId> | ||
<version>1.0-SNAPSHOT</version> | ||
<properties> | ||
<exec.mainClass>LabelImage</exec.mainClass> | ||
<!-- The sample code requires at least JDK 1.7. --> | ||
<!-- The maven compiler plugin defaults to a lower version --> | ||
<maven.compiler.source>1.7</maven.compiler.source> | ||
<maven.compiler.target>1.7</maven.compiler.target> | ||
</properties> | ||
<dependencies> | ||
<dependency> | ||
<groupId>org.tensorflow</groupId> | ||
<artifactId>tensorflow</artifactId> | ||
<version>1.4.0</version> | ||
</dependency> | ||
<!-- For ByteStreams.toByteArray: https://google.github.io/guava/releases/23.0/api/docs/com/google/common/io/ByteStreams.html --> | ||
<dependency> | ||
<groupId>com.google.guava</groupId> | ||
<artifactId>guava</artifactId> | ||
<version>23.6-jre</version> | ||
</dependency> | ||
</dependencies> | ||
</project> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
import com.google.common.io.ByteStreams; | ||
import java.io.BufferedReader; | ||
import java.io.IOException; | ||
import java.io.InputStream; | ||
import java.io.InputStreamReader; | ||
import java.nio.file.Files; | ||
import java.nio.file.Path; | ||
import java.nio.file.Paths; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
import org.tensorflow.Graph; | ||
import org.tensorflow.Session; | ||
import org.tensorflow.Tensor; | ||
import org.tensorflow.Tensors; | ||
|
||
/** | ||
* Simplified version of | ||
* https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java | ||
*/ | ||
public class LabelImage { | ||
public static void main(String[] args) throws Exception { | ||
if (args.length < 1) { | ||
System.err.println("USAGE: Provide a list of image filenames"); | ||
System.exit(1); | ||
} | ||
final List<String> labels = loadLabels(); | ||
try (Graph graph = new Graph(); | ||
Session session = new Session(graph)) { | ||
graph.importGraphDef(loadGraphDef()); | ||
|
||
float[] probabilities = null; | ||
for (String filename : args) { | ||
byte[] bytes = Files.readAllBytes(Paths.get(filename)); | ||
try (Tensor<String> input = Tensors.create(bytes); | ||
Tensor<Float> output = | ||
session | ||
.runner() | ||
.feed("encoded_image_bytes", input) | ||
.fetch("probabilities") | ||
.run() | ||
.get(0) | ||
.expect(Float.class)) { | ||
if (probabilities == null) { | ||
probabilities = new float[(int) output.shape()[0]]; | ||
} | ||
output.copyTo(probabilities); | ||
int label = argmax(probabilities); | ||
System.out.printf( | ||
"%-30s --> %-15s (%.2f%% likely)\n", | ||
filename, labels.get(label), probabilities[label] * 100.0); | ||
} | ||
} | ||
} | ||
} | ||
|
||
private static byte[] loadGraphDef() throws IOException { | ||
try (InputStream is = LabelImage.class.getClassLoader().getResourceAsStream("graph.pb")) { | ||
return ByteStreams.toByteArray(is); | ||
} | ||
} | ||
|
||
private static ArrayList<String> loadLabels() throws IOException { | ||
ArrayList<String> labels = new ArrayList<String>(); | ||
String line; | ||
final InputStream is = LabelImage.class.getClassLoader().getResourceAsStream("labels.txt"); | ||
try (BufferedReader reader = new BufferedReader(new InputStreamReader(is))) { | ||
while ((line = reader.readLine()) != null) { | ||
labels.add(line); | ||
} | ||
} | ||
return labels; | ||
} | ||
|
||
private static int argmax(float[] probabilities) { | ||
int best = 0; | ||
for (int i = 1; i < probabilities.length; ++i) { | ||
if (probabilities[i] > probabilities[best]) { | ||
best = i; | ||
} | ||
} | ||
return best; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
images | ||
labels | ||
models | ||
src/main/protobuf | ||
target |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Object Detection in Java | ||
|
||
Example of using pre-trained models of the [TensorFlow Object Detection | ||
API](https://github.com/tensorflow/models/tree/master/research/object_detection) | ||
in Java. | ||
|
||
## Quickstart | ||
|
||
1. Download some metadata files: | ||
``` | ||
./download.sh | ||
``` | ||
|
||
2. Download a model from the [object detection API model | ||
zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md). | ||
For example: | ||
``` | ||
mkdir -p models | ||
curl -L \ | ||
http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2017_11_17.tar.gz \ | ||
| tar -xz -C models/ | ||
``` | ||
|
||
3. Locate the corresponding labels file in the `data/` directory. | ||
|
||
3. Have some test images handy. For example: | ||
``` | ||
mkdir -p images | ||
curl -L -o images/test.jpg \ | ||
https://pixnio.com/free-images/people/mother-father-and-children-washing-dog-labrador-retriever-outside-in-the-fresh-air-725x483.jpg | ||
``` | ||
|
||
4. Compile and run! | ||
``` | ||
mvn -q compile exec:java \ | ||
-Dexec.args="models/ssd_inception_v2_coco_2017_11_17/saved_model labels/mscoco_label_map.pbtxt images/test.jpg" | ||
``` | ||
|
||
## Notes | ||
|
||
- This example demonstrates the use of the TensorFlow [SavedModel | ||
format](https://www.tensorflow.org/programmers_guide/saved_model). If you have | ||
TensorFlow for Python installed, you could explore the model to get the names | ||
of the tensors using `saved_model_cli` command. For example: | ||
``` | ||
saved_model_cli show --dir models/ssd_inception_v2_coco_2017_11_17/saved_model/ --all | ||
``` | ||
|
||
- The file in `src/main/object_detection/protos/` was generated using: | ||
|
||
``` | ||
./download.sh | ||
protoc -Isrc/main/protobuf --java_out=src/main/java src/main/protobuf/string_int_label_map.proto | ||
``` | ||
|
||
Where `protoc` was downloaded from | ||
https://github.com/google/protobuf/releases/tag/v3.5.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#!/bin/bash | ||
|
||
set -ex | ||
|
||
DIR="$(cd "$(dirname "$0")" && pwd -P)" | ||
cd "${DIR}" | ||
|
||
# The protobuf file needed for mapping labels to human readable names. | ||
# From: | ||
# https://github.com/tensorflow/models/blob/f87a58c/research/object_detection/protos/string_int_label_map.proto | ||
mkdir -p src/main/protobuf | ||
curl -L -o src/main/protobuf/string_int_label_map.proto "https://raw.githubusercontent.com/tensorflow/models/f87a58cd96d45de73c9a8330a06b2ab56749a7fa/research/object_detection/protos/string_int_label_map.proto" | ||
|
||
# Labels from: | ||
# https://github.com/tensorflow/models/tree/865c14c/research/object_detection/data | ||
mkdir -p labels | ||
curl -L -o labels/mscoco_label_map.pbtxt "https://raw.githubusercontent.com/tensorflow/models/865c14c1209cb9ae188b2a1b5f0883c72e050d4c/research/object_detection/data/mscoco_label_map.pbtxt" | ||
curl -L -o labels/oid_bbox_trainable_label_map.pbtxt "https://raw.githubusercontent.com/tensorflow/models/865c14c1209cb9ae188b2a1b5f0883c72e050d4c/research/object_detection/data/oid_bbox_trainable_label_map.pbtxt" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
<project> | ||
<modelVersion>4.0.0</modelVersion> | ||
<groupId>org.myorg</groupId> | ||
<artifactId>detect-objects</artifactId> | ||
<version>1.0-SNAPSHOT</version> | ||
<properties> | ||
<exec.mainClass>DetectObjects</exec.mainClass> | ||
<!-- The sample code requires at least JDK 1.7. --> | ||
<!-- The maven compiler plugin defaults to a lower version --> | ||
<maven.compiler.source>1.7</maven.compiler.source> | ||
<maven.compiler.target>1.7</maven.compiler.target> | ||
</properties> | ||
<dependencies> | ||
<dependency> | ||
<groupId>org.tensorflow</groupId> | ||
<artifactId>tensorflow</artifactId> | ||
<version>1.4.0</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.tensorflow</groupId> | ||
<artifactId>proto</artifactId> | ||
<version>1.4.0</version> | ||
</dependency> | ||
</dependencies> | ||
</project> |
184 changes: 184 additions & 0 deletions
184
samples/java/object_detection/src/main/java/DetectObjects.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
import static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap; | ||
import static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem; | ||
|
||
import com.google.protobuf.TextFormat; | ||
import java.awt.image.BufferedImage; | ||
import java.awt.image.DataBufferByte; | ||
import java.io.File; | ||
import java.io.IOException; | ||
import java.io.PrintStream; | ||
import java.nio.ByteBuffer; | ||
import java.nio.charset.StandardCharsets; | ||
import java.nio.file.Files; | ||
import java.nio.file.Paths; | ||
import java.util.List; | ||
import java.util.Map; | ||
import javax.imageio.ImageIO; | ||
import org.tensorflow.SavedModelBundle; | ||
import org.tensorflow.Tensor; | ||
import org.tensorflow.framework.MetaGraphDef; | ||
import org.tensorflow.framework.SignatureDef; | ||
import org.tensorflow.framework.TensorInfo; | ||
import org.tensorflow.types.UInt8; | ||
|
||
/** | ||
* Java inference for the Object Detection API at: | ||
* https://github.com/tensorflow/models/blob/master/research/object_detection/ | ||
*/ | ||
public class DetectObjects { | ||
public static void main(String[] args) throws Exception { | ||
if (args.length < 3) { | ||
printUsage(System.err); | ||
System.exit(1); | ||
} | ||
final String[] labels = loadLabels(args[1]); | ||
try (SavedModelBundle model = SavedModelBundle.load(args[0], "serve")) { | ||
printSignature(model); | ||
for (int arg = 2; arg < args.length; arg++) { | ||
final String filename = args[arg]; | ||
List<Tensor<?>> outputs = null; | ||
try (Tensor<UInt8> input = makeImageTensor(filename)) { | ||
outputs = | ||
model | ||
.session() | ||
.runner() | ||
.feed("image_tensor", input) | ||
.fetch("detection_scores") | ||
.fetch("detection_classes") | ||
.fetch("detection_boxes") | ||
.run(); | ||
} | ||
try (Tensor<Float> scoresT = outputs.get(0).expect(Float.class); | ||
Tensor<Float> classesT = outputs.get(1).expect(Float.class); | ||
Tensor<Float> boxesT = outputs.get(2).expect(Float.class)) { | ||
// All these tensors have: | ||
// - 1 as the first dimension | ||
// - maxObjects as the second dimension | ||
// While boxesT will have 4 as the third dimension (2 sets of (x, y) coordinates). | ||
// This can be verified by looking at scoresT.shape() etc. | ||
int maxObjects = (int) scoresT.shape()[1]; | ||
float[] scores = scoresT.copyTo(new float[1][maxObjects])[0]; | ||
float[] classes = classesT.copyTo(new float[1][maxObjects])[0]; | ||
float[][] boxes = boxesT.copyTo(new float[1][maxObjects][4])[0]; | ||
// Print all objects whose score is at least 0.5. | ||
System.out.printf("* %s\n", filename); | ||
boolean foundSomething = false; | ||
for (int i = 0; i < scores.length; ++i) { | ||
if (scores[i] < 0.5) { | ||
continue; | ||
} | ||
foundSomething = true; | ||
System.out.printf("\tFound %-20s (score: %.4f)\n", labels[(int) classes[i]], scores[i]); | ||
} | ||
if (!foundSomething) { | ||
System.out.println("No objects detected with a high enough score."); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
private static void printSignature(SavedModelBundle model) throws Exception { | ||
MetaGraphDef m = MetaGraphDef.parseFrom(model.metaGraphDef()); | ||
SignatureDef sig = m.getSignatureDefOrThrow("serving_default"); | ||
int numInputs = sig.getInputsCount(); | ||
int i = 1; | ||
System.out.println("MODEL SIGNATURE"); | ||
System.out.println("Inputs:"); | ||
for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) { | ||
TensorInfo t = entry.getValue(); | ||
System.out.printf( | ||
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n", | ||
i++, numInputs, entry.getKey(), t.getName(), t.getDtype()); | ||
} | ||
int numOutputs = sig.getOutputsCount(); | ||
i = 1; | ||
System.out.println("Outputs:"); | ||
for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) { | ||
TensorInfo t = entry.getValue(); | ||
System.out.printf( | ||
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n", | ||
i++, numOutputs, entry.getKey(), t.getName(), t.getDtype()); | ||
} | ||
System.out.println("-----------------------------------------------"); | ||
} | ||
|
||
private static String[] loadLabels(String filename) throws Exception { | ||
String text = new String(Files.readAllBytes(Paths.get(filename)), StandardCharsets.UTF_8); | ||
StringIntLabelMap.Builder builder = StringIntLabelMap.newBuilder(); | ||
TextFormat.merge(text, builder); | ||
StringIntLabelMap proto = builder.build(); | ||
int maxId = 0; | ||
for (StringIntLabelMapItem item : proto.getItemList()) { | ||
if (item.getId() > maxId) { | ||
maxId = item.getId(); | ||
} | ||
} | ||
String[] ret = new String[maxId + 1]; | ||
for (StringIntLabelMapItem item : proto.getItemList()) { | ||
ret[item.getId()] = item.getDisplayName(); | ||
} | ||
return ret; | ||
} | ||
|
||
private static void bgr2rgb(byte[] data) { | ||
for (int i = 0; i < data.length; i += 3) { | ||
byte tmp = data[i]; | ||
data[i] = data[i + 2]; | ||
data[i + 2] = tmp; | ||
} | ||
} | ||
|
||
private static Tensor<UInt8> makeImageTensor(String filename) throws IOException { | ||
BufferedImage img = ImageIO.read(new File(filename)); | ||
if (img.getType() != BufferedImage.TYPE_3BYTE_BGR) { | ||
throw new IOException( | ||
String.format( | ||
"Expected 3-byte BGR encoding in BufferedImage, found %d (file: %s). This code could be made more robust", | ||
img.getType(), filename)); | ||
} | ||
byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData(); | ||
// ImageIO.read seems to produce BGR-encoded images, but the model expects RGB. | ||
bgr2rgb(data); | ||
final long BATCH_SIZE = 1; | ||
final long CHANNELS = 3; | ||
long[] shape = new long[] {BATCH_SIZE, img.getHeight(), img.getWidth(), CHANNELS}; | ||
return Tensor.create(UInt8.class, shape, ByteBuffer.wrap(data)); | ||
} | ||
|
||
private static void printUsage(PrintStream s) { | ||
s.println("USAGE: <model> <label_map> <image> [<image>] [<image>]"); | ||
s.println(""); | ||
s.println("Where"); | ||
s.println("<model> is the path to the SavedModel directory of the model to use."); | ||
s.println(" For example, the saved_model directory in tarballs from "); | ||
s.println( | ||
" https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md)"); | ||
s.println(""); | ||
s.println( | ||
"<label_map> is the path to a file containing information about the labels detected by the model."); | ||
s.println(" For example, one of the .pbtxt files from "); | ||
s.println( | ||
" https://github.com/tensorflow/models/tree/master/research/object_detection/data"); | ||
s.println(""); | ||
s.println("<image> is the path to an image file."); | ||
s.println(" Sample images can be found from the COCO, Kitti, or Open Images dataset."); | ||
s.println( | ||
" See: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md"); | ||
} | ||
} |
1,785 changes: 1,785 additions & 0 deletions
1,785
...a/object_detection/src/main/java/object_detection/protos/StringIntLabelMapOuterClass.java
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
target | ||
checkpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Training models in Java | ||
|
||
Example of training a model (and saving and restoring checkpoints) using the | ||
TensorFlow Java API. | ||
|
||
## Quickstart | ||
|
||
1. Train for a few steps: | ||
``` | ||
mvn -q compile exec:java -Dexec.args="model/graph.pb checkpoint" | ||
``` | ||
2. Resume training from previous checkpoint and train some more: | ||
``` | ||
mvn -q exec:java -Dexec.args="model/graph.pb checkpoint" | ||
``` | ||
3. Delete checkpoint: | ||
``` | ||
rm -rf checkpoint | ||
``` | ||
## Details | ||
The model in `model/graph.pb` represents a very simple linear model: | ||
``` | ||
y = x * W + b | ||
``` | ||
The `graph.pb` file is generated by executing `create_graph.py` in Python. | ||
The training is orchestrated by `src/main/java/Train.java`, which generates | ||
training data of the form `y = 3.0 * x + 2.0` and over time, using gradient | ||
descent, the model should "learn" and the value of `W` should converge to 3.0, | ||
and `b` to 2.0. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
|
||
x = tf.placeholder(tf.float32, name='input') | ||
y_ = tf.placeholder(tf.float32, name='target') | ||
|
||
W = tf.Variable(5., name='W') | ||
b = tf.Variable(3., name='b') | ||
|
||
y = x * W + b | ||
y = tf.identity(y, name='output') | ||
|
||
loss = tf.reduce_mean(tf.square(y - y_)) | ||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) | ||
train_op = optimizer.minimize(loss, name='train') | ||
|
||
init = tf.global_variables_initializer() | ||
|
||
# Creating a tf.train.Saver adds operations to the graph to save and | ||
# restore variables from checkpoints. | ||
saver_def = tf.train.Saver().as_saver_def() | ||
|
||
print('Operation to initialize variables: ', init.name) | ||
print('Tensor to feed as input data: ', x.name) | ||
print('Tensor to feed as training targets: ', y_.name) | ||
print('Tensor to fetch as prediction: ', y.name) | ||
print('Operation to train one step: ', train_op.name) | ||
print('Tensor to be fed for checkpoint filename:', saver_def.filename_tensor_name) | ||
print('Operation to save a checkpoint: ', saver_def.save_tensor_name) | ||
print('Operation to restore a checkpoint: ', saver_def.restore_op_name) | ||
print('Tensor to read value of W ', W.value().name) | ||
print('Tensor to read value of b ', b.value().name) | ||
|
||
with open('graph.pb', 'w') as f: | ||
f.write(tf.get_default_graph().as_graph_def().SerializeToString()) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
<project> | ||
<modelVersion>4.0.0</modelVersion> | ||
<groupId>org.myorg</groupId> | ||
<artifactId>training</artifactId> | ||
<version>1.0-SNAPSHOT</version> | ||
<properties> | ||
<exec.mainClass>Train</exec.mainClass> | ||
<!-- The sample code requires at least JDK 1.7. --> | ||
<!-- The maven compiler plugin defaults to a lower version --> | ||
<maven.compiler.source>1.7</maven.compiler.source> | ||
<maven.compiler.target>1.7</maven.compiler.target> | ||
</properties> | ||
<dependencies> | ||
<dependency> | ||
<groupId>org.tensorflow</groupId> | ||
<artifactId>tensorflow</artifactId> | ||
<version>1.4.0</version> | ||
</dependency> | ||
</dependencies> | ||
</project> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
import java.nio.file.Files; | ||
import java.nio.file.Paths; | ||
import java.util.List; | ||
import java.util.Random; | ||
import org.tensorflow.Graph; | ||
import org.tensorflow.Session; | ||
import org.tensorflow.Tensor; | ||
import org.tensorflow.Tensors; | ||
|
||
/** | ||
* Training a trivial linear model. | ||
*/ | ||
public class Train { | ||
public static void main(String[] args) throws Exception { | ||
if (args.length != 2) { | ||
System.err.println("Require two arguments: The GraphDef file and checkpoint directory"); | ||
System.exit(1); | ||
} | ||
|
||
final byte[] graphDef = Files.readAllBytes(Paths.get(args[0])); | ||
final String checkpointDir = args[1]; | ||
final boolean checkpointExists = Files.exists(Paths.get(checkpointDir)); | ||
|
||
try (Graph graph = new Graph(); | ||
Session sess = new Session(graph); | ||
Tensor<String> checkpointPrefix = | ||
Tensors.create(Paths.get(checkpointDir, "ckpt").toString())) { | ||
graph.importGraphDef(graphDef); | ||
|
||
// Initialize or restore | ||
if (checkpointExists) { | ||
sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/restore_all").run(); | ||
} else { | ||
sess.runner().addTarget("init").run(); | ||
} | ||
System.out.print("Starting from : "); | ||
printVariables(sess); | ||
|
||
// Train a bunch of times. | ||
// (Will be much more efficient if we sent batches instead of individual values). | ||
final Random r = new Random(); | ||
final int NUM_EXAMPLES = 500; | ||
for (int i = 1; i <= 5; i++) { | ||
for (int n = 0; n < NUM_EXAMPLES; n++) { | ||
float in = r.nextFloat(); | ||
try (Tensor<Float> input = Tensors.create(in); | ||
Tensor<Float> target = Tensors.create(3 * in + 2)) { | ||
sess.runner().feed("input", input).feed("target", target).addTarget("train").run(); | ||
} | ||
} | ||
System.out.printf("After %5d examples: ", i*NUM_EXAMPLES); | ||
printVariables(sess); | ||
} | ||
|
||
// Checkpoint | ||
sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/control_dependency").run(); | ||
|
||
// Example of "inference" in the same graph: | ||
try (Tensor<Float> input = Tensors.create(1.0f); | ||
Tensor<Float> output = | ||
sess.runner().feed("input", input).fetch("output").run().get(0).expect(Float.class)) { | ||
System.out.printf( | ||
"For input %f, produced %f (ideally would produce 3*%f + 2)\n", | ||
input.floatValue(), output.floatValue(), input.floatValue()); | ||
} | ||
} | ||
} | ||
|
||
private static void printVariables(Session sess) { | ||
List<Tensor<?>> values = sess.runner().fetch("W/read").fetch("b/read").run(); | ||
System.out.printf("W = %f\tb = %f\n", values.get(0).floatValue(), values.get(1).floatValue()); | ||
for (Tensor<?> t : values) { | ||
t.close(); | ||
} | ||
} | ||
} |