Skip to content

Commit

Permalink
JAX: Support multiple inputs X for example Seq2Seq data
Browse files Browse the repository at this point in the history
Earlier the tuple unpacking took place when calling the
loss function in the training step. This wasn't elegant,
since we need to mark some args as static while using jax.jit
for the loss. If the X argument i.e. tuple contains more
than one input like in Machine Translation Dataset the static_argnums
values will be disrupted and will be wrong. Hence we delay the unpacking
of the input batch tuple to when the fwd pass is called by `apply_fn`.
  • Loading branch information
AnirudhDagar committed Dec 8, 2022
1 parent 1e52389 commit a9e32bb
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 27 deletions.
2 changes: 1 addition & 1 deletion chapter_convolutional-modern/batch-norm.md
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ variables.
def loss(self, params, X, Y, state, averaged=True):
Y_hat, updates = state.apply_fn({'params': params,
'batch_stats': state.batch_stats},
X, mutable=['batch_stats'],
*X, mutable=['batch_stats'],
rngs={'dropout': jax.random.PRNGKey(0)})
Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
Y = d2l.reshape(Y, (-1,))
Expand Down
8 changes: 4 additions & 4 deletions chapter_linear-classification/classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,17 @@ class Classifier(d2l.Module): #@save
# Here value is a tuple since models with BatchNorm layers require
# the loss to return auxiliary data
value, grads = jax.value_and_grad(
self.loss, has_aux=True)(params, *batch[:-1], batch[-1], state)
self.loss, has_aux=True)(params, batch[:-1], batch[-1], state)
l, _ = value
self.plot("loss", l, train=True)
return value, grads
def validation_step(self, params, batch, state):
# Discard the second returned value. It is used for training models
# with BatchNorm layers since loss also returns auxiliary data
l, _ = self.loss(params, *batch[:-1], batch[-1], state)
l, _ = self.loss(params, batch[:-1], batch[-1], state)
self.plot('loss', l, train=False)
self.plot('acc', self.accuracy(params, *batch[:-1], batch[-1], state),
self.plot('acc', self.accuracy(params, batch[:-1], batch[-1], state),
train=False)
```

Expand Down Expand Up @@ -162,7 +162,7 @@ def accuracy(self, params, X, Y, state, averaged=True):
"""Compute the number of correct predictions."""
Y_hat = state.apply_fn({'params': params,
'batch_stats': state.batch_stats}, # BatchNorm Only
X)
*X)
Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
preds = d2l.astype(d2l.argmax(Y_hat, axis=1), Y.dtype)
compare = d2l.astype(preds == d2l.reshape(Y, -1), d2l.float32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def loss(self, Y_hat, Y, averaged=True):
@d2l.add_to_class(d2l.Classifier) #@save
@partial(jax.jit, static_argnums=(0, 5))
def loss(self, params, X, Y, state, averaged=True):
Y_hat = state.apply_fn({'params': params}, X,
Y_hat = state.apply_fn({'params': params}, *X,
mutable=False, rngs=None) # To be used later (e.g., for batch norm)
Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
Y = d2l.reshape(Y, (-1,))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def loss(self, y_hat, y):
def loss(self, params, X, y, state):
def cross_entropy(y_hat, y):
return - d2l.reduce_mean(d2l.log(y_hat[list(range(len(y_hat))), y]))
y_hat = state.apply_fn({'params': params}, X)
y_hat = state.apply_fn({'params': params}, *X)
# The returned empty dictionary is a placeholder for auxiliary data,
# which will be used later (e.g., for batch norm)
return cross_entropy(y_hat, y), {}
Expand Down
2 changes: 1 addition & 1 deletion chapter_linear-regression/linear-regression-concise.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def loss(self, y_hat, y):
%%tab jax
@d2l.add_to_class(LinearRegression) #@save
def loss(self, params, X, y, state):
y_hat = state.apply_fn({'params': params}, X)
y_hat = state.apply_fn({'params': params}, *X)
return d2l.reduce_mean(optax.l2_loss(y_hat, y))
```

Expand Down
2 changes: 1 addition & 1 deletion chapter_linear-regression/linear-regression-scratch.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def loss(self, y_hat, y):
%%tab jax
@d2l.add_to_class(LinearRegressionScratch) #@save
def loss(self, params, X, y, state):
y_hat = state.apply_fn({'params': params}, X)
y_hat = state.apply_fn({'params': params}, *X) # X unpacked from a tuple
l = (y_hat - d2l.reshape(y, y_hat.shape)) ** 2 / 2
return d2l.reduce_mean(l)
```
Expand Down
4 changes: 2 additions & 2 deletions chapter_linear-regression/oo-design.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,13 @@ class Module(d2l.nn_Module, d2l.HyperParameters): #@save
if tab.selected('jax'):
def training_step(self, params, batch, state):
l, grads = jax.value_and_grad(self.loss)(params, *batch[:-1],
l, grads = jax.value_and_grad(self.loss)(params, batch[:-1],
batch[-1], state)
self.plot("loss", l, train=True)
return l, grads
def validation_step(self, params, batch, state):
l = self.loss(params, *batch[:-1], batch[-1], state)
l = self.loss(params, batch[:-1], batch[-1], state)
self.plot('loss', l, train=False)
def apply_init(self, dummy_input, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion chapter_multilayer-perceptrons/dropout.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ mask internally.
@d2l.add_to_class(d2l.Classifier) #@save
@partial(jax.jit, static_argnums=(0, 5))
def loss(self, params, X, Y, state, averaged=True):
Y_hat = state.apply_fn({'params': params}, X,
Y_hat = state.apply_fn({'params': params}, *X,
mutable=False, # To be used later (e.g., batch norm)
rngs={'dropout': jax.random.PRNGKey(0)})
Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
Expand Down
4 changes: 2 additions & 2 deletions chapter_recurrent-neural-networks/rnn-scratch.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,13 @@ class RNNLMScratch(d2l.Classifier): #@save
def training_step(self, params, batch, state):
value, grads = jax.value_and_grad(
self.loss, has_aux=True)(params, *batch[:-1], batch[-1], state)
self.loss, has_aux=True)(params, batch[:-1], batch[-1], state)
l, _ = value
self.plot('ppl', d2l.exp(l), train=True)
return value, grads
def validation_step(self, params, batch, state):
l, _ = self.loss(params, *batch[:-1], batch[-1], state)
l, _ = self.loss(params, batch[:-1], batch[-1], state)
self.plot('ppl', d2l.exp(l), train=False)
```

Expand Down
26 changes: 13 additions & 13 deletions d2l/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,13 @@ def plot(self, key, value, train):
every_n=int(n))

def training_step(self, params, batch, state):
l, grads = jax.value_and_grad(self.loss)(params, *batch[:-1],
l, grads = jax.value_and_grad(self.loss)(params, batch[:-1],
batch[-1], state)
self.plot("loss", l, train=True)
return l, grads

def validation_step(self, params, batch, state):
l = self.loss(params, *batch[:-1], batch[-1], state)
l = self.loss(params, batch[:-1], batch[-1], state)
self.plot('loss', l, train=False)

def apply_init(self, dummy_input, **kwargs):
Expand Down Expand Up @@ -414,7 +414,7 @@ def forward(self, X):

def loss(self, params, X, y, state):
"""Defined in :numref:`sec_linear_scratch`"""
y_hat = state.apply_fn({'params': params}, X)
y_hat = state.apply_fn({'params': params}, *X) # X unpacked from a tuple
l = (y_hat - d2l.reshape(y, y_hat.shape)) ** 2 / 2
return d2l.reduce_mean(l)

Expand Down Expand Up @@ -466,7 +466,7 @@ def forward(self, X):

def loss(self, params, X, y, state):
"""Defined in :numref:`sec_linear_concise`"""
y_hat = state.apply_fn({'params': params}, X)
y_hat = state.apply_fn({'params': params}, *X)
return d2l.reduce_mean(optax.l2_loss(y_hat, y))

def configure_optimizers(self):
Expand Down Expand Up @@ -523,17 +523,17 @@ def training_step(self, params, batch, state):
# Here value is a tuple since models with BatchNorm layers require
# the loss to return auxiliary data
value, grads = jax.value_and_grad(
self.loss, has_aux=True)(params, *batch[:-1], batch[-1], state)
self.loss, has_aux=True)(params, batch[:-1], batch[-1], state)
l, _ = value
self.plot("loss", l, train=True)
return value, grads

def validation_step(self, params, batch, state):
# Discard the second returned value. It is used for training models
# with BatchNorm layers since loss also returns auxiliary data
l, _ = self.loss(params, *batch[:-1], batch[-1], state)
l, _ = self.loss(params, batch[:-1], batch[-1], state)
self.plot('loss', l, train=False)
self.plot('acc', self.accuracy(params, *batch[:-1], batch[-1], state),
self.plot('acc', self.accuracy(params, batch[:-1], batch[-1], state),
train=False)

@partial(jax.jit, static_argnums=(0, 5))
Expand All @@ -543,7 +543,7 @@ def accuracy(self, params, X, Y, state, averaged=True):
Defined in :numref:`sec_classification`"""
Y_hat = state.apply_fn({'params': params,
'batch_stats': state.batch_stats}, # BatchNorm Only
X)
*X)
Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
preds = d2l.astype(d2l.argmax(Y_hat, axis=1), Y.dtype)
compare = d2l.astype(preds == d2l.reshape(Y, -1), d2l.float32)
Expand All @@ -552,7 +552,7 @@ def accuracy(self, params, X, Y, state, averaged=True):
@partial(jax.jit, static_argnums=(0, 5))
def loss(self, params, X, Y, state, averaged=True):
"""Defined in :numref:`sec_softmax_concise`"""
Y_hat = state.apply_fn({'params': params}, X,
Y_hat = state.apply_fn({'params': params}, *X,
mutable=False, rngs=None) # To be used later (e.g., for batch norm)
Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
Y = d2l.reshape(Y, (-1,))
Expand All @@ -564,7 +564,7 @@ def loss(self, params, X, Y, state, averaged=True):
@partial(jax.jit, static_argnums=(0, 5))
def loss(self, params, X, Y, state, averaged=True):
"""Defined in :numref:`sec_dropout`"""
Y_hat = state.apply_fn({'params': params}, X,
Y_hat = state.apply_fn({'params': params}, *X,
mutable=False, # To be used later (e.g., batch norm)
rngs={'dropout': jax.random.PRNGKey(0)})
Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
Expand All @@ -589,7 +589,7 @@ def loss(self, params, X, Y, state, averaged=True):
"""Defined in :numref:`subsec_layer-normalization-in-bn`"""
Y_hat, updates = state.apply_fn({'params': params,
'batch_stats': state.batch_stats},
X, mutable=['batch_stats'],
*X, mutable=['batch_stats'],
rngs={'dropout': jax.random.PRNGKey(0)})
Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
Y = d2l.reshape(Y, (-1,))
Expand Down Expand Up @@ -821,13 +821,13 @@ def setup(self):

def training_step(self, params, batch, state):
value, grads = jax.value_and_grad(
self.loss, has_aux=True)(params, *batch[:-1], batch[-1], state)
self.loss, has_aux=True)(params, batch[:-1], batch[-1], state)
l, _ = value
self.plot('ppl', d2l.exp(l), train=True)
return value, grads

def validation_step(self, params, batch, state):
l, _ = self.loss(params, *batch[:-1], batch[-1], state)
l, _ = self.loss(params, batch[:-1], batch[-1], state)
self.plot('ppl', d2l.exp(l), train=False)

def one_hot(self, X):
Expand Down

0 comments on commit a9e32bb

Please sign in to comment.