Skip to content

Commit

Permalink
new slab layout
Browse files Browse the repository at this point in the history
Namely:

* The slab is now `n` by 512 bytes, each row considered a
  "word". Addressing is word-aligned for now.
* Program values are now assumed 1-dimensional, for the time being.

Also:

* Tile shapes are now parameters of the tiling template operation.
* Views carry dtypes. The slab is stored in `uint8` and bitcast
  according to a view as needed.

Co-authored-by: Adam Paszke <[email protected]>
Co-authored-by: Matthew Johnson <[email protected]>
  • Loading branch information
3 people committed Jun 5, 2024
1 parent b5b55db commit eb68d68
Showing 1 changed file with 81 additions and 57 deletions.
138 changes: 81 additions & 57 deletions jax/experimental/slab/slab.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,121 +27,141 @@

map, zip = util.safe_map, util.safe_zip

Slab = jax.Array
Address = jax.Array
DShape = tuple[Union[int, jax.Array]]
SShape = tuple[int]

block_sz = 2

class Slab(NamedTuple):
data: jax.Array
cursor: Address

@jax.tree_util.register_pytree_node_class
class SlabView(NamedTuple):
addr: Address
shape: DShape
# We'll want dtypes eventually. For now, everything is f32.
#dtype: jax.typing.DTypeLike
dtype: jax.typing.DTypeLike

def size(self):
return jnp.prod(jnp.array(self.shape))

def ndim(self):
return len(self.shape)

def slab_make(sz, dtype):
return Slab(jnp.zeros(sz, dtype=dtype), jnp.array(0, dtype=int))
def tree_flatten(self):
return (self.addr, self.shape), self.dtype

def slab_alloc(slab, shape):
sz = jnp.prod(jnp.array(shape))
new_slab = Slab(slab.data, slab.cursor + sz)
slab_val = SlabView(slab.cursor, shape)
return new_slab, slab_val
@classmethod
def tree_unflatten(cls, dtype, xs):
addr, shape = xs
return cls(addr, shape, dtype)

word_sz = 512

def slab_read(slab, addr, shape):
sz = np.prod(shape)
flat = jax.lax.dynamic_slice_in_dim(slab.data, addr, sz, axis=0)
return flat.reshape(shape)
def slab_make(num_vmem_words):
return Slab(jnp.zeros((num_vmem_words, word_sz), dtype=jnp.uint8),
jnp.array(0, dtype=jnp.int32))

def slab_write(slab, addr, y):
flat = jnp.ravel(y)
data = jax.lax.dynamic_update_slice_in_dim(slab.data, flat, addr, axis=0)
def slab_alloc(slab, shape, dtype):
num_elts = jnp.prod(jnp.array(shape))
mem_sz = (num_elts * dtype.itemsize + word_sz - 1) // word_sz
new_slab = Slab(slab.data, slab.cursor + mem_sz)
slab_val = SlabView(slab.cursor, shape, dtype)
return new_slab, slab_val

def slab_read(slab, view, word_offset, tile_sshape):
assert view.ndim() == 1, 'for now'
assert view.ndim() == len(tile_sshape)
sz = np.prod(tile_sshape) * view.dtype.itemsize
assert sz % word_sz == 0, sz
mem = jax.lax.dynamic_slice_in_dim(
slab.data, view.addr + word_offset, sz // word_sz, axis=0)
cast = jax.lax.bitcast_convert_type(
mem.reshape((-1, view.dtype.itemsize)), view.dtype)
return cast.reshape(tile_sshape)

# TODO: just take vjp of slab_read
def slab_write(slab, view, word_offset, arr):
assert view.ndim() == 1, 'for now'
assert view.ndim() == arr.ndim
assert view.dtype == arr.dtype
sz = np.prod(arr.shape) * view.dtype.itemsize
assert sz % word_sz == 0, sz
cast = jax.lax.bitcast_convert_type(arr, slab.data.dtype)
mem = cast.reshape((-1, word_sz))
data = jax.lax.dynamic_update_slice_in_dim(
slab.data, mem, view.addr + word_offset, axis=0)
return Slab(data, slab.cursor)

def tile_loop_bounds(operands):
def elemwise_loop_bounds(operands):
x, *_ = operands
assert x.ndim() == 1, 'for now'
x_sz = x.size()
return 0, x_sz

def tile_loop_cond(kernel, slab, cursor, end, operands, results):
def elemwise_loop_cond(tile_sshape, kernel,
slab, cursor, end, operands, results):
return cursor < end

def tile_loop_body(kernel, slab, cursor, end, operands, results):
def elemwise_loop_body(tile_sshape, kernel,
slab, cursor, end, operands, results):
tile_sz, = tile_sshape
x, *_ = operands
in_tiles = [slab_read(slab, x.addr + cursor, (block_sz,) * x.ndim())
for x in operands]
in_tiles = [slab_read(slab, x, cursor, tile_sshape) for x in operands]
out_tiles = kernel(*in_tiles)
for y, r in zip(out_tiles, results):
slab = slab_write(slab, r.addr + cursor, y)
cursor = cursor + block_sz * x.ndim()
slab = slab_write(slab, r, cursor, y)
cursor = cursor + tile_sz
return slab, cursor, end, operands, results

def while_loop(cond, body, *args):
def c(x): return cond(*x)
def b(x): return body(*x)
return jax.lax.while_loop(c, b, args)

@partial(jax.jit, static_argnums=0)
def tile(kernel, slab: Slab, operands: tuple[SlabView], results: tuple[SlabView]):
start, end = tile_loop_bounds(operands)
slab, *_ = while_loop(partial(tile_loop_cond, kernel),
partial(tile_loop_body, kernel),
@partial(jax.jit, static_argnums=(0, 1))
def tile(tile_sshape, kernel,
slab: Slab, operands: tuple[SlabView], results: tuple[SlabView]):
start, end = elemwise_loop_bounds(operands)
slab, *_ = while_loop(partial(elemwise_loop_cond, tile_sshape, kernel),
partial(elemwise_loop_body, tile_sshape, kernel),
slab, start, end, operands, results)
return slab

def make_elementwise_op(name, op):
def make_elementwise_op(tile_sshape, name, op):
def kernel(*args): return [op(*args)]

def f(slab: Slab, *operands: tuple[SlabView]):
x, *_ = operands
slab, result = slab_alloc(slab, x.shape)
slab = tile(kernel, slab, operands, [result])
slab, result = slab_alloc(slab, x.shape, x.dtype)
slab = tile(tile_sshape, kernel, slab, operands, [result])
return slab, result

f.__name__ = name
return f

add = make_elementwise_op('add', jax.lax.add)
mul = make_elementwise_op('mul', jax.lax.mul)
tile_sz = word_sz * 2
add = make_elementwise_op((tile_sz,), 'add', jax.lax.add)
mul = make_elementwise_op((tile_sz,), 'mul', jax.lax.mul)

def parse_arr(i, s):
shape = eval(s)
assert all(d % block_sz == 0 for d in shape)
return 3 * i + jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
z = 3 * i + jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
return z.ravel()

def main(args):
xs = map(parse_arr, range(len(args)), args)
print('initial args:')
for x in xs:
print(x)

sz = xs[0].size
shape = xs[0].shape

slab = slab_make(1024, jnp.float32)
slab = slab_make(1024)

vals = []
for x in xs:
slab, v = slab_alloc(slab, x.shape)
slab = slab_write(slab, v.addr, x)
slab, v = slab_alloc(slab, x.shape, x.dtype)
slab = slab_write(slab, v, 0, x)
vals.append(v)

print()
print('-- args allocated on slab:')
print('slab:', slab)
print('ptrs:', vals)

def f(slab, *vals):
slab, y = add(slab, *vals)
slab, z = mul(slab, y, vals[0])
Expand All @@ -150,21 +170,25 @@ def f(slab, *vals):
print()
print(jax.make_jaxpr(f)(slab, *vals))

print()
print('-- initial args:')
for x in xs:
print(x)

print()
print('-- args allocated on slab:')
print('slab:', slab)
print('ptrs:', vals)

slab, y, z = jax.jit(f)(slab, *vals)
print()
print('-- slab ptr results')
print('add:', y)
print('mul:', z)
print()
print('-- slab space')
print('arg:', slab.data[:sz])
print('arg:', slab.data[sz:sz * 2])
print('add:', slab.data[sz * 2:sz * 3])
print('mul:', slab.data[sz * 3:sz * 4])
print()
print('-- read off slab')
print(slab_read(slab, y.addr, shape))
print(slab_read(slab, z.addr, shape))
print(slab_read(slab, y, 0, shape).astype(jnp.int32))
print(slab_read(slab, z, 0, shape).astype(jnp.int32))


if __name__ == '__main__':
Expand Down

0 comments on commit eb68d68

Please sign in to comment.