forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pull] main from jax-ml:main #26
Open
pull
wants to merge
3,036
commits into
garymm:main
Choose a base branch
from
jax-ml:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
+125,142
−44,620
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Should fix the error in https://github.com/jax-ml/jax/actions/runs/13682579939/job/38258344926. PiperOrigin-RevId: 733838895
PiperOrigin-RevId: 733857126
PiperOrigin-RevId: 733865978
PiperOrigin-RevId: 733885890
This change fixes https://github.com/jax-ml/jax/actions/runs/13686468791/job/38270929632. From the [documentation](https://docs.python.org/3/library/os.path.html#os.path.expanduser): `On Windows, USERPROFILE will be used if set, otherwise a combination of HOMEPATH and HOMEDRIVE will be used.` PiperOrigin-RevId: 733935305
…26946 PiperOrigin-RevId: 733953836
…pjit` based on resource_env. This is to start deprecating the need for `with mesh` and replace it with `use_mesh(mesh)`. PiperOrigin-RevId: 733959962
For now, most of the tests are skipped. PiperOrigin-RevId: 734026728
http://github.com/openxla/xla/commit/6e396aae2e534dc7fc5387e2aa8b1a3a8d79a3db. PiperOrigin-RevId: 734059108
PiperOrigin-RevId: 734081057
…ype_p` PiperOrigin-RevId: 734081448
This change allows us to get rid of extra env vars which used to control whether to install `jax` at head. Now, `jax` will be be built and consumed in the same way as the other wheels in the continuous jobs. PiperOrigin-RevId: 734123590
PiperOrigin-RevId: 734126415
… place. Without this fix, lowerings of ops within the `for` body are always appended at the end, even if they have users earlier in the body. This caused an `operand #0 does not dominate this use` error. The fix was tested in the upcoming (but not yet submitted) `test_realistic_matmul` in Pallas with Workgroup semantics. PiperOrigin-RevId: 734157829
Key features: * ***Support mixed prefill and decode*** to increase throughput for inference. (eg., ***5x*** speedup compared to padded Muti-Queries Paged Attention implementation for llama-3-8b.) * ***No explicit `swapaxes`*** for `seq_len` and `num_head` in pre/post kernel. The kernel takes `num_head` in 2nd minor as it naturally was. We fold swapaxes to strided load/store in the kernel and apply transpose on the fly. * ***No GMM (Grouped Matmul) Metadata required!*** We calculate the metadata on the fly in the kernel. This can speed up ***10%***! * ***Increase MXU utilization 8x in GQA*** by grouping shared q heads for MXU in decode. * ***Minimize recompilation:*** The only factors can cause recompilation are model specs, `max_num_batched_tokens` and `max_num_seqs` in the setting of mixed engine. PiperOrigin-RevId: 734269519
…plementation to utilize the recently added support for string arrays. Currently the serialized data and its length are being carried in two separate arrays, a fixed-with bytes array (with a hard-coded max size) and a unit32 array respectively. PiperOrigin-RevId: 734299259
PiperOrigin-RevId: 734323206
…ith `None`. So that for a 2D input, P('data') continues to work. PiperOrigin-RevId: 734325209
PiperOrigin-RevId: 734345644
PiperOrigin-RevId: 734351741
http://github.com/openxla/xla/commit/f1213b83af673729b60f5096da5186246568c0fb. PiperOrigin-RevId: 734484617
PiperOrigin-RevId: 734497907
PiperOrigin-RevId: 734500798
The difficulty here is that our register tiling is based on the (64, 8) shape, while the memory tiling is now (8, swizzle // bytewidth). Before, we would assume that each register tile fits neatly within a single memory tile, but now it is obviously not the case. Luckily, it wasn't too hard to add. PiperOrigin-RevId: 734517000
…vides the shape (until this restriction is lifted) to make sure we don't create bad shardings. Also improve dynamic_update_slice sharding error by printing `aval.str_short()` instead of full sharding because it's concise and gives more info than the current error (i.e. it adds shape too to the error message) Also make some formatting changes in scan lowering to make it easier to debug. PiperOrigin-RevId: 734542862
PiperOrigin-RevId: 738038116
PiperOrigin-RevId: 738038138
…choose to use the return value or ignore it. PiperOrigin-RevId: 738039559
PiperOrigin-RevId: 738080051
… targets. These were temporary forwarding targets that are no longer needed; use //jaxlib/cpu:cpu_kernels and //jaxlib/cuda:cuda_gpu_kernels instead. PiperOrigin-RevId: 738085234
… lowering. A couple of dummy transform inference rules needed to be added in order to contend with parts of the lowering that do not use the dialect yet, along with a transform inference rule for `memref.view`. PiperOrigin-RevId: 738089782
PiperOrigin-RevId: 738110447
… in TAP PiperOrigin-RevId: 738113725
…ays. PiperOrigin-RevId: 738121106
Some GPU tests in //tests/logging_test fail on Linux with NVIDIA driver only when we use hermetic CUDA (CUDA isn't installed on Linux). Reason: method tsl::Env::Default()->GetExecutablePath()` doesn't work properly with command flag (-c). As result subprocessor couldn't get path to logging_test.py file and convert it to path of runtime where CUDA hermetic libraries are placed. Solution: Save python program to file in runtime directory then run script from the file. PiperOrigin-RevId: 738152663
PiperOrigin-RevId: 738154413
… shape_poly dimexpr and lower them alongside the graph if we are in a dynamic export regime. PiperOrigin-RevId: 738171437
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/ Cleanup only, no functional changes intended. PiperOrigin-RevId: 738183402
PiperOrigin-RevId: 738190879
http://github.com/openxla/xla/commit/df971129bd82e381954da0185b534220e21798a4. PiperOrigin-RevId: 738213047
PiperOrigin-RevId: 738215055
… computations being staged out in block functions) PiperOrigin-RevId: 738218113
Updates LLVM usage to match [0230d63b4a8b](llvm/llvm-project@0230d63b4a8b) PiperOrigin-RevId: 738222096
PiperOrigin-RevId: 738235363
PiperOrigin-RevId: 738315605
PiperOrigin-RevId: 738321801
PiperOrigin-RevId: 738342517
http://github.com/openxla/xla/commit/0d20d73f2c8f21c21b9f343c4363a76e980f032e. PiperOrigin-RevId: 738352930
PiperOrigin-RevId: 738376533
…ady `jit`ted. PiperOrigin-RevId: 738393973
PiperOrigin-RevId: 738398099
This code is Mosaic specific, move it to the Mosaic directory. PiperOrigin-RevId: 738404429
PiperOrigin-RevId: 738410122
…MAOp`. Now that we have full control over strides in the lowering, these attributes are no longer necessary. PiperOrigin-RevId: 738418852
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
See Commits and Changes for more details.
Created by
pull[bot]
Can you help keep this open source service alive? 💖 Please sponsor : )