cpm.hierarchical
cpm.hierarchical.EmpiricalBayes(optimiser=None, objective='minimise', iteration=1000, tolerance=1e-06, chain=4, quiet=False, **kwargs)
Implements an Expectation-Maximisation algorithm for the optimisation of the group-level distributions of the parameters of a model from subject-level parameter estimations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
optimiser |
object
|
The initialized Optimiser object. It must use an optimisation algorithm that also returns the Hessian matrix. |
None
|
objective |
str
|
The objective of the optimisation, either 'maximise' or 'minimise'. Default is 'minimise'. Only affects how we arrive at the participant-level a posteriori parameter estimates. |
'minimise'
|
iteration |
int, optional
|
The maximum number of iterations. Default is 1000. |
1000
|
tolerance |
float, optional
|
The tolerance for convergence. Default is 1e-6. |
1e-06
|
chain |
int, optional
|
The number of random parameter initialisations. Default is 4. |
4
|
quiet |
bool, optional
|
Whether to suppress the output. Default is False. |
False
|
Notes
The EmpiricalBayes class implements an Expectation-Maximisation algorithm for the optimisation of the group-level distributions of the parameters of a model from subject-level parameter estimations. For the complete description of the method, please see Gershman (2016).
The fitting function must return the Hessian matrix of the optimisation. The Hessian matrix is then used in establishing the within-subject variance of the parameters. It is also important to note that we will require the Hessian matrix of second derivatives of the negative log posterior (Gershman, 2016, p. 3). This requires us to minimise or maximise the log posterior density as opposed to a simple log likelihood, when estimating participant-level parameters.
In the current implementation, we try to calculate the second derivative of the negative log posterior density function according to the following algorithm:
- Attempt to use Cholesky decomposition.
- If fails, attempt to use LU decomposition.
- If fails, attempt to use QR decomposition.
- If the result is a complex number with zero imaginary part, keep the real part.
In addition, because the the Hessian matrix should correspond to the precision matrix, hence its inverse is the variance-covariance matrix, we will use its inverse to calculate the within-subject variance of the parameters. If the algorithm fails to calculate the inverse of the Hessian matrix, it will use the Moore-Penrose pseudoinverse instead.
The current implementation also controls for some edge-cases that are not covered by the algorithm above:
- When calculating the within-subject variance via the Hessian matrix, the algorithm clips the variance to a minimum value of 1e-6 to avoid numerical instability.
- When calculating the within-subject variance via the Hessian matrix, the algorithm sets any non-finite or non-positive values to NaN.
- If the second derivative of the negative log posterior density function is not finite, we set the log determinant to -1e6.
References
Gershman, S. J. (2016). Empirical priors for reinforcement learning models. Journal of Mathematical Psychology, 71, 1-6.
Examples:
>>> from cpm.optimisation import EmpiricalBayes
>>> from cpm.models import DeltaRule
>>> from cpm.optimisation import FminBound, minimise
>>>
>>> model = DeltaRule()
>>> optimiser = FminBound(
model=model,
data=data,
initial_guess=None,
number_of_starts=2,
minimisation=minimise.LogLikelihood.bernoulli,
parallel=False,
prior=True,
ppt_identifier="ppt",
display=False,
maxiter=200,
approx_grad=True
)
>>> eb = EmpiricalBayes(optimiser=optimiser, iteration=1000, tolerance=1e-6, chain=4)
>>> eb.optimise()
diagnostics(show=True, save=False, path=None)
Returns the convergence diagnostics plots for the group-level hyperparameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
show |
bool, optional
|
Whether to show the plots. Default is True. |
True
|
save |
bool, optional
|
Whether to save the plots. Default is False. |
False
|
path |
str, optional
|
The path to save the plots. Default is None. |
None
|
Notes
The convergence diagnostics plots show the convergence of the log model evidence, the means, and the standard deviations of the group-level hyperparameters. It also shows the distribution of the means and the standard deviations of the group-level hyperparameters sampled for each chain.
optimise()
This method runs the Expectation-Maximisation algorithm for the optimisation of the group-level distributions of the parameters of a model from subject-level parameter estimations. This is essentially the main function that runs the algorithm for multiple chains with random starting points for the priors.
parameters()
Returns the estimated individual-level parameters for each iteration and chain.
Returns:
Type | Description |
---|---|
pandas.DataFrame
|
The estimated individual-level parameters for each iteration and chain. |
stair(chain_index=0)
The main function that runs the Expectation-Maximisation algorithm for the optimisation of the group-level distributions of the parameters of a model from subject-level parameter estimations. This is essentially a single chain.
Returns:
Type | Description |
---|---|
dict
|
A dictionary containing the log model evidence, the hyperparameters of the group-level distributions, and the parameters of the model. |
cpm.hierarchical.VariationalBayes(optimiser=None, objective='minimise', iteration=50, tolerance_lme=0.001, tolerance_param=0.001, chain=4, hyperpriors=None, convergence='parameters', quiet=False, **kwargs)
Performs hierarchical Bayesian estimation of a given model using variational (approximate) inference methods, a reduced version of the Hierarchical Bayesian Inference (HBI) algorithm proposed by Piray et al. (2019), to exclude model comparison and selection.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
optimiser |
object
|
The initialized Optimiser object. It must use an optimisation algorithm that also returns the Hessian matrix. |
None
|
objective |
str
|
The objective of the optimisation, either 'maximise' or 'minimise'. Default is 'minimise'. Only affects how we arrive at the participant-level a posteriori parameter estimates. |
'minimise'
|
iteration |
int, optional
|
The maximum number of iterations. Default is 1000. |
50
|
tolerance_lme |
float, optional
|
The tolerance for convergence with respect to the log model evidence. Default is 1e-3. |
0.001
|
tolerance_param |
float, optional
|
The tolerance for convergence with respect to the "normalized" means of parameters. Default is 1e-3. |
0.001
|
chain |
int, optional
|
The number of random parameter initialisations. Default is 4. |
4
|
hyperpriors |
A dictionary of given parameter values of the prior distributions on the population-level parameters (means mu and precisions tau). See Notes for details. Default is None. |
None
|
|
convergence |
str, optional
|
The convergence criterion. Default is 'parameters'. Options are 'lme' and 'parameters'. |
'parameters'
|
Notes
The hyperprios are as follows:
a0
: array-like Vector of means of the normal prior on the population-level means, mu.b
: float Scalar value that is multiplied with population-level precisions, tau, to determine the standard deviations of the normal prior on the population-level means, mu.v
: float Scalar value that is used to determine the shape parameter (nu) of the gamma prior on population-level precisions, tau.s
: array-like Vector of values that serve as lower bounds on the scale parameters (sigma) of the gamma prior on population-level precisions, tau.
With the number of parameters as N, the default values are as follows:
a0
: np.zeros(N)b
: 1v
: 0.5s
: np.repeat(0.01, N)
The convergence criterion can be set to 'lme' or 'parameters'. If set to 'lme', the algorithm will stop when the log model evidence converges. If set to 'parameters', the algorithm will stop when the "normalized" means of the population-level parameters converge.
References
Piray, P., Dezfouli, A., Heskes, T., Frank, M. J., & Daw, N. D. (2019). Hierarchical Bayesian inference for concurrent model fitting and comparison for group studies. PLoS computational biology, 15(6), e1007043.
Examples:
check_convergence(lme_new, lme_old, param_snr_new, param_snr_old, iter_idx=0, use_lme=True, use_param=True)
Function to check if the algorithm has converged.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
lme_new |
float
|
The new log model evidence. |
required |
lme_old |
float
|
The old log model evidence. |
required |
param_snr_new |
array-like
|
The new standardised estimates of population-level means. |
required |
param_snr_old |
array-like
|
The old standardised estimates of population-level means. |
required |
iter_idx |
int, optional
|
The iteration index. Default is 0. |
0
|
use_lme |
bool, optional
|
Whether to use the log model evidence for checking convergence. Default is True. |
True
|
use_param |
bool, optional
|
Whether to use the standardised estimates of population-level means for checking convergence. Default is True. |
True
|
Returns:
Name | Type | Description |
---|---|---|
convergence |
bool
|
Whether the algorithm has converged. |
diagnostics(show=True, save=False, path=None)
Returns the convergence diagnostics plots for the group-level hyperparameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
show |
bool, optional
|
Whether to show the plots. Default is True. |
True
|
save |
bool, optional
|
Whether to save the plots. Default is False. |
False
|
path |
str, optional
|
The path to save the plots. Default is None. |
None
|
Notes
The convergence diagnostics plots show the convergence of the log model evidence, the means, and the standard deviations of the group-level hyperparameters. It also shows the distribution of the means and the standard deviations of the group-level hyperparameters sampled for each chain.
get_lme(log_post, hessian)
Function to approximate the participant-wise log model evidence using Laplace's approximation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_post |
array-like
|
Participant-wise value of log posterior density function at the mode (i.e., MAP parameter estimates). |
required |
hessian |
array-like
|
Participant-wise Hessian matrix of log posterior density function evaluated at the mode (i.e., MAP parameter estimates). |
required |
Returns:
Name | Type | Description |
---|---|---|
lme |
array
|
Participant-wise log model evidence. |
lme_sum |
float
|
Summed log model evidence. |
optimise()
Run the Variational Bayes algorithm for multiple chains.
run_vb(chain_index=0)
Run the hierarchical Bayesian inference algorithm.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
chain_index |
int, optional
|
The chain index. Default is 0. |
0
|
Returns:
Name | Type | Description |
---|---|---|
output |
dict
|
Dictionary of results. |
ttest(null=None)
Perform a one-sample Student's t-test on the estimated values of population-level means with respect to given null hypothesis values.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
null |
dict or pd.DataFrame
|
The null hypothesis values for the population-level means for each parameters. |
None
|
Returns:
Name | Type | Description |
---|---|---|
t_df |
pd.DataFrame
|
The results of the t-test. |
update_population(param, hessian, lme, iter_idx=0, chain_idx=0)
Function to update the population-level parameters based on the results of participant-wise optimisation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
param |
array-like
|
Participant-wise parameter estimates. |
required |
hessian |
array-like
|
Participant-wise Hessian matrices of the log posterior density function evaluated at the mode (i.e., MAP parameter estimates). |
required |
lme |
float
|
Summed log model evidence. |
required |
iter_idx |
int, optional
|
The iteration index. Default is 0. |
0
|
chain_idx |
int, optional
|
The chain index. Default is 0. |
0
|
Returns:
Name | Type | Description |
---|---|---|
population_updates |
dict
|
Dictionary of updated population-level parameters. |
param_snr |
array-like
|
Standardised estimates of population-level means. |