Skip to content

Commit

Permalink
WIP adding CocoSsd functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
tirtawr committed Oct 21, 2019
1 parent fcab242 commit 7eab584
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 6 deletions.
55 changes: 55 additions & 0 deletions src/ObjectDetector/CocoSsd/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) 2018 ml5
//
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT
/* eslint max-len: ["error", { "code": 180 }] */

/*
COCO-SSD Object detection
Wraps the coco-ssd model in tfjs to be used in ml5
*/

import * as tf from '@tensorflow/tfjs';
import * as cocoSsd from '@tensorflow-models/coco-ssd';

class CocoSsd {
/**
* Create CocoSsd model. Works on video and images.
* @param {HTMLVideoElement} video - Optional. The video to be used for object detection and classification.
* @param {Object} options - Optional. A set of options.
* @param {function} callback - Optional. A callback function that is called once the model has loaded. If no callback is provided, it will return a promise
* that will be resolved once the model has loaded.
*/
constructor(video, options, callback) {
this.isModelReady = false;
this.video = video;
this.options = options;
this.callback = callback;
cocoSsd.load().then(_cocoSsdModel => {
this.cocoSsdModel = _cocoSsdModel;
callback();
});
}

detect(callback) {
if (this.isModelReady) {
this.cocoSsdModel.detect(this.video).then((predictions) => {
let formattedPredictions = [];
for (let i = 0; i < predictions.length; i++) {
const prediction = predictions[i];
formattedPredictions.push({
label: prediction.class,
confidence: prediction.score,
x: prediction.bbox[0],
y: prediction.bbox[1],
w: prediction.bbox[2],
h: prediction.bbox[3],
});
}
callback(false, formattedPredictions);
})
}
}
}

export default CocoSsd;
24 changes: 18 additions & 6 deletions src/ObjectDetector/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
*/

import YOLO from './YOLO/index';
import CocoSsd from './CocoSsd/index';

class ObjectDetectorBase {
class ObjectDetector {
/**
* @typedef {Object} options
* @property {number} filterBoxesThreshold - Optional. default 0.01
Expand All @@ -30,12 +31,23 @@ class ObjectDetectorBase {
this.video = video;
this.options = options || {};
this.callback = callback;

switch (modelName) {
case 'YOLO':
options.disableDeprecationNotice = true;
this.model = new YOLO(video, options, callback);
break;
case 'CocoSsd':
this.model = new CocoSsd(video, options, callback);
break;
default:
throw new Error('Model name not supported')
}
}
}

const ObjectDetector = (modelName, video, options, callback) => {
options.disableDeprecationNotice = true;
return new YOLO(video, options, callback);
};
detect(callback) {
this.model.detect(callback);
}
}

export default ObjectDetector;

0 comments on commit 7eab584

Please sign in to comment.