forked from tensorflow/tfjs-models
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add linting rules for tfjs-models. (tensorflow#333)
- Adds lint rules to mirror monorepo - Fixes lint errors - Adds CI to run lint - Refactors body pix and posenet to simplify base models for resnet / mobilenet.
- Loading branch information
Nikhil Thorat
authored
Oct 25, 2019
1 parent
d101c73
commit 471429d
Showing
79 changed files
with
1,232 additions
and
1,542 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
'use strict'; | ||
var __extends = (this && this.__extends) || (function() { | ||
var extendStatics = function(d, b) { | ||
extendStatics = Object.setPrototypeOf || | ||
({__proto__: []} instanceof Array && function(d, b) { | ||
d.__proto__ = b; | ||
}) || function(d, b) { | ||
for (var p in b) | ||
if (b.hasOwnProperty(p)) d[p] = b[p]; | ||
}; | ||
return extendStatics(d, b); | ||
}; | ||
return function(d, b) { | ||
extendStatics(d, b); | ||
function __() { | ||
this.constructor = d; | ||
} | ||
d.prototype = b === null ? | ||
Object.create(b) : | ||
(__.prototype = b.prototype, new __()); | ||
}; | ||
})(); | ||
exports.__esModule = true; | ||
var Lint = require('tslint'); | ||
var Rule = /** @class */ (function(_super) { | ||
__extends(Rule, _super); | ||
function Rule() { | ||
return _super !== null && _super.apply(this, arguments) || this; | ||
} | ||
Rule.prototype.apply = function(sourceFile) { | ||
return this.applyWithWalker( | ||
new NoImportsFromDistWalker(sourceFile, this.getOptions())); | ||
}; | ||
Rule.FAILURE_STRING = | ||
'importing from dist/ is prohibited. Please use public API'; | ||
return Rule; | ||
}(Lint.Rules.AbstractRule)); | ||
exports.Rule = Rule; | ||
var NoImportsFromDistWalker = /** @class */ (function(_super) { | ||
__extends(NoImportsFromDistWalker, _super); | ||
function NoImportsFromDistWalker() { | ||
return _super !== null && _super.apply(this, arguments) || this; | ||
} | ||
NoImportsFromDistWalker.prototype.visitImportDeclaration = function(node) { | ||
var importFrom = node.moduleSpecifier.getText(); | ||
var reg = /@tensorflow\/tfjs[-a-z]*\/dist/; | ||
if (importFrom.match(reg)) { | ||
var fix = new Lint.Replacement( | ||
node.moduleSpecifier.getStart(), node.moduleSpecifier.getWidth(), | ||
importFrom.replace(/\/dist[\/]*/, '')); | ||
this.addFailure(this.createFailure( | ||
node.moduleSpecifier.getStart(), node.moduleSpecifier.getWidth(), | ||
Rule.FAILURE_STRING, fix)); | ||
} | ||
_super.prototype.visitImportDeclaration.call(this, node); | ||
}; | ||
return NoImportsFromDistWalker; | ||
}(Lint.RuleWalker)); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import * as Lint from 'tslint'; | ||
import * as ts from 'typescript'; | ||
|
||
export class Rule extends Lint.Rules.AbstractRule { | ||
public static FAILURE_STRING = | ||
'importing from dist/ is prohibited. Please use public API'; | ||
|
||
public apply(sourceFile: ts.SourceFile): Lint.RuleFailure[] { | ||
return this.applyWithWalker( | ||
new NoImportsFromDistWalker(sourceFile, this.getOptions())); | ||
} | ||
} | ||
|
||
class NoImportsFromDistWalker extends Lint.RuleWalker { | ||
public visitImportDeclaration(node: ts.ImportDeclaration) { | ||
const importFrom = node.moduleSpecifier.getText(); | ||
const reg = /@tensorflow\/tfjs[-a-z]*\/dist/; | ||
if (importFrom.match(reg)) { | ||
const fix = new Lint.Replacement( | ||
node.moduleSpecifier.getStart(), node.moduleSpecifier.getWidth(), | ||
importFrom.replace(/\/dist[\/]*/, '')); | ||
|
||
this.addFailure(this.createFailure( | ||
node.moduleSpecifier.getStart(), node.moduleSpecifier.getWidth(), | ||
Rule.FAILURE_STRING, fix)); | ||
} | ||
|
||
super.visitImportDeclaration(node); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6243,11 +6243,6 @@ typedarray@^0.0.6: | |
resolved "https://registry.yarnpkg.com/typedarray/-/typedarray-0.0.6.tgz#867ac74e3864187b1d3d47d996a78ec5c8830777" | ||
integrity sha1-hnrHTjhkGHsdPUfZlqeOxciDB3c= | ||
|
||
[email protected]: | ||
version "0.0.54" | ||
resolved "https://registry.yarnpkg.com/typeface-oswald/-/typeface-oswald-0.0.54.tgz#1e253011622cdd50f580c04e7d625e7f449763d7" | ||
integrity sha512-U1WMNp4qfy4/3khIfHMVAIKnNu941MXUfs3+H9R8PFgnoz42Hh9pboSFztWr86zut0eXC8byalmVhfkiKON/8Q== | ||
|
||
uncss@^0.17.0: | ||
version "0.17.2" | ||
resolved "https://registry.yarnpkg.com/uncss/-/uncss-0.17.2.tgz#fac1c2429be72108e8a47437c647d58cf9ea66f1" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
|
||
/** | ||
* @license | ||
* Copyright 2019 Google Inc. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* ============================================================================= | ||
*/ | ||
|
||
import * as tfconv from '@tensorflow/tfjs-converter'; | ||
import * as tf from '@tensorflow/tfjs-core'; | ||
import {BodyPixOutputStride} from './types'; | ||
|
||
/** | ||
* BodyPix supports using various convolution neural network models | ||
* (e.g. ResNet and MobileNetV1) as its underlying base model. | ||
* The following BaseModel interface defines a unified interface for | ||
* creating such BodyPix base models. Currently both MobileNet (in | ||
* ./mobilenet.ts) and ResNet (in ./resnet.ts) implements the BaseModel | ||
* interface. New base models that conform to the BaseModel interface can be | ||
* added to BodyPix. | ||
*/ | ||
export abstract class BaseModel { | ||
constructor( | ||
protected readonly model: tfconv.GraphModel, | ||
public readonly outputStride: BodyPixOutputStride) { | ||
const inputShape = | ||
this.model.inputs[0].shape as [number, number, number, number]; | ||
tf.util.assert( | ||
(inputShape[1] === -1) && (inputShape[2] === -1), | ||
() => `Input shape [${inputShape[1]}, ${inputShape[2]}] ` + | ||
`must both be equal to or -1`); | ||
} | ||
|
||
abstract preprocessInput(input: tf.Tensor3D): tf.Tensor3D; | ||
|
||
/** | ||
* Predicts intermediate Tensor representations. | ||
* | ||
* @param input The input RGB image of the base model. | ||
* A Tensor of shape: [`inputResolution`, `inputResolution`, 3]. | ||
* | ||
* @return A dictionary of base model's intermediate predictions. | ||
* The returned dictionary should contains the following elements: | ||
* - heatmapScores: A Tensor3D that represents the keypoint heatmap scores. | ||
* - offsets: A Tensor3D that represents the offsets. | ||
* - displacementFwd: A Tensor3D that represents the forward displacement. | ||
* - displacementBwd: A Tensor3D that represents the backward displacement. | ||
* - segmentation: A Tensor3D that represents the segmentation of all | ||
* people. | ||
* - longOffsets: A Tensor3D that represents the long offsets used for | ||
* instance grouping. | ||
* - partHeatmaps: A Tensor3D that represents the body part segmentation. | ||
*/ | ||
predict(input: tf.Tensor3D): { | ||
heatmapScores: tf.Tensor3D, | ||
offsets: tf.Tensor3D, | ||
displacementFwd: tf.Tensor3D, | ||
displacementBwd: tf.Tensor3D, | ||
segmentation: tf.Tensor3D, | ||
partHeatmaps: tf.Tensor3D, | ||
longOffsets: tf.Tensor3D, | ||
partOffsets: tf.Tensor3D | ||
} { | ||
return tf.tidy(() => { | ||
const asFloat = this.preprocessInput(input.toFloat()); | ||
const asBatch = asFloat.expandDims(0); | ||
const results = this.model.predict(asBatch) as tf.Tensor4D[]; | ||
const results3d: tf.Tensor3D[] = results.map(y => y.squeeze([0])); | ||
const namedResults = this.nameOutputResults(results3d); | ||
|
||
return { | ||
heatmapScores: namedResults.heatmap.sigmoid(), | ||
offsets: namedResults.offsets, | ||
displacementFwd: namedResults.displacementFwd, | ||
displacementBwd: namedResults.displacementBwd, | ||
segmentation: namedResults.segmentation, | ||
partHeatmaps: namedResults.partHeatmaps, | ||
longOffsets: namedResults.longOffsets, | ||
partOffsets: namedResults.partOffsets | ||
}; | ||
}); | ||
} | ||
|
||
// Because MobileNet and ResNet predict() methods output a different order for | ||
// these values, we have a method that needs to be implemented to order them. | ||
abstract nameOutputResults(results: tf.Tensor3D[]): { | ||
heatmap: tf.Tensor3D, | ||
offsets: tf.Tensor3D, | ||
displacementFwd: tf.Tensor3D, | ||
displacementBwd: tf.Tensor3D, | ||
segmentation: tf.Tensor3D, | ||
partHeatmaps: tf.Tensor3D, | ||
longOffsets: tf.Tensor3D, | ||
partOffsets: tf.Tensor3D | ||
}; | ||
|
||
/** | ||
* Releases the CPU and GPU memory allocated by the model. | ||
*/ | ||
dispose() { | ||
this.model.dispose(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.