Skip to content

Commit

Permalink
Fix not finding wheels in bazel output
Browse files Browse the repository at this point in the history
  • Loading branch information
mrodden committed Aug 12, 2024
1 parent df2d140 commit 701cda8
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions build/rocm/tools/build_wheels.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None):
LOG.info("Running %r from cwd=%r" % (cmd, jax_path))
pattern = re.compile("Output wheel: (.+)\n")

return _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stderr")
_run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stderr")


def build_jax_wheel(jax_path, python_version):
Expand All @@ -101,10 +101,7 @@ def build_jax_wheel(jax_path, python_version):
LOG.info("Running %r from cwd=%r" % (cmd, jax_path))
pattern = re.compile("Successfully built jax-.+ and (jax-.+\.whl)\n")

wheels = _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stdout")

paths = list(map(lambda x: os.path.join(jax_path, "dist", x), wheels))
return paths
_run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stdout")


def _run_scan_for_output(cmd, pattern, env=None, cwd=None, capture=None):
Expand Down Expand Up @@ -203,6 +200,17 @@ def parse_args():
return p.parse_args()


def find_wheels(path):
wheels = []

for f in os.listdir(path):
if f.endswith(".whl"):
wheels.append(os.path.join(path, f))

LOG.info("Found wheels: %r" % wheels)
return wheels


def main():
args = parse_args()
python_versions = args.python_versions.split(",")
Expand All @@ -217,22 +225,27 @@ def main():
update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS)

for py in python_versions:
wheel_paths = build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path)
build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path)
wheel_paths = find_wheels(os.path.join(args.jax_path, "dist"))
for wheel_path in wheel_paths:
fix_wheel(wheel_path, args.jax_path)
# skip jax wheel since it is non-platform
if not os.path.basename(wheel_path).startswith("jax-"):
fix_wheel(wheel_path, args.jax_path)

# build JAX wheel for completeness
jax_wheels = build_jax_wheel(args.jax_path, python_versions[-1])
build_jax_wheel(args.jax_path, python_versions[-1])
wheels = find_wheels(os.path.join(args.jax_path, "dist"))

# NOTE(mrodden): the jax wheel is a "non-platform wheel", so auditwheel will
# do nothing, and in fact will throw an Exception. we just need to copy it
# along with the jaxlib and plugin ones

# copy jax wheel(s) to wheelhouse
wheelhouse_dir = "/wheelhouse/"
for whl in jax_wheels:
LOG.info("Copying %s into %s" % (whl, wheelhouse_dir))
shutil.copy(whl, wheelhouse_dir)
for whl in wheels:
if os.path.basename(whl).startswith("jax-"):
LOG.info("Copying %s into %s" % (whl, wheelhouse_dir))
shutil.copy(whl, wheelhouse_dir)


if __name__ == "__main__":
Expand Down

0 comments on commit 701cda8

Please sign in to comment.