cpm.hierarchical
cpm.hierarchical.EmpiricalBayes(optimiser=None, objective='minimise', iteration=1000, tolerance=1e-06, chain=4, quiet=False, **kwargs)
Implements an Expectation-Maximisation algorithm for the optimisation of the group-level distributions of the parameters of a model from subject-level parameter estimations.
Parameters: |
|
---|
Notes
The EmpiricalBayes class implements an Expectation-Maximisation algorithm for the optimisation of the group-level distributions of the parameters of a model from subject-level parameter estimations. For the complete description of the method, please see Gershman (2016).
The fitting function must return the Hessian matrix of the optimisation. The Hessian matrix is then used in establishing the within-subject variance of the parameters. It is also important to note that we will require the Hessian matrix of second derivatives of the negative log posterior (Gershman, 2016, p. 3). This requires us to minimise or maximise the log posterior density as opposed to a simple log likelihood, when estimating participant-level parameters.
In the current implementation, we try to calculate the second derivative of the negative log posterior density function according to the following algorithm:
- Attempt to use Cholesky decomposition.
- If fails, attempt to use LU decomposition.
- If fails, attempt to use QR decomposition.
- If the result is a complex number with zero imaginary part, keep the real part.
In addition, because the the Hessian matrix should correspond to the precision matrix, hence its inverse is the variance-covariance matrix, we will use its inverse to calculate the within-subject variance of the parameters. If the algorithm fails to calculate the inverse of the Hessian matrix, it will use the Moore-Penrose pseudoinverse instead.
The current implementation also controls for some edge-cases that are not covered by the algorithm above:
- When calculating the within-subject variance via the Hessian matrix, the algorithm clips the variance to a minimum value of 1e-6 to avoid numerical instability.
- When calculating the within-subject variance via the Hessian matrix, the algorithm sets any non-finite or non-positive values to NaN.
- If the second derivative of the negative log posterior density function is not finite, we set the log determinant to -1e6.
References
Gershman, S. J. (2016). Empirical priors for reinforcement learning models. Journal of Mathematical Psychology, 71, 1-6.
Examples:
>>> from cpm.optimisation import EmpiricalBayes
>>> from cpm.models import DeltaRule
>>> from cpm.optimisation import FminBound, minimise
>>>
>>> model = DeltaRule()
>>> optimiser = FminBound(
model=model,
data=data,
initial_guess=None,
number_of_starts=2,
minimisation=minimise.LogLikelihood.bernoulli,
parallel=False,
prior=True,
ppt_identifier="ppt",
display=False,
maxiter=200,
approx_grad=True
)
>>> eb = EmpiricalBayes(optimiser=optimiser, iteration=1000, tolerance=1e-6, chain=4)
>>> eb.optimise()
diagnostics(show=True, save=False, path=None)
Returns the convergence diagnostics plots for the group-level hyperparameters.
Parameters: |
|
---|
Notes
The convergence diagnostics plots show the convergence of the log model evidence, the means, and the standard deviations of the group-level hyperparameters. It also shows the distribution of the means and the standard deviations of the group-level hyperparameters sampled for each chain.
optimise()
This method runs the Expectation-Maximisation algorithm for the optimisation of the group-level distributions of the parameters of a model from subject-level parameter estimations. This is essentially the main function that runs the algorithm for multiple chains with random starting points for the priors.
parameters()
Returns the estimated individual-level parameters for each iteration and chain.
Returns: |
|
---|
stair(chain_index=0)
The main function that runs the Expectation-Maximisation algorithm for the optimisation of the group-level distributions of the parameters of a model from subject-level parameter estimations. This is essentially a single chain.
Returns: |
|
---|
cpm.hierarchical.VariationalBayes(optimiser=None, objective='minimise', iteration=50, tolerance_lme=0.001, tolerance_param=0.001, chain=4, hyperpriors=None, convergence='parameters', quiet=False, **kwargs)
Performs hierarchical Bayesian estimation of a given model using variational (approximate) inference methods, a reduced version of the Hierarchical Bayesian Inference (HBI) algorithm proposed by Piray et al. (2019), to exclude model comparison and selection.
Parameters: |
|
---|
Notes
The hyperprios are as follows:
a0
: array-like Vector of means of the normal prior on the population-level means, mu.b
: float Scalar value that is multiplied with population-level precisions, tau, to determine the standard deviations of the normal prior on the population-level means, mu.v
: float Scalar value that is used to determine the shape parameter (nu) of the gamma prior on population-level precisions, tau.s
: array-like Vector of values that serve as lower bounds on the scale parameters (sigma) of the gamma prior on population-level precisions, tau.
With the number of parameters as N, the default values are as follows:
a0
: np.zeros(N)b
: 1v
: 0.5s
: np.repeat(0.01, N)
The convergence criterion can be set to 'lme' or 'parameters'. If set to 'lme', the algorithm will stop when the log model evidence converges. If set to 'parameters', the algorithm will stop when the "normalized" means of the population-level parameters converge.
References
Piray, P., Dezfouli, A., Heskes, T., Frank, M. J., & Daw, N. D. (2019). Hierarchical Bayesian inference for concurrent model fitting and comparison for group studies. PLoS computational biology, 15(6), e1007043.
Examples:
check_convergence(lme_new, lme_old, param_snr_new, param_snr_old, iter_idx=0, use_lme=True, use_param=True)
Function to check if the algorithm has converged.
Parameters: |
|
---|
Returns: |
|
---|
diagnostics(show=True, save=False, path=None)
Returns the convergence diagnostics plots for the group-level hyperparameters.
Parameters: |
|
---|
Notes
The convergence diagnostics plots show the convergence of the log model evidence, the means, and the standard deviations of the group-level hyperparameters. It also shows the distribution of the means and the standard deviations of the group-level hyperparameters sampled for each chain.
get_lme(log_post, hessian)
Function to approximate the participant-wise log model evidence using Laplace's approximation.
Parameters: |
|
---|
Returns: |
|
---|
optimise()
Run the Variational Bayes algorithm for multiple chains.
run_vb(chain_index=0)
Run the hierarchical Bayesian inference algorithm.
Parameters: |
|
---|
Returns: |
|
---|
ttest(null=None)
Perform a one-sample Student's t-test on the estimated values of population-level means with respect to given null hypothesis values.
Parameters: |
|
---|
Returns: |
|
---|
update_population(param, hessian, lme, iter_idx=0, chain_idx=0)
Function to update the population-level parameters based on the results of participant-wise optimisation.
Parameters: |
|
---|
Returns: |
|
---|