Skip to content

Commit

Permalink
Add flax and friends for training neural networks
Browse files Browse the repository at this point in the history
This updates our formatter, which changed what it expects very slightly.
Update all the pieces of code to match

optax is for the actual optimizers, jax is the underlying accelerated
linear algebra library, tensorflow is for loading datasets and exporting
models.

Change-Id: Ic4c3b425cda74267e1d0ad1615c42452cbefab8a
Signed-off-by: Austin Schuh <[email protected]>
  • Loading branch information
AustinSchuh committed Sep 21, 2024
1 parent 9e69512 commit ea5f0a7
Show file tree
Hide file tree
Showing 14 changed files with 2,009 additions and 779 deletions.
3 changes: 3 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ build --sandbox_tmpfs_path=/tmp
# Prevent cypress from using its own binary. We want to use the hermetic one.
build --action_env=CYPRESS_INSTALL_BINARY=0

# Allow spaces in runfiles filenames.
build --experimental_inprocess_symlink_creation

# From our one and only phil.schrader: https://groups.google.com/g/bazel-discuss/c/5cbRuLuTwNg :)
# Enable -Werror and warnings for our code
# TODO: It would be nice to enable Wcast-align and Wcast-qual.
Expand Down
4 changes: 2 additions & 2 deletions frc971/control_loops/python/drivetrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,8 +664,8 @@ def PlotDrivetrainSprint(drivetrain_params):
max_battery_wattage = kMaxBreakerCurrent * (vbat -
kMaxBreakerCurrent * Rw)
if (max_current_request_left * max_voltage_left +
max_current_request_right * max_voltage_right >
max_battery_wattage):
max_current_request_right * max_voltage_right
> max_battery_wattage):
# Now solve the quadratic equation to figure out what the overall
# motor current can be which puts us at the max battery wattage.
max_motor_current = (
Expand Down
2 changes: 1 addition & 1 deletion motors/pistol_grip/generate_cogging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def main(argv):
if len(argv) < 4:
print 'Args: input output.cc struct_name'
print('Args: input output.cc struct_name')
return 1
data_sum = [0.0] * 4096
data_count = [0] * 4096
Expand Down
43 changes: 22 additions & 21 deletions motors/python/haptic_phase_current.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ def all_phases(function, theta_range):
# by to get motor current.
one_amp_scalar = (phases(f_single, 0.0).T * phases(g_single, 0.0))[0, 0]

print 'Max BEMF', max(f(theta_range))
print 'Max current', max(g(theta_range))
print 'Max drive voltage (one_amp_driving_voltage)', max(
one_amp_driving_voltage)
print 'one_amp_scalar', one_amp_scalar
print('Max BEMF', max(f(theta_range)))
print('Max current', max(g(theta_range)))
print('Max drive voltage (one_amp_driving_voltage)',
max(one_amp_driving_voltage))
print('one_amp_scalar', one_amp_scalar)

pylab.figure()
pylab.subplot(1, 1, 1)
Expand Down Expand Up @@ -464,19 +464,20 @@ def __init__(self):
self.B_discrete_inverse_model = numpy.matrix(numpy.eye(3)) / (
self.B_discrete_model[0, 0] - self.B_discrete_model[1, 0])

print 'constexpr double kL = %g;' % self.L_model
print 'constexpr double kM = %g;' % self.M_model
print 'constexpr double kR = %g;' % self.R_model
print 'constexpr float kAdiscrete_diagonal = %gf;' % self.A_discrete_model[
0, 0]
print 'constexpr float kAdiscrete_offdiagonal = %gf;' % self.A_discrete_model[
1, 0]
print 'constexpr float kBdiscrete_inv_diagonal = %gf;' % self.B_discrete_inverse_model[
0, 0]
print 'constexpr float kBdiscrete_inv_offdiagonal = %gf;' % self.B_discrete_inverse_model[
1, 0]
print 'constexpr double kOneAmpScalar = %g;' % one_amp_scalar
print 'constexpr double kMaxOneAmpDrivingVoltage = %g;' % max_one_amp_driving_voltage
print('constexpr double kL = %g;' % self.L_model)
print('constexpr double kM = %g;' % self.M_model)
print('constexpr double kR = %g;' % self.R_model)
print('constexpr float kAdiscrete_diagonal = %gf;' %
self.A_discrete_model[0, 0])
print('constexpr float kAdiscrete_offdiagonal = %gf;' %
self.A_discrete_model[1, 0])
print('constexpr float kBdiscrete_inv_diagonal = %gf;' %
self.B_discrete_inverse_model[0, 0])
print('constexpr float kBdiscrete_inv_offdiagonal = %gf;' %
self.B_discrete_inverse_model[1, 0])
print('constexpr double kOneAmpScalar = %g;' % one_amp_scalar)
print('constexpr double kMaxOneAmpDrivingVoltage = %g;' %
max_one_amp_driving_voltage)
print('A_discrete', self.A_discrete)
print('B_discrete', self.B_discrete)
print('B_discrete_sub', numpy.linalg.inv(self.B_discrete[0:2, 0:2]))
Expand Down Expand Up @@ -574,8 +575,8 @@ def DoControls(self, goal_current):
# Subtract that, and then run the stock statespace equation.
Vn_ff = self.B_discrete_inverse * (Inext - self.A_discrete *
(Icurrent - p) - p_next_imag.real)
print 'Vn_ff', Vn_ff
print 'Inext', Inext
print('Vn_ff', Vn_ff)
print('Inext', Inext)
Vn = Vn_ff + self.K * (Icurrent - measured_current)

E = phases(f_single, self.X[3, 0]) / Kv * self.X[4, 0]
Expand Down Expand Up @@ -629,7 +630,7 @@ def Simulate(self):

self.current_time = t

print 'Took %f to simulate' % (time.time() - start_wall_time)
print('Took %f to simulate' % (time.time() - start_wall_time))

self.data_logger.plot()

Expand Down
7 changes: 4 additions & 3 deletions tools/foxglove/creation_wrapper_npm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def main(argv: list[str]):
package.rsplit("@", maxsplit=1) for package in packages)
package_json_file = Path.cwd() / "package.json"
package_json = json.loads(package_json_file.read_text())
package_json.setdefault("dependencies", {}).update(
{package: version
for package, version in package_version_pairs})
package_json.setdefault("dependencies", {}).update({
package: version
for package, version in package_version_pairs
})

package_json_file.write_text(
json.dumps(package_json, sort_keys=True, indent=4))
Expand Down
4 changes: 4 additions & 0 deletions tools/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ load("@rules_python//python:pip.bzl", "compile_pip_requirements")
# Invoke this via "bazel run //tools/python:requirements.update".
compile_pip_requirements(
name = "requirements",
extra_args = [
# Make it so we can depend on setuptools.
"--allow-unsafe",
],
requirements_in = "requirements.txt",
requirements_txt = "requirements.lock.txt",
tags = [
Expand Down
Loading

0 comments on commit ea5f0a7

Please sign in to comment.