Skip to content

Commit

Permalink
Add custom call on output along with S(5) because XLA requires the cu…
Browse files Browse the repository at this point in the history
…stom call to show the transfer.

Enable paramater streaming and weight offloading

PiperOrigin-RevId: 619711649
  • Loading branch information
yashk2810 authored and jax authors committed Mar 28, 2024
1 parent ec73c40 commit 9e86aa5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 21 deletions.
8 changes: 8 additions & 0 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,14 @@ def aval_to_types(aval):
o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s)
for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]

# Insert a custom call if output is on host because XLA needs that to do the
# transfer.
if ir_result_memory_kinds is not None:
# TODO: We should have a default memory kind which we can check against.
flat_outputs = [
o if mk is None or mk == 'device' else wrap_with_memory_kind(o, mk, o_aval)
for o, mk, o_aval in zip(flat_outputs, ir_result_memory_kinds, output_avals)]

if ir_result_shardings is not None and name == "main":
flat_outputs = [
a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # type: ignore
Expand Down
44 changes: 23 additions & 21 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,27 +196,6 @@ def test_default_memory_kind(self):
dev = jax.devices()[0]
self.assertEqual(dev.default_memory().kind, "device")

def test_parameter_streaming(self):
self.skipTest("Enable after pinned_host support exists")

_, s_host, np_inp, inp_host = _create_inputs(
(8, 2), P("x", "y"), mem_kind="pinned_host")
s_dev = s_host.with_memory_kind('device')
inp_dev = jax.device_put(inp_host, s_dev)

@functools.partial(jax.jit, out_shardings=s_dev)
def f(a, b):
x = b * 2
y = jax.device_put(a, s_dev)
z = x * y
return z * 4, z

compiled = f.lower(inp_host, inp_dev).compile() # doesn't crash
compiled_text = compiled.as_text()
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")

# TODO(yashkatariya): Add execution tests when it works.


class MemoriesComputationTest(jtu.BufferDonationTestCase):

Expand Down Expand Up @@ -1115,6 +1094,29 @@ def test_device_put_python_int(self):
self._check_device_put_addressable_shards(
out_host, py_inp, s_host, "unpinned_host", index=False)

def test_parameter_streaming(self):
_, s_host, np_inp, inp_host = _create_inputs(
(8, 2), P("x", "y"), mem_kind="pinned_host")
s_dev = s_host.with_memory_kind('device')
inp_dev = jax.device_put(np_inp, s_dev)

@functools.partial(jax.jit, out_shardings=s_host)
def f(a, b):
x = b * 2
y = jax.device_put(a, s_dev)
z = x * y
return z * 4, z

compiled = f.lower(inp_host, inp_dev).compile() # doesn't crash
compiled_text = compiled.as_text()
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")

out1, out2 = f(inp_host, inp_dev)
self._check_device_put_addressable_shards(
out1, np_inp * np_inp * 8, s_host, 'pinned_host')
self._check_device_put_addressable_shards(
out2, np_inp * np_inp * 2, s_host, 'pinned_host')


class ActivationOffloadingTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 9e86aa5

Please sign in to comment.