Skip to content

Commit

Permalink
Add model for tutorial
Browse files Browse the repository at this point in the history
Summary:

Test Plan:
  • Loading branch information
EntilZha committed Mar 13, 2024
1 parent c8562bf commit d9b58b0
Showing 1 changed file with 280 additions and 0 deletions.
280 changes: 280 additions & 0 deletions py_irt/models/tutorial_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
# MIT License

# Copyright (c) 2019 John Lalor <[email protected]> and Pedro Rodriguez <[email protected]>

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


# pylint: disable=unused-argument,unused-variable,not-callable,no-name-in-module,no-member,protected-access
from functools import partial
from py_irt.models import abstract_model

import pandas as pd
import pyro
import pyro.distributions as dist
import torch
import torch.distributions.constraints as constraints
from pyro.infer import EmpiricalMarginal
from rich.console import Console

import numpy as np

console = Console()


@abstract_model.IrtModel.register("tutorial")
class FourParamLog(abstract_model.IrtModel):
"""3PL IRT Model, with disc, diff, skill"""

# pylint: disable=not-callable
def __init__(
self,
*,
device: str,
num_items: int,
num_subjects: int,
verbose: bool = False,
**kwargs,
):
super().__init__(
num_items=num_items,
num_subjects=num_subjects,
device=device,
verbose=verbose,
)

def model_hierarchical(self, subjects, items, obs):
mu_b = pyro.sample(
"mu_b",
dist.Normal(
torch.tensor(0.0, device=self.device),
torch.tensor(1.0e6, device=self.device),
),
)
u_b = pyro.sample(
"u_b",
dist.Gamma(
torch.tensor(1.0, device=self.device),
torch.tensor(1.0, device=self.device),
),
)

mu_theta = pyro.sample(
"mu_theta",
dist.Normal(
torch.tensor(0.0, device=self.device),
torch.tensor(1.0e6, device=self.device),
),
)
u_theta = pyro.sample(
"u_theta",
dist.Gamma(
torch.tensor(1.0, device=self.device),
torch.tensor(1.0, device=self.device),
),
)

mu_gamma = pyro.sample(
"mu_gamma",
dist.Normal(
torch.tensor(0.0, device=self.device),
torch.tensor(1.0e6, device=self.device),
),
)
u_gamma = pyro.sample(
"u_gamma",
dist.Gamma(
torch.tensor(1.0, device=self.device),
torch.tensor(1.0, device=self.device),
),
)

with pyro.plate("thetas", self.num_subjects, device=self.device):
ability = pyro.sample("theta", dist.Normal(mu_theta, 1.0 / u_theta))

with pyro.plate("bs", self.num_items, device=self.device):
diff = pyro.sample("b", dist.Normal(mu_b, 1.0 / u_b))

with pyro.plate("gammas", self.num_items, device=self.device):
# We want to make disc non-negative, which we can implement by changing from:
# disc ~ Normal
# To:
# log disc ~ Normal
# Re-arranging:
# disc ~ exp(Normal)
# Or just draw from LogNormal
disc = pyro.sample("gamma", dist.LogNormal(mu_gamma, 1.0 / u_gamma))

with pyro.plate("observe_data", obs.size(0)):
p_star = torch.sigmoid(disc[items] * (ability[subjects] - diff[items]))
pyro.sample(
"obs",
dist.Bernoulli(probs=p_star),
obs=obs,
)

def guide_hierarchical(self, subjects, items, obs):
loc_mu_b_param = pyro.param("loc_mu_b", torch.tensor(0.0, device=self.device))
scale_mu_b_param = pyro.param(
"scale_mu_b",
torch.tensor(1.0e2, device=self.device),
constraint=constraints.positive,
)
loc_mu_gamma_param = pyro.param(
"loc_mu_gamma", torch.tensor(0.0, device=self.device)
)
scale_mu_gamma_param = pyro.param(
"scale_mu_gamma",
torch.tensor(1.0e2, device=self.device),
constraint=constraints.positive,
)
loc_mu_theta_param = pyro.param(
"loc_mu_theta", torch.tensor(0.0, device=self.device)
)
scale_mu_theta_param = pyro.param(
"scale_mu_theta",
torch.tensor(1.0e2, device=self.device),
constraint=constraints.positive,
)
alpha_b_param = pyro.param(
"alpha_b",
torch.tensor(1.0, device=self.device),
constraint=constraints.positive,
)
beta_b_param = pyro.param(
"beta_b",
torch.tensor(1.0, device=self.device),
constraint=constraints.positive,
)
alpha_gamma_param = pyro.param(
"alpha_gamma",
torch.tensor(1.0, device=self.device),
constraint=constraints.positive,
)
beta_gamma_param = pyro.param(
"beta_gamma",
torch.tensor(1.0, device=self.device),
constraint=constraints.positive,
)
alpha_theta_param = pyro.param(
"alpha_theta",
torch.tensor(1.0, device=self.device),
constraint=constraints.positive,
)
beta_theta_param = pyro.param(
"beta_theta",
torch.tensor(1.0, device=self.device),
constraint=constraints.positive,
)
m_theta_param = pyro.param(
"loc_ability", torch.zeros(self.num_subjects, device=self.device)
)
s_theta_param = pyro.param(
"scale_ability",
torch.ones(self.num_subjects, device=self.device),
constraint=constraints.positive,
)
m_b_param = pyro.param(
"loc_diff", torch.zeros(self.num_items, device=self.device)
)
s_b_param = pyro.param(
"scale_diff",
torch.ones(self.num_items, device=self.device),
constraint=constraints.positive,
)
m_gamma_param = pyro.param(
"loc_disc", torch.zeros(self.num_items, device=self.device)
)
s_gamma_param = pyro.param(
"scale_disc",
torch.ones(self.num_items, device=self.device),
constraint=constraints.positive,
)

# sample statements
mu_b = pyro.sample("mu_b", dist.Normal(loc_mu_b_param, scale_mu_b_param))
u_b = pyro.sample("u_b", dist.Gamma(alpha_b_param, beta_b_param))

mu_gamma = pyro.sample(
"mu_gamma", dist.Normal(loc_mu_gamma_param, scale_mu_gamma_param)
)
u_gamma = pyro.sample(
"u_gamma", dist.Gamma(alpha_gamma_param, beta_gamma_param)
)

mu_theta = pyro.sample(
"mu_theta", dist.Normal(loc_mu_theta_param, scale_mu_theta_param)
)
u_theta = pyro.sample(
"u_theta", dist.Gamma(alpha_theta_param, beta_theta_param)
)

with pyro.plate("thetas", self.num_subjects, device=self.device):
pyro.sample("theta", dist.Normal(m_theta_param, s_theta_param))

with pyro.plate("bs", self.num_items, device=self.device):
pyro.sample("b", dist.Normal(m_b_param, s_b_param))

with pyro.plate("gammas", self.num_items, device=self.device):
pyro.sample("gamma", dist.LogNormal(m_gamma_param, s_gamma_param))

def export(self):
return {
"ability": pyro.param("loc_ability").data.tolist(),
"diff": pyro.param("loc_diff").data.tolist(),
"disc": pyro.param("loc_disc").data.tolist(),
}

def predict(self, subjects, items, params_from_file=None):
"""predict p(correct | params) for a specified list of model, item pairs"""
if params_from_file is not None:
model_params = params_from_file
else:
model_params = self.export()
abilities = np.array([model_params["ability"][i] for i in subjects])
diffs = np.array([model_params["diff"][i] for i in items])
discs = np.array([model_params["disc"][i] for i in items])
return 1 / (1 + np.exp(-discs * (abilities - diffs)))

def get_guide(self):
return self.guide_hierarchical

def get_model(self):
return self.model_hierarchical

def summary(self, traces, sites):
marginal = (
EmpiricalMarginal(traces, sites)
._get_samples_and_weights()[0]
.detach()
.cpu()
.numpy()
)
print(marginal)
site_stats = {}
for i in range(marginal.shape[1]):
site_name = sites[i]
marginal_site = pd.DataFrame(marginal[:, i]).transpose()
describe = partial(
pd.Series.describe, percentiles=[0.05, 0.25, 0.5, 0.75, 0.95]
)
site_stats[site_name] = marginal_site.apply(describe, axis=1)[
["mean", "std", "5%", "25%", "50%", "75%", "95%"]
]
return site_stats

0 comments on commit d9b58b0

Please sign in to comment.