Skip to content

Commit

Permalink
Merge pull request shaqian#102 from amerchDev/feature/optionalAssetLo…
Browse files Browse the repository at this point in the history
…cation

Tflite.loadModel should handle resources outside packaged assets
  • Loading branch information
shaqian authored Apr 22, 2020
2 parents 00a5c6f + c945a98 commit 850eaad
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 16 deletions.
43 changes: 32 additions & 11 deletions android/src/main/java/sq/flutter/tflite/TflitePlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.tensorflow.lite.Tensor;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ByteArrayOutputStream;
Expand Down Expand Up @@ -193,14 +194,26 @@ public void onMethodCall(MethodCall call, Result result) {

private String loadModel(HashMap args) throws IOException {
String model = args.get("model").toString();
AssetManager assetManager = mRegistrar.context().getAssets();
String key = mRegistrar.lookupKeyForAsset(model);
AssetFileDescriptor fileDescriptor = assetManager.openFd(key);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
MappedByteBuffer buffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
Object isAssetObj = args.get("isAsset");
boolean isAsset = isAssetObj == null ? false : (boolean) isAssetObj;
MappedByteBuffer buffer = null;
String key = null;
AssetManager assetManager = null;
if (isAsset) {
assetManager = mRegistrar.context().getAssets();
key = mRegistrar.lookupKeyForAsset(model);
AssetFileDescriptor fileDescriptor = assetManager.openFd(key);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
buffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
} else {
FileInputStream inputStream = new FileInputStream(new File(model));
FileChannel fileChannel = inputStream.getChannel();
long declaredLength = fileChannel.size();
buffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, declaredLength);
}

int numThreads = (int) args.get("numThreads");
final Interpreter.Options tfliteOptions = new Interpreter.Options();
Expand All @@ -210,8 +223,12 @@ private String loadModel(HashMap args) throws IOException {
String labels = args.get("labels").toString();

if (labels.length() > 0) {
key = mRegistrar.lookupKeyForAsset(labels);
loadLabels(assetManager, key);
if (isAsset) {
key = mRegistrar.lookupKeyForAsset(labels);
loadLabels(assetManager, key);
} else {
loadLabels(null, labels);
}
}

return "success";
Expand All @@ -220,7 +237,11 @@ private String loadModel(HashMap args) throws IOException {
private void loadLabels(AssetManager assetManager, String path) {
BufferedReader br;
try {
br = new BufferedReader(new InputStreamReader(assetManager.open(path)));
if (assetManager != null) {
br = new BufferedReader(new InputStreamReader(assetManager.open(path)));
} else {
br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(path))));
}
String line;
labels = new Vector<>();
while ((line = br.readLine()) != null) {
Expand Down
23 changes: 19 additions & 4 deletions ios/Classes/TflitePlugin.mm
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,17 @@ static void LoadLabels(NSString* labels_path,
}

NSString* loadModel(NSObject<FlutterPluginRegistrar>* _registrar, NSDictionary* args) {
NSString* key = [_registrar lookupKeyForAsset:args[@"model"]];
NSString* graph_path = [[NSBundle mainBundle] pathForResource:key ofType:nil];
NSString* graph_path;
NSString* key;
NSNumber* isAssetNumber = args[@"isAsset"];
bool isAsset = [isAssetNumber boolValue];
if(isAsset){
key = [_registrar lookupKeyForAsset:args[@"model"]];
graph_path = [[NSBundle mainBundle] pathForResource:key ofType:nil];
}else{
graph_path = args[@"model"];
}

const int num_threads = [args[@"numThreads"] intValue];

model = tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]);
Expand All @@ -142,8 +151,13 @@ static void LoadLabels(NSString* labels_path,
LOG(INFO) << "resolved reporter";

if ([args[@"labels"] length] > 0) {
key = [_registrar lookupKeyForAsset:args[@"labels"]];
NSString* labels_path = [[NSBundle mainBundle] pathForResource:key ofType:nil];
NSString* labels_path;
if(isAsset){
key = [_registrar lookupKeyForAsset:args[@"labels"]];
labels_path = [[NSBundle mainBundle] pathForResource:key ofType:nil];
}else{
labels_path = args[@"labels"];
}
LoadLabels(labels_path, &labels);
}

Expand All @@ -163,6 +177,7 @@ static void LoadLabels(NSString* labels_path,
return @"success";
}


void runTflite(NSDictionary* args, TfLiteStatusCallback cb) {
const bool asynch = [args[@"asynch"] boolValue];
if (asynch) {
Expand Down
8 changes: 7 additions & 1 deletion lib/tflite.dart
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@ class Tflite {
@required String model,
String labels = "",
int numThreads = 1,
bool isAsset = true,
}) async {
return await _channel.invokeMethod(
'loadModel',
{"model": model, "labels": labels, "numThreads": numThreads},
{
"model": model,
"labels": labels,
"numThreads": numThreads,
"isAsset": isAsset,
},
);
}

Expand Down

0 comments on commit 850eaad

Please sign in to comment.