-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathpredict.mjs
115 lines (95 loc) · 3.59 KB
/
predict.mjs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
let modelKey, model, statusEl, predictionsContainerEl, predictionsEl, connection, gestureClasses, feedbackEl;
const predict = (gesture) => {
tf.tidy(() => {
const input = tf.tensor2d([gesture], [1, 50]);
const predictOut = model.predict(input);
const _logits = Array.from(predictOut.dataSync());
const winner = gestureClasses[predictOut.argMax(-1).dataSync()[0]];
if (winner) {
const randCol = '#' + parseInt(Math.random() * 0xffffff).toString(16);
predictionsContainerEl.style.setProperty("background", randCol);
predictionsEl.innerHTML = winner;
} else {
predictionsEl.innerHTML = "Unknown gesture";
}
});
}
const onGesture = (buf) => {
const rawGesture = JSON.parse(buf.match(/\[[0-9\-,]+\]/))
if (!rawGesture) return;
let gesture = rawGesture.slice(0, 50);
while (gesture.length !== (50)) gesture.push(0);
try {
predict(gesture);
} catch(e) {
console.log("Error predicting", e)
}
}
function disconnect() {
if (connection) {
connection.close();
connection = undefined;
statusEl.innerHTML = "Disconnected";
}
}
function connect() {
disconnect();
// Update status text
statusEl.innerHTML = "Conecting to BangleJS..."
// Clear results panel
predictionsEl.innerHTML += "";
Puck.connect((conn) => {
if (!conn) {
statusEl.innerHTML = "Disconnected";
return;
}
feedbackEl.classList.remove("ribbon-neutral", "ribbon-connected", "ribbon-connecting", "ribbon-complete");
feedbackEl.classList.add("ribbon-connected");
statusEl.innerHTML = "Connected... please wait for initialization to complete";
connection = conn;
let buf = "";
connection.on("data", function(d) {
buf += d;
let i = buf.indexOf("\n");
while (i >= 0) {
onGesture(buf.substr(0,i));
buf = buf.substr(i+1);
i = buf.indexOf("\n");
}
});
// First, reset Puck.js
connection.write("reset();\n", function() {
// Wait for it to reset itself
const gestureProgram = "Bangle.on('gesture', g => Bluetooth.println(JSON.stringify(g))); NRF.on('disconnect', () => reset());\n"
const evt = new CustomEvent("bangleConnected");
setTimeout(() => {
connection.write(gestureProgram, () => document.dispatchEvent(evt));
}, 1500);
});
});
};
const setup = async (opts) => {
statusEl = document.getElementById("status");
predictionsContainerEl = document.getElementById("predictionsContainer")
predictionsEl = document.getElementById("predictions");
feedbackEl = document.getElementById("feedback");
modelKey = opts.modelKey;
gestureClasses = opts.gestureClasses;
statusEl.innerHTML = "Loading model";
feedbackEl.classList.remove("ribbon-neutral", "ribbon-connected", "ribbon-connecting", "ribbon-complete");
feedbackEl.classList.add("ribbon-connecting");
model = await tf.loadLayersModel(`localstorage://${modelKey}`);
if (!model) throw new Error(`Couldn't load model ${modelKey}`)
}
const download = async() => {
statusEl.innerHTML = "Downloading model";
const saveResults = await model.save(`downloads://${modelKey}`);
statusEl.innerHTML = "Downloaded model";
console.log("Model downloaded to local filesystem with name:", modelKey)
}
export default {
setup: setup,
connect: connect,
predict: predict,
download: download,
}