Skip to content

Commit

Permalink
Add pix2pix support for iOS
Browse files Browse the repository at this point in the history
  • Loading branch information
jfoutts committed Mar 15, 2019
1 parent 0b662f6 commit 8f546cb
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 23 deletions.
41 changes: 22 additions & 19 deletions android/src/main/java/sq/flutter/tflite/TflitePlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ Bitmap feedOutput(ByteBuffer imgData, float mean, float std) {
for (int i = 0; i < outputSize; ++i) {
for (int j = 0; j < outputSize; ++j) {
int pixelValue = 0xFF << 24;
pixelValue |= ((Math.round(Math.min(1.0, Math.max(0.0, imgData.getFloat())) * std + mean) & 0xFF) << 16);
pixelValue |= ((Math.round(Math.min(1.0, Math.max(0.0, imgData.getFloat())) * std + mean) & 0xFF) << 8);
pixelValue |= ((Math.round(Math.min(1.0, Math.max(0.0, imgData.getFloat())) * std + mean) & 0xFF));
pixelValue |= ((Math.round(imgData.getFloat() * std + mean) & 0xFF) << 16);
pixelValue |= ((Math.round(imgData.getFloat() * std + mean) & 0xFF) << 8);
pixelValue |= ((Math.round(imgData.getFloat() * std + mean) & 0xFF));
bitmapRaw.setPixel(j, i, pixelValue);
}
}
Expand All @@ -255,11 +255,14 @@ ByteBuffer feedInputTensor(Bitmap bitmapRaw, float mean, float std) throws IOExc
ByteBuffer imgData = ByteBuffer.allocateDirect(1 * inputSize * inputSize * inputChannels * bytePerChannel);
imgData.order(ByteOrder.nativeOrder());

Matrix matrix = getTransformationMatrix(bitmapRaw.getWidth(), bitmapRaw.getHeight(),
inputSize, inputSize, false);
Bitmap bitmap = Bitmap.createBitmap(inputSize, inputSize, Bitmap.Config.ARGB_8888);
final Canvas canvas = new Canvas(bitmap);
canvas.drawBitmap(bitmapRaw, matrix, null);
Bitmap bitmap = bitmapRaw;
if (bitmapRaw.getWidth() != inputSize || bitmapRaw.getHeight() != inputSize) {
Matrix matrix = getTransformationMatrix(bitmapRaw.getWidth(), bitmapRaw.getHeight(),
inputSize, inputSize, false);
bitmap = Bitmap.createBitmap(inputSize, inputSize, Bitmap.Config.ARGB_8888);
final Canvas canvas = new Canvas(bitmap);
canvas.drawBitmap(bitmapRaw, matrix, null);
}

if (tensor.dataType() == DataType.FLOAT32) {
for (int i = 0; i < inputSize; ++i) {
Expand Down Expand Up @@ -470,12 +473,12 @@ private List<Map<String, Object>> runPix2PixOnImage(HashMap args) throws IOExcep

long startTime = SystemClock.uptimeMillis();
ByteBuffer input = feedInputTensorImage(path, IMAGE_MEAN, IMAGE_STD);
ByteBuffer output = ByteBuffer.allocateDirect(input.position());
ByteBuffer output = ByteBuffer.allocateDirect(input.limit());
output.order(ByteOrder.nativeOrder());
if (input.position() == 0) throw new RuntimeException("Unexpected input position, bad file?");
if (input.limit() == 0) throw new RuntimeException("Unexpected input position, bad file?");
if (output.position() != 0) throw new RuntimeException("Unexpected output position");
tfLite.run(input, output);
if (output.position() != input.position()) throw new RuntimeException("Mismatching input/output position");
if (output.position() != input.limit()) throw new RuntimeException("Mismatching input/output position");

output.flip();
Bitmap bitmapRaw = feedOutput(output, IMAGE_MEAN, IMAGE_STD);
Expand All @@ -500,18 +503,18 @@ private List<Map<String, Object>> runPix2PixOnBinary(HashMap args) throws IOExce

long startTime = SystemClock.uptimeMillis();
ByteBuffer input = ByteBuffer.wrap(binary);
ByteBuffer output = ByteBuffer.allocateDirect(input.position());
ByteBuffer output = ByteBuffer.allocateDirect(input.limit());
output.order(ByteOrder.nativeOrder());

if (input.position() == 0) throw new RuntimeException("Unexpected input position, bad file?");
if (input.limit() == 0) throw new RuntimeException("Unexpected input position, bad file?");
if (output.position() != 0) throw new RuntimeException("Unexpected output position");
tfLite.run(input, output);
Log.v("time", "Generating took " + (SystemClock.uptimeMillis() - startTime));
if (output.position() != input.position()) throw new RuntimeException("Mismatching input/output position");
if (output.position() != input.limit()) throw new RuntimeException("Mismatching input/output position");

final ArrayList<Map<String, Object>> result = new ArrayList<>();
Map<String, Object> res = new HashMap<>();
res.put("binary", output);
res.put("binary", output.array());
result.add(res);
return result;
}
Expand All @@ -528,18 +531,18 @@ private List<Map<String, Object>> runPix2PixOnFrame(HashMap args) throws IOExcep

long startTime = SystemClock.uptimeMillis();
ByteBuffer input = feedInputTensorFrame(bytesList, imageHeight, imageWidth, IMAGE_MEAN, IMAGE_STD, rotation);
ByteBuffer output = ByteBuffer.allocateDirect(input.position());
ByteBuffer output = ByteBuffer.allocateDirect(input.limit());
output.order(ByteOrder.nativeOrder());

if (input.position() == 0) throw new RuntimeException("Unexpected input position, bad file?");
if (input.limit() == 0) throw new RuntimeException("Unexpected input position, bad file?");
if (output.position() != 0) throw new RuntimeException("Unexpected output position");
tfLite.run(input, output);
Log.v("time", "Generating took " + (SystemClock.uptimeMillis() - startTime));
if (output.position() != input.position()) throw new RuntimeException("Mismatching input/output position");
if (output.position() != input.limit()) throw new RuntimeException("Mismatching input/output position");

final ArrayList<Map<String, Object>> result = new ArrayList<>();
Map<String, Object> res = new HashMap<>();
res.put("binary", output);
res.put("binary", output.array());
result.add(res);
return result;
}
Expand Down
175 changes: 171 additions & 4 deletions ios/Classes/TflitePlugin.mm
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
#include <sstream>
#include <string>

#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/string_util.h"
#include "tensorflow/contrib/lite/op_resolver.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/op_resolver.h"

#include "ios_image_load.h"

Expand All @@ -24,6 +24,9 @@
NSMutableArray* detectObjectOnImage(NSDictionary* args);
NSMutableArray* detectObjectOnBinary(NSDictionary* args);
NSMutableArray* detectObjectOnFrame(NSDictionary* args);
NSMutableArray* runPix2PixOnImage(NSDictionary* args);
NSMutableArray* runPix2PixOnBinary(NSDictionary* args);
NSMutableArray* runPix2PixOnFrame(NSDictionary* args);
void close();

@implementation TflitePlugin {
Expand Down Expand Up @@ -68,6 +71,15 @@ - (void)handleMethodCall:(FlutterMethodCall*)call result:(FlutterResult)result {
} else if ([@"detectObjectOnFrame" isEqualToString:call.method]) {
NSMutableArray* inference_result = detectObjectOnFrame(call.arguments);
result(inference_result);
} else if ([@"runPix2PixOnImage" isEqualToString:call.method]) {
NSMutableArray* generated_result = runPix2PixOnImage(call.arguments);
result(generated_result);
} else if ([@"runPix2PixOnBinary" isEqualToString:call.method]) {
NSMutableArray* generated_result = runPix2PixOnBinary(call.arguments);
result(generated_result);
} else if ([@"runPix2PixOnFrame" isEqualToString:call.method]) {
NSMutableArray* generated_result = runPix2PixOnFrame(call.arguments);
result(generated_result);
} else if ([@"close" isEqualToString:call.method]) {
close();
} else {
Expand Down Expand Up @@ -129,6 +141,58 @@ static void LoadLabels(NSString* labels_path,
return @"success";
}

NSMutableData *feedOutputTensor(int outputChannelsIn, float mean, float std, bool convertToUint8,
int *widthOut, int *heightOut) {
assert(interpreter->outputs().size() == 1);
int output = interpreter->outputs()[0];
TfLiteTensor* output_tensor = interpreter->tensor(output);
const int width = output_tensor->dims->data[2];
const int channels = output_tensor->dims->data[3];
const int outputChannels = outputChannelsIn ? outputChannelsIn : channels;
assert(outputChannels >= channels);
if (widthOut) *widthOut = width;
if (heightOut) *heightOut = width;

NSMutableData *data = nil;
if (output_tensor->type == kTfLiteUInt8) {
int size = width*width*outputChannels;
data = [[NSMutableData dataWithCapacity: size] initWithLength: size];
uint8_t* out = (uint8_t*)[data bytes], *outEnd = out + width*width*outputChannels;
const uint8_t* bytes = interpreter->typed_tensor<uint8_t>(output);
while (out != outEnd) {
for (int c = 0; c < channels; c++)
*out++ = *bytes++;
for (int c = 0; c < outputChannels - channels; c++)
*out++ = 255;
}
} else { // kTfLiteFloat32
if (convertToUint8) {
int size = width*width*outputChannels;
data = [[NSMutableData dataWithCapacity: size] initWithLength: size];
uint8_t* out = (uint8_t*)[data bytes], *outEnd = out + width*width*outputChannels;
const float* bytes = interpreter->typed_tensor<float>(output);
while (out != outEnd) {
for (int c = 0; c < channels; c++)
*out++ = (*bytes++ * std) + mean;
for (int c = 0; c < outputChannels - channels; c++)
*out++ = 255;
}
} else { // kTfLiteFloat32
int size = width*width*outputChannels*4;
data = [[NSMutableData dataWithCapacity: size] initWithLength: size];
float* out = (float*)[data bytes], *outEnd = out + width*width*outputChannels;
const float* bytes = interpreter->typed_tensor<float>(output);
while (out != outEnd) {
for (int c = 0; c < channels; c++)
*out++ = (*bytes++ * std) + mean;
for (int c = 0; c < outputChannels - channels; c++)
*out++ = 255;
}
}
}
return data;
}

void feedInputTensorBinary(const FlutterStandardTypedData* typedData, int* input_size) {
assert(interpreter->inputs().size() == 1);
int input = interpreter->inputs()[0];
Expand Down Expand Up @@ -601,6 +665,109 @@ void softmax(float vals[], int count) {
threshold, input_size);
}

NSMutableArray* runPix2PixOnImage(NSDictionary* args) {
const NSString* image_path = args[@"path"];
const float input_mean = [args[@"imageMean"] floatValue];
const float input_std = [args[@"imageStd"] floatValue];

NSMutableArray* empty = [@[] mutableCopy];

if (!interpreter) {
NSLog(@"Failed to construct interpreter.");
return empty;
}

int input_size;
feedInputTensorImage(image_path, input_mean, input_std, &input_size);

if (interpreter->Invoke() != kTfLiteOk) {
NSLog(@"Failed to invoke!");
return empty;
}

int width = 0, height = 0;
NSMutableData* output = feedOutputTensor(4, input_mean, input_std, true, &width, &height);
if (output == NULL)
return empty;

NSString *ext = image_path.pathExtension, *out_path = image_path.stringByDeletingPathExtension;
out_path = [NSString stringWithFormat:@"%@_pix2pix.%@", out_path, ext];
if (SaveImageToFile(output, [out_path UTF8String], width, height, 1)) {
NSMutableArray* results = [NSMutableArray array];
NSMutableDictionary* res = [NSMutableDictionary dictionary];
[res setObject:out_path forKey:@"filename"];
[results addObject:res];
return results;
}

return empty;
}

NSMutableArray* runPix2PixOnBinary(NSDictionary* args) {
const FlutterStandardTypedData* typedData = args[@"binary"];
NSMutableArray* empty = [@[] mutableCopy];

if (!interpreter) {
NSLog(@"Failed to construct interpreter.");
return empty;
}

int input_size;
feedInputTensorBinary(typedData, &input_size);

if (interpreter->Invoke() != kTfLiteOk) {
NSLog(@"Failed to invoke!");
return empty;
}

int width = 0, height = 0;
NSMutableData* output = feedOutputTensor(0, 0, 1, false, &width, &height);
if (output == NULL)
return empty;

FlutterStandardTypedData* ret = [FlutterStandardTypedData typedDataWithBytes: output];
NSMutableArray* results = [NSMutableArray array];
NSMutableDictionary* res = [NSMutableDictionary dictionary];
[res setObject:ret forKey:@"binary"];
[results addObject:res];
return results;
}

NSMutableArray* runPix2PixOnFrame(NSDictionary* args) {
const FlutterStandardTypedData* typedData = args[@"bytesList"][0];
const int image_height = [args[@"imageHeight"] intValue];
const int image_width = [args[@"imageWidth"] intValue];
const float input_mean = [args[@"imageMean"] floatValue];
const float input_std = [args[@"imageStd"] floatValue];
NSMutableArray* empty = [@[] mutableCopy];

if (!interpreter) {
NSLog(@"Failed to construct interpreter.");
return empty;
}

int input_size;
int image_channels = 4;
feedInputTensorFrame(typedData, &input_size, image_height, image_width, image_channels, input_mean, input_std);

if (interpreter->Invoke() != kTfLiteOk) {
NSLog(@"Failed to invoke!");
return empty;
}

int width = 0, height = 0;
NSMutableData* output = feedOutputTensor(0, 0, 1, false, &width, &height);
if (output == NULL)
return empty;

FlutterStandardTypedData* ret = [FlutterStandardTypedData typedDataWithBytes: output];
NSMutableArray* results = [NSMutableArray array];
NSMutableDictionary* res = [NSMutableDictionary dictionary];
[res setObject:ret forKey:@"binary"];
[results addObject:res];
return results;
}

void close() {
interpreter.release();
interpreter = NULL;
Expand Down
6 changes: 6 additions & 0 deletions ios/Classes/ios_image_load.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,9 @@ std::vector<uint8_t> LoadImageFromFile(const char* file_name,
int* out_height,
int* out_channels);

BOOL SaveImageToFile(NSMutableData*,
const char* file_name,
int width,
int height,
int bytesPerPixel);

27 changes: 27 additions & 0 deletions ios/Classes/ios_image_load.mm
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#import <Flutter/Flutter.h>
#include "ios_image_load.h"

#include <stdlib.h>
Expand Down Expand Up @@ -70,3 +71,29 @@
*out_channels = channels;
return result;
}

BOOL SaveImageToFile(NSMutableData *image, const char* file_name, int width, int height, int bytesPerPixel) {
const int channels = 4;
CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB();
CGContextRef context = CGBitmapContextCreate([image mutableBytes], width, height,
bytesPerPixel*8, width*channels*bytesPerPixel, color_space,
kCGImageAlphaPremultipliedLast | (bytesPerPixel == 4 ? kCGBitmapFloatComponents : kCGBitmapByteOrder32Big));
CGColorSpaceRelease(color_space);
if (context == nil) return NO;

CGImageRef imgRef = CGBitmapContextCreateImage(context);
CGContextRelease(context);
if (imgRef == nil) return NO;

UIImage* img = [UIImage imageWithCGImage:imgRef];
CGImageRelease(imgRef);
if (img == nil) return NO;

NSData *data = UIImagePNGRepresentation(img);
if (data == nil) return NO;

FILE* file_handle = fopen(file_name, "wb");
BOOL ret = data.length == fwrite([data bytes], 1, data.length, file_handle);
fclose(file_handle);
return ret;
}

0 comments on commit 8f546cb

Please sign in to comment.