Skip to content

Commit

Permalink
calculating loss wrt each variable in backwards pass
Browse files Browse the repository at this point in the history
  • Loading branch information
osamja committed Mar 6, 2023
1 parent 129d1b8 commit c9d4055
Showing 1 changed file with 118 additions and 102 deletions.
220 changes: 118 additions & 102 deletions build_makemore_backprop_ninja.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {
"id": "ChBbac4y8PPq"
},
Expand All @@ -36,102 +36,30 @@
},
{
"cell_type": "code",
"execution_count": 65,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"total error: 0.2983711063861847\n",
"partial_e_wrt_w5: 0.08216704428195953\n",
"old w5: 0.4\n",
"updated w5: 0.3589164912700653\n",
"total error: 0.29498523473739624\n"
]
}
],
"source": [
"i1 = 0.05\n",
"i2 = 0.1\n",
"\n",
"w1 = 0.15\n",
"w2 = 0.2\n",
"w3 = 0.25\n",
"w4 = 0.3\n",
"w5 = .4\n",
"w6 = .45\n",
"w7 = .5\n",
"w8 = .55\n",
"\n",
"b1 = 0.35\n",
"b2 = 0.6\n",
"\n",
"h1 = i1 * w1 + i2 * w2 + b1\n",
"h2 = i1 * w3 + i2 * w4 + b1\n",
"\n",
"h1 = torch.sigmoid(torch.tensor(h1))\n",
"h2 = torch.sigmoid(torch.tensor(h2))\n",
"\n",
"o1 = h1 * w5 + h2 * w6 + b2\n",
"o2 = h1 * w7 + h2 * w8 + b2\n",
"\n",
"o1 = torch.sigmoid((o1))\n",
"o2 = torch.sigmoid((o2))\n",
"\n",
"target1 = 0.01\n",
"target2 = 0.99\n",
"\n",
"e1 = 0.5 * (target1 - o1)**2 # MSE\n",
"e2 = 0.5 * (target2 - o2)**2\n",
"\n",
"e = e1 + e2\n",
"print(f'total error: {e}')\n",
"\n",
"partial_e_wrt_o1 = -(target1 - o1)\n",
"partial_out1_wrt_net1 = o1 * (1 - o1)\n",
"partial_net1_wrt_w5 = h1\n",
"partial_e_wrt_w5 = partial_e_wrt_o1 * partial_out1_wrt_net1 * partial_net1_wrt_w5\n",
"\n",
"print(f'partial_e_wrt_w5: {partial_e_wrt_w5}')\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 213,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"old w5: -5.680357456207275\n",
"updated w5: -5.72144079208374\n",
"total error: 0.02561262622475624\n"
"32033\n",
"15\n",
"['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']\n"
]
}
],
"source": [
"# Let's update w5 and recalculate the total error\n",
"LR = 0.5\n",
"print(f'old w5: {w5}')\n",
"w5 = w5 - LR * partial_e_wrt_w5\n",
"print(f'updated w5: {w5}')\n",
"\n",
"o1 = h1 * w5 + h2 * w6 + b2\n",
"o1 = torch.sigmoid((o1))\n",
"e1 = 0.5 * (target1 - o1)**2 # MSE\n",
"e = e1 + e2\n",
"print(f'total error: {e}')\n",
"\n",
"# Let's update w6 and recalculate the total error"
"# read in all the words\n",
"words = open('names.txt', 'r').read().splitlines()\n",
"print(len(words))\n",
"print(max(len(w) for w in words))\n",
"print(words[:8])"
]
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 4,
"metadata": {
"id": "klmu3ZG08PPr"
},
Expand All @@ -143,8 +71,7 @@
"muhammadibrahim\n",
"muhammadmustafa\n",
"32033\n",
"<generator object <genexpr> at 0x7f7d90e685f0>\n",
"camila\n",
"<generator object <genexpr> at 0x7fcd19304820>\n",
"['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']\n"
]
}
Expand Down Expand Up @@ -174,7 +101,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 5,
"metadata": {
"id": "BCQomLE_8PPs"
},
Expand All @@ -201,7 +128,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 6,
"metadata": {
"id": "V_zt2QHr8PPs"
},
Expand Down Expand Up @@ -249,7 +176,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 7,
"metadata": {
"id": "eg20-vsg8PPt"
},
Expand All @@ -260,7 +187,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 8,
"metadata": {
"id": "MJPU8HT08PPu"
},
Expand All @@ -276,7 +203,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 9,
"metadata": {
"id": "ZlFLjQyT8PPu"
},
Expand Down Expand Up @@ -317,7 +244,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 10,
"metadata": {
"id": "QY-y96Y48PPv"
},
Expand All @@ -332,7 +259,7 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 77,
"metadata": {
"id": "8ofj1s6d8PPv"
},
Expand All @@ -341,8 +268,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"loss: 3.343810558319092\n",
"loss: 3.343810558319092\n"
"loss: 3.338346242904663\n",
"loss: 3.338346242904663\n"
]
}
],
Expand Down Expand Up @@ -373,9 +300,8 @@
"counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...\n",
"probs = counts * counts_sum_inv\n",
"logprobs = probs.log()\n",
"loss = -logprobs[range(n), Yb].mean()\n",
"print(f'loss: {loss.item()}')\n",
"\n",
"loss = -logprobs[range(n), Yb].mean() # negative log-likelihood loss\n",
"print(f'loss: {loss.item()}') # Loss per minibatch\n",
"\n",
"# PyTorch backward pass\n",
"for p in parameters:\n",
Expand All @@ -392,23 +318,99 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 123,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a shape: torch.Size([3, 2])\n",
"b shape: torch.Size([3, 1])\n"
]
},
{
"data": {
"text/plain": [
"tensor([[ 1, 2],\n",
" [ 6, 8],\n",
" [15, 18]])"
]
},
"execution_count": 123,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = torch.tensor([[1, 2],\n",
" [3, 4],\n",
" [5, 6]])\n",
"\n",
"b = torch.tensor([[1], \n",
" [2],\n",
" [3]])\n",
"\n",
"print(f'a shape: {a.shape}')\n",
"print(f'b shape: {b.shape}')\n",
"\n",
"a*b"
]
},
{
"cell_type": "code",
"execution_count": 128,
"metadata": {
"id": "mO-8aqxK8PPw"
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"counts shape: torch.Size([32, 27])\n",
"counts_sum_inv shape: torch.Size([32, 1])\n",
"logprobs | exact: True | approximate: True | maxdiff: 0.0\n",
"probs | exact: True | approximate: True | maxdiff: 0.0\n",
"counts_sum_inv | exact: True | approximate: True | maxdiff: 0.0\n"
]
}
],
"source": [
"# Exercise 1: backprop through the whole thing manually, \n",
"# backpropagating through exactly all of the variables \n",
"# as they are defined in the forward pass above, one by one\n",
"\n",
"# -----------------\n",
"# YOUR CODE HERE :)\n",
"\n",
"\"\"\"\n",
"Suppose a = logprobs[0], b = logprobs[1], c = logprobs[2]\n",
"loss = -(a + b + c) / n\n",
"dloss/da = -1/n\n",
"dloss/db = -1/n\n",
"dloss/dc = -1/n\n",
"\n",
"So we can generalize this as follows:\n",
"dloss/dlogprobs = -1/n * I (for each input i)\n",
"\"\"\"\n",
"dlogprobs = torch.zeros_like(logprobs) # dlogprobs should be the same shape as logprobs\n",
"\n",
"# Only the correct class should have a gradient of -1/n since other guesses don't affect our loss\n",
"dlogprobs[range(n), Yb] = -1/n\n",
"\n",
"dprobs = dlogprobs * (1.0/probs)\n",
"\n",
"# Here we are accounting for the multiplication by counts and the replication across columns during broadcasting\n",
"dcounts_sum_inv = (dprobs * counts).sum(1, keepdim=True)\n",
"print(f'counts shape: {counts.shape}')\n",
"print(f'counts_sum_inv shape: {counts_sum_inv.shape}')\n",
"\n",
"# -----------------\n",
"\n",
"# cmp('logprobs', dlogprobs, logprobs)\n",
"# cmp('probs', dprobs, probs)\n",
"# cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)\n",
"cmp('logprobs', dlogprobs, logprobs)\n",
"cmp('probs', dprobs, probs)\n",
"cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)\n",
"# cmp('counts_sum', dcounts_sum, counts_sum)\n",
"# cmp('counts', dcounts, counts)\n",
"# cmp('norm_logits', dnorm_logits, norm_logits)\n",
Expand Down Expand Up @@ -744,6 +746,20 @@
" \n",
" print(''.join(itos[i] for i in out))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Questions:\n",
"- How we calculated `dcounts_sum_inv = (dprobs * counts).sum(1, keepdim=True)` was a bit confusing in regards to the sum along the 1st dimension. I don't fully understand the details here although I understand that we are taking into account the replication of counts along the columns by doing this summation. Covered around 26:27 in the video\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit c9d4055

Please sign in to comment.