Skip to content

Commit

Permalink
added support for gpu delegate on android
Browse files Browse the repository at this point in the history
  • Loading branch information
AbdulRashidReshamwala committed Jun 4, 2020
1 parent 4caff96 commit 5c1b8cd
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
1 change: 1 addition & 0 deletions android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ android {

dependencies {
compile 'org.tensorflow:tensorflow-lite:+'
compile 'org.tensorflow:tensorflow-lite-gpu:+'
}
}
9 changes: 9 additions & 0 deletions android/src/main/java/sq/flutter/tflite/TflitePlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.Tensor;

import org.tensorflow.lite.gpu.GpuDelegate;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
Expand Down Expand Up @@ -195,6 +197,7 @@ public void onMethodCall(MethodCall call, Result result) {
private String loadModel(HashMap args) throws IOException {
String model = args.get("model").toString();
Object isAssetObj = args.get("isAsset");
Object useGpuDelegateObj = args.get("useGpuDelegate");
boolean isAsset = isAssetObj == null ? false : (boolean) isAssetObj;
MappedByteBuffer buffer = null;
String key = null;
Expand All @@ -216,8 +219,14 @@ private String loadModel(HashMap args) throws IOException {
}

int numThreads = (int) args.get("numThreads");
boolean useGpuDelegate = useGpuDelegateObj == null ? false : (boolean) useGpuDelegate;

final Interpreter.Options tfliteOptions = new Interpreter.Options();
tfliteOptions.setNumThreads(numThreads);
if (useGpuDelegate){
GpuDelegate delegate = new GpuDelegate();
tfliteOptions.addDelegate(delegate)
}
tfLite = new Interpreter(buffer, tfliteOptions);

String labels = args.get("labels").toString();
Expand Down
13 changes: 7 additions & 6 deletions lib/tflite.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@ import 'package:flutter/services.dart';
class Tflite {
static const MethodChannel _channel = const MethodChannel('tflite');

static Future<String> loadModel({
@required String model,
String labels = "",
int numThreads = 1,
bool isAsset = true,
}) async {
static Future<String> loadModel(
{@required String model,
String labels = "",
int numThreads = 1,
bool isAsset = true,
bool useGpuDelegate = false}) async {
return await _channel.invokeMethod(
'loadModel',
{
"model": model,
"labels": labels,
"numThreads": numThreads,
"isAsset": isAsset,
'useGpuDelegate': useGpuDelegate
},
);
}
Expand Down

0 comments on commit 5c1b8cd

Please sign in to comment.