Exploring hierarchical estimation of hyperparameters (cpm.hierarchical)¶
Now, we will move on to estimate hyperparameters for the priors of the parameters in our model. It can take quite a bit of time to run, so I already ran the simulation for you, so we can focus on exploring the results.
If you are interested in how to run it, check the file 04-fitting-hierarchical.py
in the exercises
or solutions
folder.
More information about the method as well can be found in the documentation of the cpm
package, in the section about hierarchical models. Briefly, we use an Expectation Maximisation algorithm to estimate the hyperparameters of the priors for the parameters in our model. The algorithm iteratively updates the hyperparameters until convergence. In plain language, it runs four times using different starting points for the hyperparameters, checking whether there is any change from one iteration to the next, and if not, it stops. The result is a set of hyperparameters that can be used to define the priors for the parameters in our model.
Let us move on to the code:
import warnings
import ipyparallel as ipp # for parallel computing with ipython (specific for Jupyter Notebook)
import cpm
import cpm.datasets as datasets
from cpm.generators import Parameters, Value
import numpy
import pandas as pd
import functions as f ## I prepped some models here, so you can use them directly
## plotting libraries
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
## Hierarchical fitting can take a lot of time depending on your setup, so I ran the simulation for you already, so we can focus on exploring the results.
hyperparameters = pd.read_csv('04-fitting-hierarchical-hyperparameters.csv')
hyperparameters.head()
Unnamed: 0 | chain | iteration | parameter | mean | mean_errorbar | sd | lme | mean_se | |
---|---|---|---|---|---|---|---|---|---|
0 | 0 | 1 | 1 | alpha | 0.253034 | 0 | 0.353337 | -1293.670284 | 0.070667 |
1 | 0 | 1 | 1 | temperature | 4.451520 | 0 | 5.354804 | -1293.670284 | 1.070961 |
2 | 0 | 1 | 2 | alpha | 0.219261 | 0 | 0.321423 | -1292.875412 | 0.064285 |
3 | 0 | 1 | 2 | temperature | 4.652763 | 0 | 5.398617 | -1292.875412 | 1.079723 |
4 | 0 | 1 | 3 | alpha | 0.219261 | 0 | 0.321407 | -1292.879355 | 0.064281 |
So, the first column is a leftover from the way I saved the file, ignore it for now.
Then we have the following columns of interest:
chain
: the chain number, we ran 4 chains in total. Each chain starts from a different set of hyperparameters.iteration
: the iteration number, we ran 40 iterations in total, but the optimisation can stop earlier.parameter
: the name of the parameter, we have 2 parameters in our model.mean
: the mean of the hyperparameter for the parameter.sd
: the standard deviation of the hyperparameter for the parameter.lme
: the log model evidence (a maximum likelihood estimate) that takes into account both the individual fits and the hyperparameters.
The remaining columns give us some extra information about the optimisation process, but we do not want to complicate things too much, so we will not focus on them now.
Explore the results¶
First, we will need to see whether our hyperparameters converged. We can start doing that by plotting the log model evidence over the iterations for each chain and parameter.
plt.figure(figsize=(10, 6))
sns.lineplot(
data=hyperparameters.drop_duplicates(subset=['chain', 'iteration']),
x='iteration',
y='lme',
hue='chain',
marker='o',
markersize=10,
palette='tab10',
)
plt.title('Log Model Evidence (lme) over Iterations for Each Chain')
plt.xlabel('Iteration')
plt.ylabel('Log Model Evidence (lme)')
plt.legend(title='Chain')
plt.tight_layout()
So, what is the first thing you notice here? All chains converged quite fast. That is a good thing, although it is often the case that it takes much longer. Now, let us check the mean and standard deviation of the hyperparameters for each parameter across all chains. Let's plot them.
fig, axes = plt.subplots(2, 1, figsize=(7, 10), sharey=False)
sns.lineplot(
data=hyperparameters[hyperparameters['parameter'] == 'alpha'],
x='iteration',
y='mean',
hue='chain',
marker='o',
palette='tab10',
ax=axes[0]
)
axes[0].set_title('Alpha Hyperparameter Mean by Iteration')
axes[0].set_xlabel('Iteration')
axes[0].set_ylabel('Mean (alpha)')
axes[0].legend(title='Chain')
axes[0].set_ylim(0, 1)
sns.lineplot(
data=hyperparameters[hyperparameters['parameter'] == 'temperature'],
x='iteration',
y='mean',
hue='chain',
marker='o',
palette='tab10',
ax=axes[1]
)
axes[1].set_title('Temperature Hyperparameter Mean by Iteration')
axes[1].set_xlabel('Iteration')
axes[1].set_ylabel('Mean (temperature)')
axes[1].legend(title='Chain')
axes[1].set_ylim(0, 10)
plt.tight_layout()
Refit model with the best hyperparameters¶
First, you will need to find the best hyperparameters for the model and set up the parameters object with the right values. You can use the mean
, sd
, and lme
columns for that.
hyperparameters[hyperparameters["lme"] == hyperparameters["lme"].max()]
Unnamed: 0 | chain | iteration | parameter | mean | mean_errorbar | sd | lme | mean_se | |
---|---|---|---|---|---|---|---|---|---|
22 | 0 | 2 | 9 | alpha | 0.219264 | 0 | 0.321424 | -1292.862916 | 0.064285 |
23 | 0 | 2 | 9 | temperature | 4.651864 | 0 | 5.399851 | -1292.862916 | 1.079970 |
Now that we have those, we will need to refit the model with the best hyperparameters. We will use the cpm.generators.Parameters
, where we can specify our priors. Then, we will use the FminBound
class to fit the model to the data. Throughout your every-day workflow, you can use the update_prior(**kwargs)
method of cpm.generators.Parameters
, which will look something like this:
wrapper.parameters.update_prior({
"alpha": {"mean": 0.21221, "sd": 0.3214},
"temperature": {"mean": 4.66, "sd": 5.3899}
})
For now, we have to reinitalise everything, because we ran the simulations outside of this Jupyter Notebook.
data = cpm.datasets.load_bandit_data()
data["observed"] = data["response"].astype(int) # convert response to int
parameters = Parameters(
# free parameters are indicated by specifying priors
alpha=Value(
value=0.5,
lower=1e-10,
upper=1,
prior="truncated_normal",
args={"mean": 0.219264, "sd": 0.321424}, # specify the mean and standard deviation of the prior
),
temperature=Value(
value=1,
lower=0,
upper=10,
prior="truncated_normal",
args={"mean": 4.651864, "sd": 5.399851}, # specify the mean and standard deviation of the prior
),
# everything without a prior is part of the initial state of the
# model or constructs fixed throughout the simulation
# (e.g. exemplars in general-context models of categorizations)
# initial q-values starting starting from non-zero value
# these are equal to all 4 stimuli (1 / 4)
values = numpy.array([0.25, 0.25, 0.25, 0.25])
)
@ipp.require("numpy")
def model(parameters, trial):
# pull out the parameters
alpha = parameters.alpha
temperature = parameters.temperature
values = numpy.array(parameters.values)
# pull out the trial information
stimulus = numpy.array([trial.arm_left, trial.arm_right]).astype(int)
feedback = numpy.array([trial.reward_left, trial.reward_right])
human_choice = trial.observed.astype(int)
# Equation 1. - get the value of each available action
# Note that because python counts from 0, we need to shift
# the stimulus identifiers by -1
expected_rewards = values[stimulus - 1]
# convert columns to rows
expected_rewards = expected_rewards.reshape(2, 1)
# calculate a policy based on the activations
# Equation 2.
choice_rule = cpm.models.decision.Softmax(
activations=expected_rewards,
temperature=temperature
)
choice_rule.compute() # compute the policy
# if the policy is NaN for an action, then we need to set it to 1
# this corrects some numerical issues with python and infinities
if numpy.isnan(choice_rule.policies).any():
choice_rule.policies[numpy.isnan(choice_rule.policies)] = 1
# get the received reward for the choice
reward = feedback[human_choice]
teacher = numpy.array([reward])
# we now create a vector that tells our learning rule what...
# ... stimulus to update according to the participant's choice
what_to_update = numpy.zeros(4)
chosen_stimulus = stimulus[human_choice] - 1
what_to_update[chosen_stimulus] = 1
# Equation 4.
update = cpm.models.learning.SeparableRule(
weights=values,
feedback=teacher,
input=what_to_update,
alpha=alpha
)
update.compute()
# Equation 5.
values += update.weights.flatten()
# compile output
output = {
"trial" : trial.trial.astype(int), # trial numbers
"activation" : expected_rewards.flatten(), # expected reward of arms
"policy" : choice_rule.policies, # policies
"reward" : reward, # received reward
"error" : update.weights, # prediction error
"values" : values, # updated values
# dependent variable
"dependent" : numpy.array([choice_rule.policies[1]]),
}
return output
generative_model = cpm.generators.Wrapper(
model=model,
parameters=parameters,
data=data[data.ppt == 1],
)
from cpm.optimisation import minimise, FminBound
# Set up the fitting procedure
fit = FminBound(
model=generative_model, # Wrapper class with the model we specified from before
data=data.groupby('ppt'), # the data as a list of dictionaries
minimisation=minimise.LogLikelihood.bernoulli,
parallel=True,
libraries=["numpy", "cpm", "pandas"],
prior=True,
ppt_identifier="ppt",
display=False,
number_of_starts=5,
# everything below is optional and passed directly to the scipy implementation of the optimiser
approx_grad=True
)
fit.optimise()
parameters_hierarchical = fit.export()
parameters_hierarchical.rename(
columns={
"x_0": "alpha",
"x_1": "temperature",
},
inplace=True
)
parameters_hierarchical.head()
Starting optimization 1/5 from [0.16404429 3.73148691] Starting 10 engines with <class 'ipyparallel.cluster.launcher.LocalEngineSetLauncher'>
0%| | 0/10 [00:00<?, ?engine/s]
Starting optimization 2/5 from [0.89653545 9.0386942 ] Starting 10 engines with <class 'ipyparallel.cluster.launcher.LocalEngineSetLauncher'>
0%| | 0/10 [00:00<?, ?engine/s]
Starting optimization 3/5 from [0.52102888 6.38880791] Starting 10 engines with <class 'ipyparallel.cluster.launcher.LocalEngineSetLauncher'>
0%| | 0/10 [00:00<?, ?engine/s]
Starting optimization 4/5 from [0.95700352 7.71708422] Starting 10 engines with <class 'ipyparallel.cluster.launcher.LocalEngineSetLauncher'>
0%| | 0/10 [00:00<?, ?engine/s]
Starting optimization 5/5 from [0.19076827 0.67803325] Starting 10 engines with <class 'ipyparallel.cluster.launcher.LocalEngineSetLauncher'>
0%| | 0/10 [00:00<?, ?engine/s]
alpha | temperature | grad_0 | grad_1 | task | funcalls | nit | warnflag | hessian_0 | hessian_1 | hessian_2 | hessian_3 | ppt | fun | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.082634 | 4.762181 | -7.105427e-07 | 0.000000 | CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL | 63 | 13 | 0 | 854.348341 | 25.037331 | 25.037331 | 0.999656 | 1 | 50.636195 |
1 | 0.070499 | 5.380910 | -4.973799e-06 | 0.000000 | CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL | 42 | 11 | 0 | 1205.527689 | 29.093655 | 29.093655 | 0.837651 | 2 | 49.321948 |
2 | 0.286731 | 1.195427 | -1.421085e-06 | -0.000003 | CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL | 27 | 7 | 0 | 25.664101 | 3.566809 | 3.566809 | 2.739243 | 3 | 66.528813 |
3 | 0.404007 | 1.855438 | 0.000000e+00 | 0.000000 | CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL | 57 | 10 | 0 | 63.797819 | 3.819208 | 3.819208 | 3.368577 | 4 | 61.809662 |
4 | 0.091344 | 10.000000 | -1.065814e-06 | -0.625223 | CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL | 48 | 10 | 0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 5 | 25.335649 |
Great, now just for the sake of completeness, let us estimate the parameters again, but without prior.
# Set up the fitting procedure
fit_no_prior = FminBound(
model=generative_model, # Wrapper class with the model we specified from before
data=data.groupby('ppt'), # the data as a list of dictionaries
minimisation=minimise.LogLikelihood.bernoulli,
parallel=True,
libraries=["numpy", "cpm", "pandas"],
prior=False, # no prior
ppt_identifier="ppt",
display=False,
number_of_starts=5,
# everything below is optional and passed directly to the scipy implementation of the optimiser
approx_grad=True
)
fit_no_prior.optimise()
parameters_no_prior = fit_no_prior.export()
parameters_no_prior.rename(
columns={
"x_0": "alpha",
"x_1": "temperature",
},
inplace=True
)
parameters_no_prior.head()
Starting optimization 1/5 from [0.46636063 0.2707254 ] Starting 10 engines with <class 'ipyparallel.cluster.launcher.LocalEngineSetLauncher'>
0%| | 0/10 [00:00<?, ?engine/s]
Starting optimization 2/5 from [0.16537263 1.61395966] Starting 10 engines with <class 'ipyparallel.cluster.launcher.LocalEngineSetLauncher'>
0%| | 0/10 [00:00<?, ?engine/s]
Starting optimization 3/5 from [0.748808 1.51560535] Starting 10 engines with <class 'ipyparallel.cluster.launcher.LocalEngineSetLauncher'>
0%| | 0/10 [00:00<?, ?engine/s]
Starting optimization 4/5 from [0.99763419 2.56325616] Starting 10 engines with <class 'ipyparallel.cluster.launcher.LocalEngineSetLauncher'>
0%| | 0/10 [00:00<?, ?engine/s]
Starting optimization 5/5 from [0.65124312 2.59952301] Starting 10 engines with <class 'ipyparallel.cluster.launcher.LocalEngineSetLauncher'>
0%| | 0/10 [00:00<?, ?engine/s]
alpha | temperature | grad_0 | grad_1 | task | funcalls | nit | warnflag | hessian_0 | hessian_1 | hessian_2 | hessian_3 | ppt | fun | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.076392 | 4.943618 | -3.552714e-06 | 7.105427e-07 | CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL | 57 | 13 | 0 | 1038.690956 | 27.070576 | 27.070576 | 0.890515 | 1 | 48.885637 |
1 | 0.056893 | 6.042440 | -6.394885e-06 | -7.105427e-07 | CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL | 66 | 16 | 0 | 2153.658871 | 35.755985 | 35.755985 | 0.629376 | 2 | 47.533037 |
2 | 0.356632 | 1.065563 | 2.842171e-06 | 0.000000e+00 | CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL | 30 | 8 | 0 | 16.604947 | 4.121967 | 4.121967 | 3.233532 | 3 | 64.615781 |
3 | 0.446523 | 1.775728 | 1.847411e-05 | 7.105427e-07 | CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH | 27 | 7 | 0 | 45.754484 | 4.659371 | 4.659371 | 3.575182 | 4 | 59.814704 |
4 | 0.090663 | 10.000000 | 3.552714e-07 | -8.154305e-01 | CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL | 51 | 10 | 0 | 1792.445875 | 9.953954 | 9.953954 | 0.262455 | 5 | 23.110100 |
Now that is done, we can compare the results of the hierarchical model with the results of the non-hierarchical model. We can do that by plotting the parameters of the hierarchical model against the parameters of the non-hierarchical model. And because we feel adventurous, we will also plot the priors of the hierarchical model in the background. Let's run the code below to see the results.
from scipy.stats import truncnorm
best_hyperparameters = hyperparameters[hyperparameters["lme"] == hyperparameters["lme"].max()]
alpha_mean = best_hyperparameters[best_hyperparameters["parameter"] == "alpha"]["mean"].values[0]
alpha_sd = best_hyperparameters[best_hyperparameters["parameter"] == "alpha"]["sd"].values[0]
temperature_mean = best_hyperparameters[best_hyperparameters["parameter"] == "temperature"]["mean"].values[0]
temperature_sd = best_hyperparameters[best_hyperparameters["parameter"] == "temperature"]["sd"].values[0]
num_samples = 5000
fig = plt.figure(figsize=(10, 10))
numpy.random.seed(984777324)
# For alpha: lower=0, upper=1
a_alpha = (0 - alpha_mean) / alpha_sd
b_alpha = (1 - alpha_mean) / alpha_sd
samples_alpha = truncnorm.rvs(a=a_alpha, b=b_alpha, loc=alpha_mean, scale=alpha_sd, size=num_samples)
# For temperature: lower=0, upper=10
a_temp = (0 - temperature_mean) / temperature_sd
b_temp = (10 - temperature_mean) / temperature_sd
samples_temperature = truncnorm.rvs(a=a_temp, b=b_temp, loc=temperature_mean, scale=temperature_sd, size=num_samples)
plt.scatter(x=parameters_hierarchical.alpha, y=parameters_hierarchical.temperature, color='red', label='Fitted Parameters', s=150, zorder=4, marker='+')
plt.scatter(x=parameters_no_prior.alpha, y=parameters_no_prior.temperature, color='black', label='Fitted Parameters', s=100, zorder=3, marker='o', alpha=0.5)
sns.kdeplot(x=samples_alpha, y=samples_temperature, cmap="Blues", thresh=0.05, linewidths=3)
plt.xlabel('Alpha')
plt.ylabel('Temperature')
plt.title('Comparison of Fitted Parameters and Hierarchical Priors')
plt.legend(['Hierarchical fit', 'Non-hierarchical fit', 'Hierarchical prior (KDE)'])
plt.tight_layout()
So let me explain the figure here. The two axes are the two parameters. The density contours plot the prior distributions of the parameters, which are the blue lines. The countours show the density of the prior distribution, with darker areas indicating higher density. The red dots are the estimates of the parameters from the hierarchical model, and the green dots are the estimates of the parameters from the non-hierarchical model.