Monday, October 17, 2016

Finite mixture models in Stan

Jim Savage

17 October 2016

A very fun class of problems has the following structure: we observe some data, but there are several mechanisms that may have generated it. We see someone’s height, and need to infer their gender. Or we see some monthly unemployment figures and have to work out whether we’re in a low-growth scenario or not. One modeling framework that helps us think about these problems is the finite mixture model. Explicitly, the model looks like this:
We first draw an index that tells us which model the data will be drawn from. Each state has probability , collected in  with :
Once we know which model observation  is going to come from, we can draw the observations. If the data generating process is a normal with location  and scale , then
Of course the generative model for  needn’t be such a simple model, or even a model of the same structure. All it needs to do is have distributional implications for the outcome data .
So that’s the generative model. As you know, I’m a big fan of Modern Statistical Workflow. So let’s simulate some data in R so that there’s no confusion.
library(dplyr); library(ggplot2); library(ggthemes)

# Number of data points
N <- 400

# Let's make three states
mu <- c(3, 6, 9)
sigma <- c(2, 4, 3)

# with probability
Theta <- c(.5, .2, .3)

# Draw which model each belongs to
z <- sample(1:3, size = N, prob = Theta, replace = T)

# Some white noise
epsilon <- rnorm(N)

# Simulate the data using the fact that y ~ normal(mu, sigma) can be 
# expressed as y = mu + sigma*epsilon for epsilon ~ normal(0, 1)
y <- mu[z] + sigma[z]*epsilon

data_frame(y, z = as.factor(z)) %>% 
  ggplot(aes(x = y, fill = z)) +
  geom_density(alpha = 0.3) +
  theme_economist() +
  ggtitle("Three data generating processes")

What should you take from the above? The big point is that if you observe a value of 7, it might very plausibly have come from any of the models; likewise a value of 15 probably did not come from the first model, and a value of 0 probably came from the first. This brings us to the likelihood.

From generative model to likelihood

Recall that the likelihood contribution of a data point  is the height of the generative density for a given set of parameters. But here we have several sets of parameters. How do we deal with that?
Well, if the probability of being drawn from model 1 is  and so on, we can write out the likelihood as

where the notation  refers to the height of the normal density at  with the given parameters. The above is the likelihood contribution of a single datapoint. The likelihood of the whole model (under the assumption that observations are independent) is the product of individual likelihood contributions. Of course, this is a very small number, which gives us computational problems. So we prefer to work on the log scale (allowing us to sum log likelihood contributions, rather than multiplying likelihood contributions).

Now how should we evaluate the inside of the log() on the right?

Log sum exp to the rescue

A very handy transformation here is the log_sum_exp() function, which, as its name suggests, is the log of summed exponents. It is defined as:
log_sum_exp(a, b) = log(exp(a) + exp(b))
Now look at our log likeihood above. We could re-write it as

Now we’re going to use this trick to implement the finite mixture model in Stan

Estimating the model in Stan

Here is the simple model. Note that we’re increasing the log probability counter target using the log_sum_exp function just discussed.
// saved as finite_mixture_linear_regression.stan
data {
  int N;
  vector[N] y;
  int n_groups;
}
parameters {
  vector[n_groups] mu;
  vector<lower = 0>[n_groups] sigma;
  simplex[n_groups] Theta;
}
model {
  vector[n_groups] contributions;
  // priors
  mu ~ normal(0, 10);
  sigma ~ cauchy(0, 2);
  Theta ~ dirichlet(rep_vector(2.0, n_groups));
  
  
  // likelihood
  for(i in 1:N) {
    for(k in 1:n_groups) {
      contributions[k] = log(Theta[k]) + normal_lpdf(y[i] | mu[k], sigma[k]);
    }
    target += log_sum_exp(contributions);
  }
}
Which we estimate using:
library(rstan)
options(mc.cores = parallel::detectCores())

compiled_model <- stan_model("finite_mixture_linear_regression.stan")

estimated_model <- sampling(compiled_model, data = list(N= N, y = y, n_groups = 3), iter = 600)

Label switching

That model appeared to estimate very cleanly. But if we print the parameters estimates, we see that all the Rhats are large an n_effs are low. That’s bad! So we might check out the traceplots to see what’s going on.

Oh dear! We have a case of the dreaded label switching. With these mixture models, the parameters aren’t identified up to an index. That is, we could change all the indices of the parameters, the likelihood would be unchanged. Think about it this way: we said that state 1 () had probability .5. But we could equally call it “state 2” and it wouldn’t affect the model at all so long as we slot the state 2 parameters into index 1. This “label” switching happens when the labels all change while the model is running.
So what do we do? There are a few tricks. One is to declare one of the parameters as an ordered vector, so long as there is theoretical support for doing so. We can do this by replacing the line
vector[n_groups] mu;
with
ordered[n_groups] mu;
This estimates far more cleanly, as we can see from the parameters that now cover the generative process, and a much nicer looking traceplot.
## Inference for Stan model: finite_mixture_linear_regression_2.
## 4 chains, each with iter=600; warmup=300; thin=1; 
## post-warmup draws per chain=300, total post-warmup draws=1200.
## 
##              mean se_mean   sd     2.5%      25%      50%      75%
## mu[1]        2.51    0.09 0.59     0.66     2.34     2.63     2.87
## mu[2]        4.87    0.13 1.42     2.78     3.53     4.87     6.04
## mu[3]        7.92    0.09 1.23     6.33     7.06     7.63     8.42
## sigma[1]     1.60    0.06 0.39     0.64     1.42     1.60     1.79
## sigma[2]     2.56    0.08 0.92     0.94     1.85     2.56     3.25
## sigma[3]     3.36    0.04 0.55     2.09     3.12     3.45     3.70
## Theta[1]     0.33    0.02 0.15     0.04     0.22     0.35     0.44
## Theta[2]     0.30    0.01 0.15     0.05     0.18     0.29     0.40
## Theta[3]     0.37    0.02 0.16     0.05     0.24     0.38     0.50
## lp__     -1060.70    0.15 2.19 -1066.08 -1061.92 -1060.33 -1059.09
##             97.5% n_eff Rhat
## mu[1]        3.27    46 1.07
## mu[2]        7.53   115 1.03
## mu[3]       11.25   181 1.03
## sigma[1]     2.37    50 1.08
## sigma[2]     4.23   140 1.02
## sigma[3]     4.28   185 1.04
## Theta[1]     0.59    86 1.05
## Theta[2]     0.60   213 1.01
## Theta[3]     0.65   111 1.03
## lp__     -1057.62   216 1.00
## 
## Samples were drawn using NUTS(diag_e) at Mon Oct 17 22:56:15 2016.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

In the next post, I’ll build on these concepts and discuss Markov Switching.

No comments:

Post a Comment