Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from jax-ml:main #26

Open
wants to merge 3,036 commits into
base: main
Choose a base branch
from
Open

[pull] main from jax-ml:main #26

wants to merge 3,036 commits into from

Conversation

pull[bot]
Copy link

@pull pull bot commented Oct 1, 2024

See Commits and Changes for more details.


Created by pull[bot]

Can you help keep this open source service alive? 💖 Please sponsor : )

@pull pull bot added the ⤵️ pull label Oct 1, 2024
Google-ML-Automation and others added 29 commits March 5, 2025 12:59
PiperOrigin-RevId: 733865978
This change fixes https://github.com/jax-ml/jax/actions/runs/13686468791/job/38270929632.

From the [documentation](https://docs.python.org/3/library/os.path.html#os.path.expanduser):
`On Windows, USERPROFILE will be used if set, otherwise a combination of HOMEPATH and HOMEDRIVE will be used.`

PiperOrigin-RevId: 733935305
…pjit` based on resource_env. This is to start deprecating the need for `with mesh` and replace it with `use_mesh(mesh)`.

PiperOrigin-RevId: 733959962
For now, most of the tests are skipped.

PiperOrigin-RevId: 734026728
This change allows us to get rid of extra env vars which used to control whether to install `jax` at head. Now, `jax` will be be built and consumed in the same way as the other wheels in the continuous jobs.

PiperOrigin-RevId: 734123590
PiperOrigin-RevId: 734126415
… place.

Without this fix, lowerings of ops within the `for` body are always appended at the end, even if they have users earlier in the body. This caused an `operand #0 does not dominate this use` error.

The fix was tested in the upcoming (but not yet submitted) `test_realistic_matmul` in Pallas with Workgroup semantics.

PiperOrigin-RevId: 734157829
Key features:
* ***Support mixed prefill and decode*** to increase throughput for inference. (eg., ***5x*** speedup compared to padded Muti-Queries Paged Attention implementation for llama-3-8b.)
* ***No explicit `swapaxes`*** for `seq_len` and `num_head` in pre/post kernel. The kernel takes `num_head` in 2nd minor as it naturally was. We fold swapaxes to strided load/store in the kernel and apply transpose on the fly.
* ***No GMM (Grouped Matmul) Metadata required!*** We calculate the metadata on the fly in the kernel. This can speed up ***10%***!
* ***Increase MXU utilization 8x in GQA*** by grouping shared q heads for MXU in decode.
* ***Minimize recompilation:*** The only factors can cause recompilation are model specs, `max_num_batched_tokens` and `max_num_seqs` in the setting of mixed engine.

PiperOrigin-RevId: 734269519
…plementation to utilize the recently added support for string arrays.

Currently the serialized data and its length are being carried in two separate arrays, a fixed-with bytes array (with a hard-coded max size) and a unit32 array respectively.

PiperOrigin-RevId: 734299259
PiperOrigin-RevId: 734323206
…ith `None`. So that for a 2D input, P('data') continues to work.

PiperOrigin-RevId: 734325209
PiperOrigin-RevId: 734345644
The difficulty here is that our register tiling is based on the (64, 8)
shape, while the memory tiling is now (8, swizzle // bytewidth). Before,
we would assume that each register tile fits neatly within a single
memory tile, but now it is obviously not the case. Luckily, it wasn't
too hard to add.

PiperOrigin-RevId: 734517000
…vides the shape (until this restriction is lifted) to make sure we don't create bad shardings.

Also improve dynamic_update_slice sharding error by printing `aval.str_short()` instead of full sharding because it's concise and gives more info than the current error (i.e. it adds shape too to the error message)

Also make some formatting changes in scan lowering to make it easier to debug.

PiperOrigin-RevId: 734542862
Google-ML-Automation and others added 30 commits March 18, 2025 09:38
PiperOrigin-RevId: 738038116
…choose to use the return value or ignore it.

PiperOrigin-RevId: 738039559
… targets.

These were temporary forwarding targets that are no longer needed; use //jaxlib/cpu:cpu_kernels and //jaxlib/cuda:cuda_gpu_kernels instead.

PiperOrigin-RevId: 738085234
… lowering.

A couple of dummy transform inference rules needed to be added in order to
contend with parts of the lowering that do not use the dialect yet, along with
a transform inference rule for `memref.view`.

PiperOrigin-RevId: 738089782
PiperOrigin-RevId: 738110447
Some GPU tests in //tests/logging_test fail on Linux with NVIDIA driver only when we use hermetic CUDA (CUDA isn't installed on Linux).

Reason: method tsl::Env::Default()->GetExecutablePath()` doesn't work properly with command flag (-c). As result subprocessor couldn't get path to logging_test.py file and convert it to path of runtime where CUDA hermetic libraries are placed.

Solution: Save python program to file in runtime directory then run script from the file.
PiperOrigin-RevId: 738152663
PiperOrigin-RevId: 738154413
… shape_poly dimexpr and lower them alongside the graph if we are in a dynamic export regime.

PiperOrigin-RevId: 738171437
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/

Cleanup only, no functional changes intended.

PiperOrigin-RevId: 738183402
PiperOrigin-RevId: 738215055
… computations being staged out in block functions)

PiperOrigin-RevId: 738218113
Updates LLVM usage to match
[0230d63b4a8b](llvm/llvm-project@0230d63b4a8b)

PiperOrigin-RevId: 738222096
This code is Mosaic specific, move it to the Mosaic directory.

PiperOrigin-RevId: 738404429
…MAOp`.

Now that we have full control over strides in the lowering, these attributes
are no longer necessary.

PiperOrigin-RevId: 738418852
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.