Sometimes you are developing a model that has multiple variants: maybe you want to consider several different link functions somewhere deep in your model, or you want to switch between estimating a quantity and getting it as data or something completely different. In these cases, you might have wanted to use optional parameters and/or data that apply only to some variants of your model. Sadly, Stan does not support this feature directly, but you can implement it yourself with just a bit of additional code. In this post I will show how.
The Base Model
Let’s start with a very simple model: just estimating the mean and standard deviation of a normal distribution:
library(rstan)
library(knitr)
library(tidyverse)
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)
set.seed(3145678)
model_fixed_code <- "
data {
int N;
vector[N] X;
}
parameters {
real mu;
real<lower=0> sigma;
}
model {
X ~ normal(mu, sigma);
//And some priors
mu ~ normal(0, 10);
sigma ~ student_t(3, 0, 1);
}
"
model_fixed <- stan_model(model_code = model_fixed_code)
And let’s simulate some data and see that it fits:
mu_true = 8
sigma_true = 2
N = 10
X <- rnorm(N, mean = mu_true, sd = sigma_true)
data_fixed <- list(N = N, X = X)
fit_fixed <- sampling(model_fixed, data = data_fixed, iter = 500)
summary(fit_fixed, probs = c(0.1, 0.9))$summary %>% kable()
mean | se_mean | sd | 10% | 90% | n_eff | Rhat | |
---|---|---|---|---|---|---|---|
mu | 7.855031 | 0.0256139 | 0.5632183 | 7.162485 | 8.548415 | 483.5059 | 1.007501 |
sigma | 1.774158 | 0.0206974 | 0.4400573 | 1.302616 | 2.350727 | 452.0508 | 1.003409 |
lp__ | -12.103350 | 0.0555738 | 1.1132479 | -13.664610 | -11.091775 | 401.2768 | 1.004955 |
Now With Optional Parameters
Let’s say we now want to handle the case where the standard deviation is known. Obviously we could write a new model. But what if the full model has several hundred lines and the only thing we want to change is to let the user specify the known standard deviation? The simplest solution is to just have all parameters/data that are needed in any of the variants lying around and use if
conditions in the model block to ignore some of them, but that is a bit unsatisfactory (and also those unused parameters may in some cases hinder sampling).
For a better solution, we can take advantage of the fact that Stan allows zero-sized arrays/vectors and features the ternary operator ?
. The ternary operator has the syntax (condition) ? (true value) : (false value)
and works like an if - else
statement, but within an expression. The last piece of the puzzle is that Stan allows size of data and parameter arrays to depend on arbitrary expressions computed from data. The model that can handle both known and unknown standard deviation follows:
model_optional_code <- "
data {
int N;
vector[N] X;
//Just a verbose way to specify boolean variable
int<lower = 0, upper = 1> sigma_known;
//sigma_data is size 0 if sigma_known is FALSE
real<lower=0> sigma_data[sigma_known ? 1 : 0];
}
parameters {
real mu;
//sigma is size 0 if sigma_known is TRUE
real<lower=0> sigma_param[sigma_known ? 0 : 1];
}
transformed parameters {
real<lower=0> sigma;
if (sigma_known) {
sigma = sigma_data[1];
} else {
sigma = sigma_param[1];
}
}
model {
X ~ normal(mu, sigma);
//And some priors
mu ~ normal(0, 10);
if (!sigma_known) {
sigma_param ~ student_t(3, 0, 1);
}
}
"
model_optional <- stan_model(model_code = model_optional_code)
We had to add some biolerplate code, but now we don’t have to maintain two separate models. This trick is also sometimes useful if you want to test multiple variants in development. As the model compiles only once and then you can test the two variants while modifying other parts of your code and reduce time waiting for compilation.
Just to make sure the model works and see how to correctly specify the data, let’s fit it assuming the standard deviation is to be estimated:
data_optional <- list(
N = N,
X = X,
sigma_known = 0,
sigma_data = numeric(0) #This produces an array of size 0
)
fit_optional <- sampling(model_optional,
data = data_optional,
iter = 500, pars = c("mu","sigma"))
summary(fit_optional, probs = c(0.1, 0.9))$summary %>% kable()
mean | se_mean | sd | 10% | 90% | n_eff | Rhat | |
---|---|---|---|---|---|---|---|
mu | 7.854036 | 0.0198265 | 0.5440900 | 7.181837 | 8.531780 | 753.0924 | 0.9981102 |
sigma | 1.730077 | 0.0152808 | 0.3918781 | 1.308565 | 2.270505 | 657.6701 | 0.9989029 |
lp__ | -11.992770 | 0.0503044 | 0.9811551 | -13.383729 | -11.089657 | 380.4199 | 1.0016842 |
And now let’s run the model and give it the correct standard deviation:
data_optional_sigma_known <- list(
N = N,
X = X,
sigma_known = 1,
sigma_data = array(sigma_true, 1)
#The array conversion is necessary, otherwise Stan complains about dimensions
)
fit_optional_sigma_known <- sampling(model_optional,
data = data_optional_sigma_known,
iter = 500, pars = c("mu","sigma"))
summary(fit_optional_sigma_known, probs = c(0.1, 0.9))$summary %>% kable()
mean | se_mean | sd | 10% | 90% | n_eff | Rhat | |
---|---|---|---|---|---|---|---|
mu | 7.808058 | 0.0292710 | 0.6273565 | 7.017766 | 8.622762 | 459.3600 | 1.006164 |
sigma | 2.000000 | 0.0000000 | 0.0000000 | 2.000000 | 2.000000 | 1000.0000 | NaN |
lp__ | -11.072234 | 0.0321233 | 0.6750295 | -11.917321 | -10.585280 | 441.5753 | 1.002187 |
Extending
Obviously this method lets you do all sorts of more complicated things, in particular:
- When the optional parameter is a vector you can have something like
vector[sigma_known ? 0 : n_sigma] sigma;
- You can have more than two variants to choose from and then use something akin to
real param[varaint == 5 ? 0 : 1];
- If your conditions become more complex you can always put them into a user-defined function (for optional data) or
transformed data
block (for optional parameters) as in:
functions {
int compute_whatever_size(int X, int Y, int Z) {
//do stuff
}
}
data {
...
real whatever[compute_whatever_size(X,Y,Z)];
real<lower = 0> whatever_sigma[compute_whatever_size(X,Y,Z)];
}
transformed data {
int carebear_size;
//do stuff
carebear_size = magic_result;
}
parameters {
vector[carebear_size] carebear;
matrix[carebear_size,carebear_size] spatial_carebear;
}