Skip to content

Commit

Permalink
Change multidiff to multigrad
Browse files Browse the repository at this point in the history
  • Loading branch information
AlanPearl committed Jul 11, 2024
1 parent 6c8b6c2 commit 321a3ca
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 30 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ ENV/
env.bak/
venv.bak/

# VSCode project settings
.vscode

# Spyder project settings
.spyderproject
.spyproject
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ Welcome to kdescent's documentation!

installation.rst
notebooks/intro.ipynb
notebooks/integration.ipynb
notebooks/hmf_upweight.ipynb
notebooks/integration.ipynb
reference.rst

Indices and tables
Expand Down
13 changes: 10 additions & 3 deletions docs/source/notebooks/hmf_upweight.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
" ...\n",
"```\n",
"\n",
"At the end of this tutorial, we will show that we arrive at identical results with and without upweighting."
"At the end of this tutorial, we will show that we arrive at essentially identical results with and without upweighting."
]
},
{
Expand All @@ -44,6 +44,13 @@
"import kdescent"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define the model"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -140,7 +147,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define \"true\" parameters to generate training data"
"### Define \"true\" parameters to generate training data"
]
},
{
Expand Down Expand Up @@ -204,7 +211,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define plotting function\n",
"### Define plotting function\n",
"\n",
"- Plot the mass distribution + the color-color distribution in three separate mass bins"
]
Expand Down
30 changes: 13 additions & 17 deletions docs/source/notebooks/integration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Integration with `multidiff`"
"# Integration with `multigrad`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook will show an example and discuss some of the complexity arising from performing `kdescent` in parallel with the aid of the `multidiff` package. We will be following an identical example to that found in the [Advanced Usage](./intro.ipynb#Advanced-Usage) section of the tutorial. All procedural differences will be flagged with a `# NOTE: ...` comment.\n",
"This notebook will show an example and discuss some of the complexity arising from performing `kdescent` in parallel with the aid of the `multigrad` package. We will be following an identical example to that found in the [Advanced Usage](./intro.ipynb#Advanced-Usage) section of the tutorial. All procedural differences will be flagged with a `# NOTE: ...` comment.\n",
"\n",
"Prerequisite: `pip install multidiff`\n",
"Prerequisite: `pip install multigrad`\n",
"\n",
"In the following script, the `generate_model()` function will do the same thing as before, except it will divide the sample size by the number of MPI ranks available, so that a fully-sized sample will be generated across all the ranks (and we will split the randkey between the ranks so the subsamples are not identical). Let's save this script as `kdescent-multidiff-integration.py`:"
"In the following script, the `generate_model()` function will do the same thing as before, except it will divide the sample size by the number of MPI ranks available, so that a fully-sized sample will be generated across all the ranks (and we will split the randkey between the ranks so the subsamples are not identical). Let's save this script as `kdescent-multigrad-integration.py`:"
]
},
{
Expand All @@ -24,7 +24,7 @@
"source": [
"```python\n",
"\"\"\"\n",
"kdescent-multidiff-integration.py\n",
"kdescent-multigrad-integration.py\n",
"\"\"\"\n",
"\n",
"import functools\n",
Expand All @@ -37,16 +37,14 @@
"from mpi4py import MPI\n",
"\n",
"import kdescent\n",
"import multidiff\n",
"import multigrad\n",
"\n",
"comm = MPI.COMM_WORLD\n",
"\n",
"model_nsample = 40_000\n",
"data_nsample = 20_000 # same volume, but undersampled below logM* < 10.5\n",
"\n",
"# Generate data weighted from two mass-dependent multivariate normals\n",
"\n",
"\n",
"@functools.partial(jax.jit, static_argnames=[\"undersample\", \"nsample\"])\n",
"def generate_model(params, randkey, undersample=False, nsample=model_nsample):\n",
" # NOTE: Divide nsample and split randkey across MPI ranks:\n",
Expand Down Expand Up @@ -312,7 +310,7 @@
" return left * right\n",
"\n",
"\n",
"# NOTE: For MultiDiff, we have to explicitly define sumstats_from_params()\n",
"# NOTE: For multigrad, we have to explicitly define sumstats_from_params()\n",
"# and loss_from_sumstats() to replace the old lossfunc()\n",
"@jax.jit\n",
"def sumstats_from_params(params, randkey):\n",
Expand Down Expand Up @@ -413,7 +411,7 @@
" truth_massfunc = jnp.array([training_w_highmass.sum(),]) / volume\n",
"\n",
" # Must abs() the Fourier residuals so that the loss is real\n",
" # NOTE: We even have to abs() the PDF residuals due to MultiDiff\n",
" # NOTE: We even have to abs() the PDF residuals due to multigrad\n",
" # combining all sumstats into a single complex-typed array\n",
" sqerrs = jnp.abs(jnp.concatenate([\n",
" (model_low_condprob - truth_low_condprob)**2,\n",
Expand All @@ -427,11 +425,9 @@
"\n",
" return jnp.mean(sqerrs)\n",
"\n",
"# NOTE: Define MultiDiff class using the sumstats + loss funcs we just defined\n",
"\n",
"\n",
"# NOTE: Define multigrad class using the sumstats + loss funcs we just defined\n",
"@dataclass\n",
"class MyModel(multidiff.MultiDiffOnePointModel):\n",
"class MyModel(multigrad.OnePointModel):\n",
" sumstats_func_has_aux: bool = True # override param default set by parent\n",
"\n",
" def calc_partial_sumstats_from_params(self, params, randkey):\n",
Expand Down Expand Up @@ -472,7 +468,7 @@
" make_sumstat_plot(\n",
" adam_results[-1],\n",
" txt=f\"Solution after {nsteps} evalulations\", fig=figs[1])\n",
" plt.savefig(\"kdescent-multidiff-results.png\")\n",
" plt.savefig(\"kdescent-multigrad-results.png\")\n",
" else:\n",
" # All other ranks need to do this for make_sumstat_plot() to work...\n",
" generate_model_into_mass_bins(guess, jax.random.key(13))\n",
Expand All @@ -484,7 +480,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's run this with four MPI ranks. Executing `mpiexec -n 4 python kdescent-multidiff-integration.py` yields the following results (about 2x speedup on my laptop):"
"Now let's run this with four MPI ranks. Executing `mpiexec -n 4 python kdescent-multigrad-integration.py` yields the following results (about 2x speedup on my laptop):"
]
},
{
Expand All @@ -498,7 +494,7 @@
" 2.1207855 0.40100217 0.52279544 0.69977784 0.8434925 -0.41961044\n",
" 0.76283926 0.9049511 ]\n",
"```\n",
"![Results plot](./kdescent-multidiff-results.png)"
"![Results plot](./kdescent-multigrad-results.png)"
]
}
],
Expand Down
18 changes: 16 additions & 2 deletions docs/source/notebooks/intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define \"true\" parameters to generate training data"
"### Define \"true\" parameters to generate training data"
]
},
{
Expand Down Expand Up @@ -122,6 +122,13 @@
" return jnp.mean((model_kde_density - truth_kde_density)**2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run gradient descent"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -266,7 +273,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define \"true\" parameters to generate training data"
"### Define \"true\" parameters to generate training data"
]
},
{
Expand Down Expand Up @@ -555,6 +562,13 @@
" return jnp.mean(sqerrs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run gradient descent"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
kdescent-multidiff-integration.py
kdescent-multigrad-integration.py
"""

import functools
Expand All @@ -12,7 +12,7 @@
from mpi4py import MPI

import kdescent
import multidiff
import multigrad

comm = MPI.COMM_WORLD

Expand Down Expand Up @@ -287,7 +287,7 @@ def soft_tophat(x, low, high, squish=25.0):
return left * right


# NOTE: For MultiDiff, we have to explicitly define sumstats_from_params()
# NOTE: For multigrad, we have to explicitly define sumstats_from_params()
# and loss_from_sumstats() to replace the old lossfunc()
@jax.jit
def sumstats_from_params(params, randkey):
Expand Down Expand Up @@ -388,7 +388,7 @@ def loss_from_sumstats(sumstats, sumstats_aux):
truth_massfunc = jnp.array([training_w_highmass.sum(),]) / volume

# Must abs() the Fourier residuals so that the loss is real
# NOTE: We even have to abs() the PDF residuals due to MultiDiff
# NOTE: We even have to abs() the PDF residuals due to multigrad
# combining all sumstats into a single complex-typed array
sqerrs = jnp.abs(jnp.concatenate([
(model_low_condprob - truth_low_condprob)**2,
Expand All @@ -402,11 +402,11 @@ def loss_from_sumstats(sumstats, sumstats_aux):

return jnp.mean(sqerrs)

# NOTE: Define MultiDiff class using the sumstats + loss funcs we just defined
# NOTE: Define multigrad class using the sumstats + loss funcs we just defined


@dataclass
class MyModel(multidiff.MultiDiffOnePointModel):
class MyModel(multigrad.OnePointModel):
sumstats_func_has_aux: bool = True # override param default set by parent

def calc_partial_sumstats_from_params(self, params, randkey):
Expand Down Expand Up @@ -447,7 +447,7 @@ def calc_loss_from_sumstats(self, sumstats, sumstats_aux, randkey=None):
make_sumstat_plot(
adam_results[-1],
txt=f"Solution after {nsteps} evalulations", fig=figs[1])
plt.savefig("kdescent-multidiff-results.png")
plt.savefig("kdescent-multigrad-results.png")
else:
# All other ranks need to do this for make_sumstat_plot() to work...
generate_model_into_mass_bins(guess, jax.random.key(13))
Expand Down

0 comments on commit 321a3ca

Please sign in to comment.