Skip to content

Commit

Permalink
Workaround to make XRT work with multi-backend change.
Browse files Browse the repository at this point in the history
XrtBuffers don't expose their platform name. I couldn't figure out a nice way to plumb this through, so I added this workaround for now.
  • Loading branch information
skye authored Aug 27, 2019
1 parent 1cd37bd commit d582095
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,8 +647,12 @@ def __hash__(self):
def _device_array_constant_handler(c, val, canonicalize_types=True):
return c.Constant(onp.asarray(val), canonicalize_types=canonicalize_types)
xb.register_constant_handler(DeviceArray, _device_array_constant_handler)

def _device_put_device_array(x, device_num, backend):
if xb.get_backend(backend).platform == x.device_buffer.platform():
# TODO(skye): we're assuming the DeviceBuffers without "platform" are
# XrtBuffers. Figure out a less risky way to deal with XrtBuffers.
if (not hasattr(x.device_buffer, "platform") or
xb.get_backend(backend).platform == x.device_buffer.platform()):
if x.device_buffer.device() == device_num:
return x.device_buffer
else:
Expand Down

0 comments on commit d582095

Please sign in to comment.