From 71605f4b493b0bfee73accae89fbf62d184d1770 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 2 Jul 2019 15:16:01 -0400 Subject: [PATCH] Change DeviceArray.__iter__ and DeviceArray.__reversed__ to forward to the _value. This has the effect of transferring the entire array to the host and iterating over it in host memory, rather than slicing out individual elements in device memory one by one. This is much faster for examples like `list(np.arange(10000))`; previously this took several seconds the first time due to compilation and 100ms+ subsequent times. With this change it takes < 1ms. --- jax/interpreters/xla.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 0f5af5889854..881ac58c71ae 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -526,13 +526,13 @@ def __iter__(self): if self.ndim == 0: raise TypeError("iteration over a 0-d array") # same as numpy error else: - return (self[i] for i in xrange(self.shape[0])) + return self._value.__iter__() def __reversed__(self): if self.ndim == 0: raise TypeError("iteration over a 0-d array") else: - return (self[i] for i in xrange(self.shape[0] - 1, -1, -1)) + return reversed(self._value) def __format__(self, format_spec): # Simulates behavior of https://github.com/numpy/numpy/pull/9883