forked from tensorflow/tfjs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcallbacks.ts
329 lines (307 loc) · 11 KB
/
callbacks.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {CustomCallback, Logs, nextFrame, util} from '@tensorflow/tfjs';
import * as path from 'path';
import * as ProgressBar from 'progress';
import {summaryFileWriter, SummaryFileWriter} from './tensorboard';
// A helper class created for testing with the jasmine `spyOn` method, which
// operates only on member methods of objects.
// tslint:disable-next-line:no-any
export const progressBarHelper: {ProgressBar: any, log: Function} = {
ProgressBar,
log: console.log
};
/**
* Terminal-based progress bar callback for tf.Model.fit().
*/
export class ProgbarLogger extends CustomCallback {
private numTrainBatchesPerEpoch: number;
private progressBar: ProgressBar;
private currentEpochBegin: number;
private epochDurationMillis: number;
private usPerStep: number;
private batchesInLatestEpoch: number;
private terminalWidth: number;
private readonly RENDER_THROTTLE_MS = 50;
/**
* Construtor of LoggingCallback.
*/
constructor() {
super({
onTrainBegin: async (logs?: Logs) => {
const samples = this.params.samples as number;
const batchSize = this.params.batchSize as number;
const steps = this.params.steps as number;
if (samples != null || steps != null) {
this.numTrainBatchesPerEpoch =
samples != null ? Math.ceil(samples / batchSize) : steps;
} else {
// Undetermined number of batches per epoch, e.g., due to
// `fitDataset()` without `batchesPerEpoch`.
this.numTrainBatchesPerEpoch = 0;
}
},
onEpochBegin: async (epoch: number, logs?: Logs) => {
progressBarHelper.log(`Epoch ${epoch + 1} / ${this.params.epochs}`);
this.currentEpochBegin = util.now();
this.epochDurationMillis = null;
this.usPerStep = null;
this.batchesInLatestEpoch = 0;
this.terminalWidth = process.stderr.columns;
},
onBatchEnd: async (batch: number, logs?: Logs) => {
this.batchesInLatestEpoch++;
if (batch === 0) {
this.progressBar = new progressBarHelper.ProgressBar(
'eta=:eta :bar :placeholderForLossesAndMetrics', {
width: Math.floor(0.5 * this.terminalWidth),
total: this.numTrainBatchesPerEpoch + 1,
head: `>`,
renderThrottle: this.RENDER_THROTTLE_MS
});
}
const maxMetricsStringLength =
Math.floor(this.terminalWidth * 0.5 - 12);
const tickTokens = {
placeholderForLossesAndMetrics:
this.formatLogsAsMetricsContent(logs, maxMetricsStringLength)
};
if (this.numTrainBatchesPerEpoch === 0) {
// Undetermined number of batches per epoch.
this.progressBar.tick(0, tickTokens);
} else {
this.progressBar.tick(tickTokens);
}
await nextFrame();
if (batch === this.numTrainBatchesPerEpoch - 1) {
this.epochDurationMillis = util.now() - this.currentEpochBegin;
this.usPerStep = this.params.samples != null ?
this.epochDurationMillis / (this.params.samples as number) * 1e3 :
this.epochDurationMillis / this.batchesInLatestEpoch * 1e3;
}
},
onEpochEnd: async (epoch: number, logs?: Logs) => {
if (this.epochDurationMillis == null) {
// In cases where the number of batches per epoch is not determined,
// the calculation of the per-step duration is done at the end of the
// epoch. N.B., this includes the time spent on validation.
this.epochDurationMillis = util.now() - this.currentEpochBegin;
this.usPerStep =
this.epochDurationMillis / this.batchesInLatestEpoch * 1e3;
}
this.progressBar.tick({placeholderForLossesAndMetrics: ''});
const lossesAndMetricsString = this.formatLogsAsMetricsContent(logs);
progressBarHelper.log(
`${this.epochDurationMillis.toFixed(0)}ms ` +
`${this.usPerStep.toFixed(0)}us/step - ` +
`${lossesAndMetricsString}`);
await nextFrame();
},
});
}
private formatLogsAsMetricsContent(logs: Logs, maxMetricsLength?: number):
string {
let metricsContent = '';
const keys = Object.keys(logs).sort();
for (const key of keys) {
if (this.isFieldRelevant(key)) {
const value = logs[key];
metricsContent += `${key}=${getSuccinctNumberDisplay(value)} `;
}
}
if (maxMetricsLength != null && metricsContent.length > maxMetricsLength) {
// Cut off metrics strings that are too long to avoid new lines being
// constantly created.
metricsContent = metricsContent.slice(0, maxMetricsLength - 3) + '...';
}
return metricsContent;
}
private isFieldRelevant(key: string) {
return key !== 'batch' && key !== 'size';
}
}
const BASE_NUM_DIGITS = 2;
const MAX_NUM_DECIMAL_PLACES = 4;
/**
* Get a succint string representation of a number.
*
* Uses decimal notation if the number isn't too small.
* Otherwise, use engineering notation.
*
* @param x Input number.
* @return Succinct string representing `x`.
*/
export function getSuccinctNumberDisplay(x: number): string {
const decimalPlaces = getDisplayDecimalPlaces(x);
return decimalPlaces > MAX_NUM_DECIMAL_PLACES ?
x.toExponential(BASE_NUM_DIGITS) :
x.toFixed(decimalPlaces);
}
/**
* Determine the number of decimal places to display.
*
* @param x Number to display.
* @return Number of decimal places to display for `x`.
*/
export function getDisplayDecimalPlaces(x: number): number {
if (!Number.isFinite(x) || x === 0 || x > 1 || x < -1) {
return BASE_NUM_DIGITS;
} else {
return BASE_NUM_DIGITS - Math.floor(Math.log10(Math.abs(x)));
}
}
export interface TensorBoardCallbackArgs {
/**
* The frequency at which loss and metric values are written to logs.
*
* Currently supported options are:
*
* - 'batch': Write logs at the end of every batch of training, in addition
* to the end of every epoch of training.
* - 'epoch': Write logs at the end of every epoch of training.
*
* Note that writing logs too often slows down the training.
*
* Default: 'epoch'.
*/
updateFreq?: 'batch'|'epoch';
}
/**
* Callback for logging to TensorBoard during training.
*
* Users are expected to access this class through the `tensorBoardCallback()`
* factory method instead.
*/
export class TensorBoardCallback extends CustomCallback {
private trainWriter: SummaryFileWriter;
private valWriter: SummaryFileWriter;
private batchesSeen: number;
private epochsSeen: number;
private readonly args: TensorBoardCallbackArgs;
constructor(readonly logdir = './logs', args?: TensorBoardCallbackArgs) {
super({
onBatchEnd: async (batch: number, logs?: Logs) => {
this.batchesSeen++;
if (this.args.updateFreq !== 'epoch') {
this.logMetrics(logs, 'batch_', this.batchesSeen);
}
},
onEpochEnd: async (epoch: number, logs?: Logs) => {
this.epochsSeen++;
this.logMetrics(logs, 'epoch_', this.epochsSeen);
},
onTrainEnd: async (logs?: Logs) => {
if (this.trainWriter != null) {
this.trainWriter.flush();
}
if (this.valWriter != null) {
this.valWriter.flush();
}
}
});
this.args = args == null ? {} : args;
if (this.args.updateFreq == null) {
this.args.updateFreq = 'epoch';
}
util.assert(
['batch', 'epoch'].indexOf(this.args.updateFreq) !== -1,
() => `Expected updateFreq to be 'batch' or 'epoch', but got ` +
`${this.args.updateFreq}`);
this.batchesSeen = 0;
this.epochsSeen = 0;
}
private logMetrics(logs: Logs, prefix: string, step: number) {
for (const key in logs) {
if (key === 'batch' || key === 'size' || key === 'num_steps') {
continue;
}
const VAL_PREFIX = 'val_';
if (key.startsWith(VAL_PREFIX)) {
this.ensureValWriterCreated();
const scalarName = prefix + key.slice(VAL_PREFIX.length);
this.valWriter.scalar(scalarName, logs[key], step);
} else {
this.ensureTrainWriterCreated();
this.trainWriter.scalar(`${prefix}${key}`, logs[key], step);
}
}
}
private ensureTrainWriterCreated() {
this.trainWriter = summaryFileWriter(path.join(this.logdir, 'train'));
}
private ensureValWriterCreated() {
this.valWriter = summaryFileWriter(path.join(this.logdir, 'val'));
}
}
/**
* Callback for logging to TensorBoard during training.
*
* Writes the loss and metric values (if any) to the specified log directory
* (`logdir`) which can be ingested and visualized by TensorBoard.
* This callback is usually passed as a callback to `tf.Model.fit()` or
* `tf.Model.fitDataset()` calls during model training. The frequency at which
* the values are logged can be controlled with the `updateFreq` field of the
* configuration object (2nd argument).
*
* Usage example:
* ```js
* // Constructor a toy multilayer-perceptron regressor for demo purpose.
* const model = tf.sequential();
* model.add(
* tf.layers.dense({units: 100, activation: 'relu', inputShape: [200]}));
* model.add(tf.layers.dense({units: 1}));
* model.compile({
* loss: 'meanSquaredError',
* optimizer: 'sgd',
* metrics: ['MAE']
* });
*
* // Generate some random fake data for demo purpose.
* const xs = tf.randomUniform([10000, 200]);
* const ys = tf.randomUniform([10000, 1]);
* const valXs = tf.randomUniform([1000, 200]);
* const valYs = tf.randomUniform([1000, 1]);
*
* // Start model training process.
* await model.fit(xs, ys, {
* epochs: 100,
* validationData: [valXs, valYs],
* // Add the tensorBoard callback here.
* callbacks: tf.node.tensorBoard('/tmp/fit_logs_1')
* });
* ```
*
* Then you can use the following commands to point tensorboard
* to the logdir:
*
* ```sh
* pip install tensorboard # Unless you've already installed it.
* tensorboard --logdir /tmp/fit_logs_1
* ```
*
* @param logdir Directory to which the logs will be written.
* @param args Optional configuration arguments.
* @returns An instance of `TensorBoardCallback`, which is a subclass of
* `tf.CustomCallback`.
*/
/**
* @doc {heading: 'TensorBoard', namespace: 'node'}
*/
export function tensorBoard(
logdir = './logs', args?: TensorBoardCallbackArgs): TensorBoardCallback {
return new TensorBoardCallback(logdir, args);
}