Skip to content

Commit

Permalink
Add linting rules for tfjs-models. (tensorflow#333)
Browse files Browse the repository at this point in the history
- Adds lint rules to mirror monorepo
- Fixes lint errors
- Adds CI to run lint
- Refactors body pix and posenet to simplify base models for resnet / mobilenet.
  • Loading branch information
Nikhil Thorat authored Oct 25, 2019
1 parent d101c73 commit 471429d
Show file tree
Hide file tree
Showing 79 changed files with 1,232 additions and 1,542 deletions.
58 changes: 58 additions & 0 deletions .tslint/noImportsFromDistRule.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
'use strict';
var __extends = (this && this.__extends) || (function() {
var extendStatics = function(d, b) {
extendStatics = Object.setPrototypeOf ||
({__proto__: []} instanceof Array && function(d, b) {
d.__proto__ = b;
}) || function(d, b) {
for (var p in b)
if (b.hasOwnProperty(p)) d[p] = b[p];
};
return extendStatics(d, b);
};
return function(d, b) {
extendStatics(d, b);
function __() {
this.constructor = d;
}
d.prototype = b === null ?
Object.create(b) :
(__.prototype = b.prototype, new __());
};
})();
exports.__esModule = true;
var Lint = require('tslint');
var Rule = /** @class */ (function(_super) {
__extends(Rule, _super);
function Rule() {
return _super !== null && _super.apply(this, arguments) || this;
}
Rule.prototype.apply = function(sourceFile) {
return this.applyWithWalker(
new NoImportsFromDistWalker(sourceFile, this.getOptions()));
};
Rule.FAILURE_STRING =
'importing from dist/ is prohibited. Please use public API';
return Rule;
}(Lint.Rules.AbstractRule));
exports.Rule = Rule;
var NoImportsFromDistWalker = /** @class */ (function(_super) {
__extends(NoImportsFromDistWalker, _super);
function NoImportsFromDistWalker() {
return _super !== null && _super.apply(this, arguments) || this;
}
NoImportsFromDistWalker.prototype.visitImportDeclaration = function(node) {
var importFrom = node.moduleSpecifier.getText();
var reg = /@tensorflow\/tfjs[-a-z]*\/dist/;
if (importFrom.match(reg)) {
var fix = new Lint.Replacement(
node.moduleSpecifier.getStart(), node.moduleSpecifier.getWidth(),
importFrom.replace(/\/dist[\/]*/, ''));
this.addFailure(this.createFailure(
node.moduleSpecifier.getStart(), node.moduleSpecifier.getWidth(),
Rule.FAILURE_STRING, fix));
}
_super.prototype.visitImportDeclaration.call(this, node);
};
return NoImportsFromDistWalker;
}(Lint.RuleWalker));
30 changes: 30 additions & 0 deletions .tslint/noImportsFromDistRule.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import * as Lint from 'tslint';
import * as ts from 'typescript';

export class Rule extends Lint.Rules.AbstractRule {
public static FAILURE_STRING =
'importing from dist/ is prohibited. Please use public API';

public apply(sourceFile: ts.SourceFile): Lint.RuleFailure[] {
return this.applyWithWalker(
new NoImportsFromDistWalker(sourceFile, this.getOptions()));
}
}

class NoImportsFromDistWalker extends Lint.RuleWalker {
public visitImportDeclaration(node: ts.ImportDeclaration) {
const importFrom = node.moduleSpecifier.getText();
const reg = /@tensorflow\/tfjs[-a-z]*\/dist/;
if (importFrom.match(reg)) {
const fix = new Lint.Replacement(
node.moduleSpecifier.getStart(), node.moduleSpecifier.getWidth(),
importFrom.replace(/\/dist[\/]*/, ''));

this.addFailure(this.createFailure(
node.moduleSpecifier.getStart(), node.moduleSpecifier.getWidth(),
Rule.FAILURE_STRING, fix));
}

super.visitImportDeclaration(node);
}
}
8 changes: 8 additions & 0 deletions body-pix/cloudbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ steps:
args: ['install']
waitFor: ['yarn-common']

# Lint.
- name: 'node:10'
dir: 'body-pix'
entrypoint: 'yarn'
id: 'lint'
args: ['lint']
waitFor: ['yarn']

# Build.
- name: 'node:10'
dir: 'body-pix'
Expand Down
3 changes: 2 additions & 1 deletion body-pix/demos/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
"scripts": {
"watch": "cross-env NODE_ENV=development parcel index.html --no-hmr --open ",
"build": "cross-env NODE_ENV=production parcel build index.html --public-url ./",
"lint": "eslint ."
"lint": "eslint .",
"link-local": "yalc link"
},
"browser": {
"crypto": false
Expand Down
5 changes: 0 additions & 5 deletions body-pix/demos/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -6243,11 +6243,6 @@ typedarray@^0.0.6:
resolved "https://registry.yarnpkg.com/typedarray/-/typedarray-0.0.6.tgz#867ac74e3864187b1d3d47d996a78ec5c8830777"
integrity sha1-hnrHTjhkGHsdPUfZlqeOxciDB3c=

[email protected]:
version "0.0.54"
resolved "https://registry.yarnpkg.com/typeface-oswald/-/typeface-oswald-0.0.54.tgz#1e253011622cdd50f580c04e7d625e7f449763d7"
integrity sha512-U1WMNp4qfy4/3khIfHMVAIKnNu941MXUfs3+H9R8PFgnoz42Hh9pboSFztWr86zut0eXC8byalmVhfkiKON/8Q==

uncss@^0.17.0:
version "0.17.2"
resolved "https://registry.yarnpkg.com/uncss/-/uncss-0.17.2.tgz#fac1c2429be72108e8a47437c647d58cf9ea66f1"
Expand Down
2 changes: 1 addition & 1 deletion body-pix/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"rollup-plugin-typescript2": "~0.13.0",
"rollup-plugin-uglify": "~3.0.0",
"ts-node": "~5.0.0",
"tslint": "~5.8.0",
"tslint": "~5.18.0",
"typescript": "~3.5.3",
"yalc": "^1.0.0-pre.27"
},
Expand Down
113 changes: 113 additions & 0 deletions body-pix/src/base_model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@

/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tfconv from '@tensorflow/tfjs-converter';
import * as tf from '@tensorflow/tfjs-core';
import {BodyPixOutputStride} from './types';

/**
* BodyPix supports using various convolution neural network models
* (e.g. ResNet and MobileNetV1) as its underlying base model.
* The following BaseModel interface defines a unified interface for
* creating such BodyPix base models. Currently both MobileNet (in
* ./mobilenet.ts) and ResNet (in ./resnet.ts) implements the BaseModel
* interface. New base models that conform to the BaseModel interface can be
* added to BodyPix.
*/
export abstract class BaseModel {
constructor(
protected readonly model: tfconv.GraphModel,
public readonly outputStride: BodyPixOutputStride) {
const inputShape =
this.model.inputs[0].shape as [number, number, number, number];
tf.util.assert(
(inputShape[1] === -1) && (inputShape[2] === -1),
() => `Input shape [${inputShape[1]}, ${inputShape[2]}] ` +
`must both be equal to or -1`);
}

abstract preprocessInput(input: tf.Tensor3D): tf.Tensor3D;

/**
* Predicts intermediate Tensor representations.
*
* @param input The input RGB image of the base model.
* A Tensor of shape: [`inputResolution`, `inputResolution`, 3].
*
* @return A dictionary of base model's intermediate predictions.
* The returned dictionary should contains the following elements:
* - heatmapScores: A Tensor3D that represents the keypoint heatmap scores.
* - offsets: A Tensor3D that represents the offsets.
* - displacementFwd: A Tensor3D that represents the forward displacement.
* - displacementBwd: A Tensor3D that represents the backward displacement.
* - segmentation: A Tensor3D that represents the segmentation of all
* people.
* - longOffsets: A Tensor3D that represents the long offsets used for
* instance grouping.
* - partHeatmaps: A Tensor3D that represents the body part segmentation.
*/
predict(input: tf.Tensor3D): {
heatmapScores: tf.Tensor3D,
offsets: tf.Tensor3D,
displacementFwd: tf.Tensor3D,
displacementBwd: tf.Tensor3D,
segmentation: tf.Tensor3D,
partHeatmaps: tf.Tensor3D,
longOffsets: tf.Tensor3D,
partOffsets: tf.Tensor3D
} {
return tf.tidy(() => {
const asFloat = this.preprocessInput(input.toFloat());
const asBatch = asFloat.expandDims(0);
const results = this.model.predict(asBatch) as tf.Tensor4D[];
const results3d: tf.Tensor3D[] = results.map(y => y.squeeze([0]));
const namedResults = this.nameOutputResults(results3d);

return {
heatmapScores: namedResults.heatmap.sigmoid(),
offsets: namedResults.offsets,
displacementFwd: namedResults.displacementFwd,
displacementBwd: namedResults.displacementBwd,
segmentation: namedResults.segmentation,
partHeatmaps: namedResults.partHeatmaps,
longOffsets: namedResults.longOffsets,
partOffsets: namedResults.partOffsets
};
});
}

// Because MobileNet and ResNet predict() methods output a different order for
// these values, we have a method that needs to be implemented to order them.
abstract nameOutputResults(results: tf.Tensor3D[]): {
heatmap: tf.Tensor3D,
offsets: tf.Tensor3D,
displacementFwd: tf.Tensor3D,
displacementBwd: tf.Tensor3D,
segmentation: tf.Tensor3D,
partHeatmaps: tf.Tensor3D,
longOffsets: tf.Tensor3D,
partOffsets: tf.Tensor3D
};

/**
* Releases the CPU and GPU memory allocated by the model.
*/
dispose() {
this.model.dispose();
}
}
65 changes: 14 additions & 51 deletions body-pix/src/body_pix_model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,56 +19,19 @@
import * as tfconv from '@tensorflow/tfjs-converter';
import * as tf from '@tensorflow/tfjs-core';

import {BaseModel} from './base_model';
import {decodeOnlyPartSegmentation, decodePartSegmentation, toMaskTensor} from './decode_part_map';
import {MobileNet, MobileNetMultiplier} from './mobilenet';
import {MobileNet} from './mobilenet';
import {decodePersonInstanceMasks, decodePersonInstancePartMasks} from './multi_person/decode_instance_masks';
import {decodeMultiplePoses} from './multi_person/decode_multiple_poses';
import {ResNet} from './resnet';
import {mobileNetSavedModel, resNet50SavedModel} from './saved_models';
import {decodeSinglePose} from './sinlge_person/decode_single_pose';
import {decodeSinglePose} from './single_person/decode_single_pose';
import {BodyPixArchitecture, BodyPixInput, BodyPixInternalResolution, BodyPixMultiplier, BodyPixOutputStride, BodyPixQuantBytes, Padding, PartSegmentation, PersonSegmentation} from './types';
import {getInputSize, padAndResizeTo, scaleAndCropToInputTensorShape, scaleAndFlipPoses, toTensorBuffers3D, toValidInternalResolutionNumber} from './util';


const APPLY_SIGMOID_ACTIVATION = true;

/**
* BodyPix supports using various convolution neural network models
* (e.g. ResNet and MobileNetV1) as its underlying base model.
* The following BaseModel interface defines a unified interface for
* creating such BodyPix base models. Currently both MobileNet (in
* ./mobilenet.ts) and ResNet (in ./resnet.ts) implements the BaseModel
* interface. New base models that conform to the BaseModel interface can be
* added to BodyPix.
*/
export interface BaseModel {
// The output stride of the base model.
readonly outputStride: BodyPixOutputStride;

/**
* Predicts intermediate Tensor representations.
*
* @param input The input RGB image of the base model.
* A Tensor of shape: [`inputResolution`, `inputResolution`, 3].
*
* @return A dictionary of base model's intermediate predictions.
* The returned dictionary should contains the following elements:
* - heatmapScores: A Tensor3D that represents the keypoint heatmap scores.
* - offsets: A Tensor3D that represents the offsets.
* - displacementFwd: A Tensor3D that represents the forward displacement.
* - displacementBwd: A Tensor3D that represents the backward displacement.
* - segmentation: A Tensor3D that represents the segmentation of all people.
* - longOffsets: A Tensor3D that represents the long offsets used for
* instance grouping.
* - partHeatmaps: A Tensor3D that represents the body part segmentation.
*/
predict(input: tf.Tensor3D): {[key: string]: tf.Tensor3D};
/**
* Releases the CPU and GPU memory allocated by the model.
*/
dispose(): void;
}

/**
* BodyPix model loading is configurable using the following config dictionary.
*
Expand Down Expand Up @@ -101,7 +64,7 @@ export interface BaseModel {
export interface ModelConfig {
architecture: BodyPixArchitecture;
outputStride: BodyPixOutputStride;
multiplier?: MobileNetMultiplier;
multiplier?: BodyPixMultiplier;
modelUrl?: string;
quantBytes?: BodyPixQuantBytes;
}
Expand Down Expand Up @@ -602,15 +565,15 @@ export class BodyPix {
};
});

const [scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer] =
const [scoresBuf, offsetsBuf, displacementsFwdBuf, displacementsBwdBuf] =
await toTensorBuffers3D([
heatmapScoresRaw, offsetsRaw, displacementFwdRaw, displacementBwdRaw
]);

let poses = await decodeMultiplePoses(
scoresBuffer, offsetsBuffer, displacementsFwdBuffer,
displacementsBwdBuffer, this.baseModel.outputStride,
config.maxDetections, config.scoreThreshold, config.nmsRadius);
let poses = decodeMultiplePoses(
scoresBuf, offsetsBuf, displacementsFwdBuf, displacementsBwdBuf,
this.baseModel.outputStride, config.maxDetections,
config.scoreThreshold, config.nmsRadius);

poses = scaleAndFlipPoses(
poses, [height, width],
Expand Down Expand Up @@ -849,15 +812,15 @@ export class BodyPix {
};
});

const [scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer] =
const [scoresBuf, offsetsBuf, displacementsFwdBuf, displacementsBwdBuf] =
await toTensorBuffers3D([
heatmapScoresRaw, offsetsRaw, displacementFwdRaw, displacementBwdRaw
]);

let poses = await decodeMultiplePoses(
scoresBuffer, offsetsBuffer, displacementsFwdBuffer,
displacementsBwdBuffer, this.baseModel.outputStride,
config.maxDetections, config.scoreThreshold, config.nmsRadius);
let poses = decodeMultiplePoses(
scoresBuf, offsetsBuf, displacementsFwdBuf, displacementsBwdBuf,
this.baseModel.outputStride, config.maxDetections,
config.scoreThreshold, config.nmsRadius);

poses = scaleAndFlipPoses(
poses, [height, width],
Expand Down
Loading

0 comments on commit 471429d

Please sign in to comment.