Skip to content

Commit be137c8

Browse files
Add Classifier for TensorFlow Object Detection
1 parent 4e51973 commit be137c8

File tree

2 files changed

+323
-0
lines changed

2 files changed

+323
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright (C) 2017 MINDORKS NEXTGEN PRIVATE LIMITED
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.mindorks.tensorflowexample;
18+
19+
import android.graphics.Bitmap;
20+
import android.graphics.RectF;
21+
22+
import java.util.List;
23+
24+
/**
25+
* Created by amitshekhar on 06/03/17.
26+
*/
27+
28+
/**
29+
* Generic interface for interacting with different recognition engines.
30+
*/
31+
public interface Classifier {
32+
/**
33+
* An immutable result returned by a Classifier describing what was recognized.
34+
*/
35+
public class Recognition {
36+
/**
37+
* A unique identifier for what has been recognized. Specific to the class, not the instance of
38+
* the object.
39+
*/
40+
private final String id;
41+
42+
/**
43+
* Display name for the recognition.
44+
*/
45+
private final String title;
46+
47+
/**
48+
* A sortable score for how good the recognition is relative to others. Higher should be better.
49+
*/
50+
private final Float confidence;
51+
52+
/**
53+
* Optional location within the source image for the location of the recognized object.
54+
*/
55+
private RectF location;
56+
57+
public Recognition(
58+
final String id, final String title, final Float confidence, final RectF location) {
59+
this.id = id;
60+
this.title = title;
61+
this.confidence = confidence;
62+
this.location = location;
63+
}
64+
65+
public String getId() {
66+
return id;
67+
}
68+
69+
public String getTitle() {
70+
return title;
71+
}
72+
73+
public Float getConfidence() {
74+
return confidence;
75+
}
76+
77+
public RectF getLocation() {
78+
return new RectF(location);
79+
}
80+
81+
public void setLocation(RectF location) {
82+
this.location = location;
83+
}
84+
85+
@Override
86+
public String toString() {
87+
String resultString = "";
88+
if (id != null) {
89+
resultString += "[" + id + "] ";
90+
}
91+
92+
if (title != null) {
93+
resultString += title + " ";
94+
}
95+
96+
if (confidence != null) {
97+
resultString += String.format("(%.1f%%) ", confidence * 100.0f);
98+
}
99+
100+
if (location != null) {
101+
resultString += location + " ";
102+
}
103+
104+
return resultString.trim();
105+
}
106+
}
107+
108+
List<Recognition> recognizeImage(Bitmap bitmap);
109+
110+
void enableStatLogging(final boolean debug);
111+
112+
String getStatString();
113+
114+
void close();
115+
}
116+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
/*
2+
* Copyright (C) 2017 MINDORKS NEXTGEN PRIVATE LIMITED
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.mindorks.tensorflowexample;
18+
19+
import android.content.res.AssetManager;
20+
import android.graphics.Bitmap;
21+
import android.os.Trace;
22+
import android.util.Log;
23+
24+
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
25+
26+
import java.io.BufferedReader;
27+
import java.io.IOException;
28+
import java.io.InputStreamReader;
29+
import java.util.ArrayList;
30+
import java.util.Comparator;
31+
import java.util.List;
32+
import java.util.PriorityQueue;
33+
import java.util.Vector;
34+
35+
/**
36+
* Created by amitshekhar on 06/03/17.
37+
*/
38+
39+
/**
40+
* A classifier specialized to label images using TensorFlow.
41+
*/
42+
public class TensorFlowImageClassifier implements Classifier {
43+
44+
private static final String TAG = "TensorFlowImageClassifier";
45+
46+
// Only return this many results with at least this confidence.
47+
private static final int MAX_RESULTS = 3;
48+
private static final float THRESHOLD = 0.1f;
49+
50+
// Config values.
51+
private String inputName;
52+
private String outputName;
53+
private int inputSize;
54+
private int imageMean;
55+
private float imageStd;
56+
57+
// Pre-allocated buffers.
58+
private Vector<String> labels = new Vector<String>();
59+
private int[] intValues;
60+
private float[] floatValues;
61+
private float[] outputs;
62+
private String[] outputNames;
63+
64+
private TensorFlowInferenceInterface inferenceInterface;
65+
66+
private TensorFlowImageClassifier() {
67+
}
68+
69+
/**
70+
* Initializes a native TensorFlow session for classifying images.
71+
*
72+
* @param assetManager The asset manager to be used to load assets.
73+
* @param modelFilename The filepath of the model GraphDef protocol buffer.
74+
* @param labelFilename The filepath of label file for classes.
75+
* @param inputSize The input size. A square image of inputSize x inputSize is assumed.
76+
* @param imageMean The assumed mean of the image values.
77+
* @param imageStd The assumed std of the image values.
78+
* @param inputName The label of the image input node.
79+
* @param outputName The label of the output node.
80+
* @throws IOException
81+
*/
82+
public static Classifier create(
83+
AssetManager assetManager,
84+
String modelFilename,
85+
String labelFilename,
86+
int inputSize,
87+
int imageMean,
88+
float imageStd,
89+
String inputName,
90+
String outputName)
91+
throws IOException {
92+
TensorFlowImageClassifier c = new TensorFlowImageClassifier();
93+
c.inputName = inputName;
94+
c.outputName = outputName;
95+
96+
// Read the label names into memory.
97+
// TODO(andrewharp): make this handle non-assets.
98+
String actualFilename = labelFilename.split("file:///android_asset/")[1];
99+
Log.i(TAG, "Reading labels from: " + actualFilename);
100+
BufferedReader br = null;
101+
br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
102+
String line;
103+
while ((line = br.readLine()) != null) {
104+
c.labels.add(line);
105+
}
106+
br.close();
107+
108+
c.inferenceInterface = new TensorFlowInferenceInterface();
109+
if (c.inferenceInterface.initializeTensorFlow(assetManager, modelFilename) != 0) {
110+
throw new RuntimeException("TF initialization failed");
111+
}
112+
// The shape of the output is [N, NUM_CLASSES], where N is the batch size.
113+
int numClasses =
114+
(int) c.inferenceInterface.graph().operation(outputName).output(0).shape().size(1);
115+
Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);
116+
117+
// Ideally, inputSize could have been retrieved from the shape of the input operation. Alas,
118+
// the placeholder node for input in the graphdef typically used does not specify a shape, so it
119+
// must be passed in as a parameter.
120+
c.inputSize = inputSize;
121+
c.imageMean = imageMean;
122+
c.imageStd = imageStd;
123+
124+
// Pre-allocate buffers.
125+
c.outputNames = new String[]{outputName};
126+
c.intValues = new int[inputSize * inputSize];
127+
c.floatValues = new float[inputSize * inputSize * 3];
128+
c.outputs = new float[numClasses];
129+
130+
return c;
131+
}
132+
133+
@Override
134+
public List<Recognition> recognizeImage(final Bitmap bitmap) {
135+
// Log this method so that it can be analyzed with systrace.
136+
Trace.beginSection("recognizeImage");
137+
138+
Trace.beginSection("preprocessBitmap");
139+
// Preprocess the image data from 0-255 int to normalized float based
140+
// on the provided parameters.
141+
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
142+
for (int i = 0; i < intValues.length; ++i) {
143+
final int val = intValues[i];
144+
floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd;
145+
floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd;
146+
floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd;
147+
}
148+
Trace.endSection();
149+
150+
// Copy the input data into TensorFlow.
151+
Trace.beginSection("fillNodeFloat");
152+
inferenceInterface.fillNodeFloat(
153+
inputName, new int[]{1, inputSize, inputSize, 3}, floatValues);
154+
Trace.endSection();
155+
156+
// Run the inference call.
157+
Trace.beginSection("runInference");
158+
inferenceInterface.runInference(outputNames);
159+
Trace.endSection();
160+
161+
// Copy the output Tensor back into the output array.
162+
Trace.beginSection("readNodeFloat");
163+
inferenceInterface.readNodeFloat(outputName, outputs);
164+
Trace.endSection();
165+
166+
// Find the best classifications.
167+
PriorityQueue<Recognition> pq =
168+
new PriorityQueue<Recognition>(
169+
3,
170+
new Comparator<Recognition>() {
171+
@Override
172+
public int compare(Recognition lhs, Recognition rhs) {
173+
// Intentionally reversed to put high confidence at the head of the queue.
174+
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
175+
}
176+
});
177+
for (int i = 0; i < outputs.length; ++i) {
178+
if (outputs[i] > THRESHOLD) {
179+
pq.add(
180+
new Recognition(
181+
"" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null));
182+
}
183+
}
184+
final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
185+
int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
186+
for (int i = 0; i < recognitionsSize; ++i) {
187+
recognitions.add(pq.poll());
188+
}
189+
Trace.endSection(); // "recognizeImage"
190+
return recognitions;
191+
}
192+
193+
@Override
194+
public void enableStatLogging(boolean debug) {
195+
inferenceInterface.enableStatLogging(debug);
196+
}
197+
198+
@Override
199+
public String getStatString() {
200+
return inferenceInterface.getStatString();
201+
}
202+
203+
@Override
204+
public void close() {
205+
inferenceInterface.close();
206+
}
207+
}

0 commit comments

Comments
 (0)