Skip to content

Commit

Permalink
JAX: Add section softmax-regression-scratch.md (d2l-ai#2294)
Browse files Browse the repository at this point in the history
  • Loading branch information
AnirudhDagar authored Sep 7, 2022
1 parent f348aec commit d5926ac
Showing 1 changed file with 46 additions and 6 deletions.
52 changes: 46 additions & 6 deletions chapter_linear-classification/softmax-regression-scratch.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
```{.python .input n=1}
%load_ext d2lbook.tab
tab.interact_select(['mxnet', 'pytorch', 'tensorflow'])
tab.interact_select(['mxnet', 'pytorch', 'tensorflow', 'jax'])
```

# Softmax Regression Implementation from Scratch
Expand Down Expand Up @@ -34,6 +34,14 @@ from d2l import tensorflow as d2l
import tensorflow as tf
```

```{.python .input}
%%tab jax
from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
```

## The Softmax

Let's begin with the most important part:
Expand Down Expand Up @@ -95,6 +103,13 @@ X_prob = softmax(X)
X_prob, d2l.reduce_sum(X_prob, 1)
```

```{.python .input}
%%tab jax
X = jax.random.uniform(jax.random.PRNGKey(d2l.get_seed()), (2, 5))
X_prob = softmax(X)
X_prob, d2l.reduce_sum(X_prob, 1)
```

## The Model

We now have everything that we need
Expand Down Expand Up @@ -165,6 +180,20 @@ class SoftmaxRegressionScratch(d2l.Classifier):
self.b = tf.Variable(self.b)
```

```{.python .input}
%%tab jax
class SoftmaxRegressionScratch(d2l.Classifier):
num_inputs: int
num_outputs: int
lr: float
sigma: float = 0.01
def setup(self):
self.W = self.param('W', nn.initializers.normal(self.sigma),
(self.num_inputs, self.num_outputs))
self.b = self.param('b', nn.initializers.zeros, self.num_outputs)
```

The code below defines how the network
maps each input to an output.
Note that we flatten each $28 \times 28$ pixel image in the batch
Expand Down Expand Up @@ -202,7 +231,7 @@ The correct labels are $1$ and $2$ respectively.
we can pick out terms efficiently.

```{.python .input}
%%tab mxnet, pytorch
%%tab mxnet, pytorch, jax
y = d2l.tensor([0, 2])
y_hat = d2l.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y]
Expand All @@ -218,9 +247,9 @@ tf.boolean_mask(y_hat, tf.one_hot(y, depth=y_hat.shape[-1]))
Now we can (**implement the cross-entropy loss function**) by averaging over the logarithms of the selected probabilities.

```{.python .input}
%%tab mxnet, pytorch
%%tab mxnet, pytorch, jax
def cross_entropy(y_hat, y):
return - d2l.reduce_mean(d2l.log(y_hat[range(len(y_hat)), y]))
return - d2l.reduce_mean(d2l.log(y_hat[list(range(len(y_hat))), y]))
cross_entropy(y_hat, y)
```
Expand All @@ -235,12 +264,20 @@ cross_entropy(y_hat, y)
```

```{.python .input}
%%tab all
%%tab pytorch, mxnet, tensorflow
@d2l.add_to_class(SoftmaxRegressionScratch)
def loss(self, y_hat, y):
return cross_entropy(y_hat, y)
```

```{.python .input}
%%tab jax
@d2l.add_to_class(SoftmaxRegressionScratch)
def loss(self, params, X, y):
y_hat = self.apply(params, X)
return cross_entropy(y_hat, y)
```

## Training

We reuse the `fit` method defined in :numref:`sec_linear_scratch` to [**train the model with 10 epochs.**]
Expand Down Expand Up @@ -279,7 +316,10 @@ our model is ready to [**classify some images.**]
```{.python .input}
%%tab all
X, y = next(iter(data.val_dataloader()))
preds = d2l.argmax(model(X), axis=1)
if tab.selected('pytorch', 'mxnet', 'tensorflow'):
preds = d2l.argmax(model(X), axis=1)
if tab.selected('jax'):
preds = d2l.argmax(model.apply(trainer.state.params, X), axis=1)
preds.shape
```

Expand Down

0 comments on commit d5926ac

Please sign in to comment.