Skip to content

Commit

Permalink
Merge pull request #26 from fehiepsi/update-copy
Browse files Browse the repository at this point in the history
Update pattern to copy jax array to numpy array
  • Loading branch information
fehiepsi authored May 6, 2023
2 parents bde704a + bec3188 commit 7333d38
Show file tree
Hide file tree
Showing 14 changed files with 158 additions and 90 deletions.
3 changes: 2 additions & 1 deletion notebooks/00_preface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
}
],
"source": [
"print(jnp.log(0.01 ** 200))\n",
"print(jnp.log(0.01**200))\n",
"print(200 * jnp.log(0.01))"
]
},
Expand Down Expand Up @@ -182,6 +182,7 @@
"# see cars.info() for details\n",
"cars = pd.read_csv(\"../data/cars.csv\", index_col=0)\n",
"\n",
"\n",
"# fit a linear regression of distance on speed\n",
"def model(speed, dist_):\n",
" mu = numpyro.param(\"a\", 0.0) + numpyro.param(\"b\", 1.0) * speed\n",
Expand Down
14 changes: 11 additions & 3 deletions notebooks/03_sampling_the_imaginary.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"\n",
"import arviz as az\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"from scipy.stats import gaussian_kde\n",
"\n",
Expand Down Expand Up @@ -661,7 +662,7 @@
],
"source": [
"dummy_w = dist.Binomial(total_count=9, probs=0.7).sample(random.PRNGKey(0), (100000,))\n",
"ax = az.plot_dist(dummy_w.copy(), kind=\"hist\", hist_kwargs={\"rwidth\": 0.1})\n",
"ax = az.plot_dist(np.asarray(dummy_w), kind=\"hist\", hist_kwargs={\"rwidth\": 0.1})\n",
"ax.set_xlabel(\"dummy water count\", fontsize=14)\n",
"plt.show()"
]
Expand Down Expand Up @@ -793,7 +794,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -807,7 +808,14 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.9.10"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
Expand Down
20 changes: 14 additions & 6 deletions notebooks/04_geocentric_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1278,6 +1278,7 @@
"# define the average weight, x-bar\n",
"xbar = d2.weight.mean()\n",
"\n",
"\n",
"# fit model\n",
"def model(weight, height):\n",
" a = numpyro.sample(\"a\", dist.Normal(178, 20))\n",
Expand Down Expand Up @@ -2072,7 +2073,7 @@
],
"source": [
"d[\"weight_s\"] = (d.weight - d.weight.mean()) / d.weight.std()\n",
"d[\"weight_s2\"] = d.weight_s ** 2\n",
"d[\"weight_s2\"] = d.weight_s**2\n",
"\n",
"\n",
"def model(weight_s, weight_s2, height=None):\n",
Expand Down Expand Up @@ -2143,7 +2144,7 @@
"outputs": [],
"source": [
"weight_seq = jnp.linspace(start=-2.2, stop=2, num=30)\n",
"pred_dat = {\"weight_s\": weight_seq, \"weight_s2\": weight_seq ** 2}\n",
"pred_dat = {\"weight_s\": weight_seq, \"weight_s2\": weight_seq**2}\n",
"post = m4_5.sample_posterior(random.PRNGKey(1), p4_5, (1000,))\n",
"predictive = Predictive(m4_5.model, post)\n",
"mu = predictive(random.PRNGKey(2), **pred_dat)[\"mu\"]\n",
Expand Down Expand Up @@ -2207,7 +2208,7 @@
}
],
"source": [
"d[\"weight_s3\"] = d.weight_s ** 3\n",
"d[\"weight_s3\"] = d.weight_s**3\n",
"\n",
"\n",
"def model(weight_s, weight_s2, weight_s3, height):\n",
Expand Down Expand Up @@ -2518,7 +2519,7 @@
"mu_PI = jnp.percentile(mu, q=jnp.array([1.5, 98.5]), axis=0)\n",
"az.plot_pair(\n",
" d2[[\"year\", \"doy\"]].astype(float).to_dict(orient=\"list\"),\n",
" scatter_kwargs={\"c\": \"royalblue\", \"alpha\": 0.3, \"markersize\": 10},\n",
" scatter_kwargs={\"c\": \"royalblue\", \"alpha\": 0.3, \"s\": 10},\n",
")\n",
"plt.fill_between(d2.year, mu_PI[0], mu_PI[1], color=\"k\", alpha=0.5)\n",
"plt.show()"
Expand Down Expand Up @@ -2563,7 +2564,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -2577,7 +2578,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.9.10"
},
"toc": {
"base_numbering": 1,
Expand All @@ -2591,6 +2592,13 @@
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
Expand Down
3 changes: 2 additions & 1 deletion notebooks/06_the_haunted_dag_and_the_causal_terror.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@
"svi_result = svi.run(random.PRNGKey(0), 1000)\n",
"p6_3 = svi_result.params\n",
"\n",
"\n",
"# kcal.per.g regressed on perc.lactose\n",
"def model(L, K):\n",
" a = numpyro.sample(\"a\", dist.Normal(0, 0.2))\n",
Expand Down Expand Up @@ -523,7 +524,7 @@
"\n",
"\n",
"def sim_coll(i, r=0.9):\n",
" sd = jnp.sqrt((1 - r ** 2) * jnp.var(d[\"perc.fat\"].values))\n",
" sd = jnp.sqrt((1 - r**2) * jnp.var(d[\"perc.fat\"].values))\n",
" x = dist.Normal(r * d[\"perc.fat\"].values, sd).sample(random.PRNGKey(3 * i))\n",
"\n",
" def model(perc_fat, kcal_per_g):\n",
Expand Down
4 changes: 2 additions & 2 deletions notebooks/07_ulysses_compass.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@
" a = numpyro.sample(\"a\", dist.Normal(0.5, 1))\n",
" b = numpyro.sample(\"b\", dist.Normal(0, 10).expand([2]))\n",
" log_sigma = numpyro.sample(\"log_sigma\", dist.Normal(0, 1))\n",
" mu = numpyro.deterministic(\"mu\", a + b[0] * mass_std + b[1] * mass_std ** 2)\n",
" mu = numpyro.deterministic(\"mu\", a + b[0] * mass_std + b[1] * mass_std**2)\n",
" numpyro.sample(\"brain_std\", dist.Normal(mu, jnp.exp(log_sigma)), obs=brain_std)\n",
"\n",
"\n",
Expand Down Expand Up @@ -303,7 +303,7 @@
" b = numpyro.sample(\"b\", dist.Normal(0, 10).expand([3]))\n",
" log_sigma = numpyro.sample(\"log_sigma\", dist.Normal(0, 1))\n",
" mu = numpyro.deterministic(\n",
" \"mu\", a + b[0] * mass_std + b[1] * mass_std ** 2 + b[2] * mass_std ** 3\n",
" \"mu\", a + b[0] * mass_std + b[1] * mass_std**2 + b[2] * mass_std**3\n",
" )\n",
" numpyro.sample(\"brain_std\", dist.Normal(mu, jnp.exp(log_sigma)), obs=brain_std)\n",
"\n",
Expand Down
28 changes: 18 additions & 10 deletions notebooks/09_markov_chain_monte_carlo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"\n",
"import arviz as az\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import jax.numpy as jnp\n",
Expand Down Expand Up @@ -143,7 +144,7 @@
}
],
"source": [
"plt.hist(positions.copy(), bins=range(1, 12), rwidth=0.1, align=\"left\")\n",
"plt.hist(np.asarray(positions), bins=range(1, 12), rwidth=0.1, align=\"left\")\n",
"plt.show()"
]
},
Expand Down Expand Up @@ -176,7 +177,7 @@
"Y = dist.MultivariateNormal(jnp.repeat(0, D), jnp.identity(D)).sample(\n",
" random.PRNGKey(0), (T,)\n",
")\n",
"rad_dist = lambda Y: jnp.sqrt(jnp.sum(Y ** 2))\n",
"rad_dist = lambda Y: jnp.sqrt(jnp.sum(Y**2))\n",
"Rd = lax.map(lambda i: rad_dist(Y[i]), jnp.arange(T))\n",
"az.plot_kde(Rd, bw=0.18)\n",
"plt.show()"
Expand Down Expand Up @@ -225,8 +226,8 @@
"def U_gradient(q, a=0, b=1, k=0, d=1):\n",
" muy = q[0]\n",
" mux = q[1]\n",
" G1 = jnp.sum(y - muy) + (a - muy) / b ** 2 # dU/dmuy\n",
" G2 = jnp.sum(x - mux) + (k - mux) / b ** 2 # dU/dmux\n",
" G1 = jnp.sum(y - muy) + (a - muy) / b**2 # dU/dmuy\n",
" G2 = jnp.sum(x - mux) + (k - mux) / b**2 # dU/dmux\n",
" return jnp.stack([-G1, -G2]) # negative bc energy is neg-log-prob\n",
"\n",
"\n",
Expand Down Expand Up @@ -291,9 +292,9 @@
" p = -p\n",
" # Evaluate potential and kinetic energies at start and end of trajectory\n",
" current_U = U(current_q)\n",
" current_K = jnp.sum(current_p ** 2) / 2\n",
" current_K = jnp.sum(current_p**2) / 2\n",
" proposed_U = U(q)\n",
" proposed_K = jnp.sum(p ** 2) / 2\n",
" proposed_K = jnp.sum(p**2) / 2\n",
" # Accept or reject the state at end of trajectory, returning either\n",
" # the position at the end of the trajectory or the initial position\n",
" accept = 0\n",
Expand Down Expand Up @@ -338,7 +339,7 @@
" # for fancy arrows\n",
" dx = Q[\"traj\"][L, 0] - Q[\"traj\"][L - 1, 0]\n",
" dy = Q[\"traj\"][L, 1] - Q[\"traj\"][L - 1, 1]\n",
" d = jnp.sqrt(dx ** 2 + dy ** 2)\n",
" d = jnp.sqrt(dx**2 + dy**2)\n",
" plt.annotate(\n",
" \"\",\n",
" (Q[\"traj\"][L - 1, 0], Q[\"traj\"][L - 1, 1]),\n",
Expand Down Expand Up @@ -1167,7 +1168,7 @@
"m9_4.print_summary()\n",
"print(\n",
" \"There were {} transitions that exceeded the maximum treedepth.\".format(\n",
" (m9_4.get_extra_fields()[\"num_steps\"] + 1 == 2 ** 10).sum()\n",
" (m9_4.get_extra_fields()[\"num_steps\"] + 1 == 2**10).sum()\n",
" )\n",
")"
]
Expand Down Expand Up @@ -1475,7 +1476,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -1489,7 +1490,14 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.9.10"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@
],
"source": [
"p = 0.7\n",
"A = jnp.array([(1 - p) ** 2, p * (1 - p), (1 - p) * p, p ** 2])\n",
"A = jnp.array([(1 - p) ** 2, p * (1 - p), (1 - p) * p, p**2])\n",
"A"
]
},
Expand Down
18 changes: 14 additions & 4 deletions notebooks/11_god_spiked_the_integers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"\n",
"import arviz as az\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import jax.numpy as jnp\n",
Expand Down Expand Up @@ -1708,8 +1709,8 @@
" list(\n",
" map(\n",
" lambda k: jnp.divide(\n",
" d.applications[dat_list[\"dept_id\"].copy() == k].values,\n",
" d.applications[dat_list[\"dept_id\"].copy() == k].sum(),\n",
" d.applications[np.asarray(dat_list[\"dept_id\"]) == k].values,\n",
" d.applications[np.asarray(dat_list[\"dept_id\"]) == k].sum(),\n",
" ),\n",
" range(6),\n",
" )\n",
Expand Down Expand Up @@ -2267,6 +2268,7 @@
"source": [
"dat = dict(T=d.total_tools.values, P=d.P.values, cid=d.contact_id.values)\n",
"\n",
"\n",
"# intercept only\n",
"def model(T=None):\n",
" a = numpyro.sample(\"a\", dist.Normal(3, 0.5))\n",
Expand All @@ -2277,6 +2279,7 @@
"m11_9 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)\n",
"m11_9.run(random.PRNGKey(0), dat[\"T\"])\n",
"\n",
"\n",
"# interaction model\n",
"def model(cid, P, T=None):\n",
" a = numpyro.sample(\"a\", dist.Normal(3, 0.5).expand([2]))\n",
Expand Down Expand Up @@ -3199,7 +3202,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -3213,7 +3216,14 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.9.10"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
Expand Down
18 changes: 13 additions & 5 deletions notebooks/12_monsters_and_mixtures.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"\n",
"import arviz as az\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"from IPython.display import set_matplotlib_formats\n",
"\n",
Expand Down Expand Up @@ -435,7 +436,7 @@
}
],
"source": [
"plt.hist(y.copy(), color=\"k\", bins=jnp.arange(-0.5, 6), rwidth=0.1)\n",
"plt.hist(np.asarray(y), color=\"k\", bins=jnp.arange(-0.5, 6), rwidth=0.1)\n",
"plt.gca().set(xlabel=\"manuscripts completed\")\n",
"zeros_drink = jnp.sum(drink)\n",
"zeros_work = jnp.sum((y == 0) & (drink == 0))\n",
Expand Down Expand Up @@ -659,15 +660,15 @@
"\n",
"\n",
"m12_3_alt = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)\n",
"m12_3_alt.run(random.PRNGKey(0), y=y.copy())\n",
"m12_3_alt.run(random.PRNGKey(0), y=np.asarray(y))\n",
"m12_3_alt.print_summary(0.89)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note:** JAX 0.2 requires that array boolean indices must be concrete. So to make `log_prob[y > 0]` work, we need to use a concrete NumPy ndarray `y` (obtained by `y.copy()`) instead of JAX's `DeviceArray`."
"**Note:** JAX 0.2 requires that array boolean indices must be concrete. So to make `log_prob[y > 0]` work, we need to use a concrete NumPy ndarray `y` (obtained by `np.asarray(y)`) instead of JAX's `DeviceArray`."
]
},
{
Expand Down Expand Up @@ -1817,7 +1818,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -1831,7 +1832,14 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.9.10"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 7333d38

Please sign in to comment.