A Bayesian Perspective on Missing Data Imputation

Apr 26, 2021
11 min read
Mar 11, 2022 15:59 UTC
This lecture discusses some approaches to handling missing data, primarily when missingness occurs completely randomly. We discuss a procedure, MICE, which uses Gibbs sampling to create multiple "copies" of filled-in datasets.
Missing piece in a jigsaw.
Missing piece in a jigsaw.

Missing data is a quite prevalent issue among many applied data sets. This lecture is a light introduction to one way we can perform missing data imputation from a Bayesian perspective. For more information, the book Statistical Analysis with Missing Data by Roderick Little and Donald Rubin is a great resource.

Diabetes example

The motivating example comes from Hoff’s book, page 115. Consider two measurements on 200 Arizona women with Pima Indian heritage:

  • skin, the skin fold thickness, and
  • bmi, the body mass index.

Our goal is to study the relationship between these two variables, accounting for the missing data present in the table.

1
2
3
library(tidyverse)
library(ggpubr)
library(rjags)
1
2
3
4
5
6
7
8
9
diab <- data.table::fread(
  "http://www2.stat.duke.edu/~pdh10/FCBS/Exercises/diabetes_200_miss.dat",
  header = T
) %>%
  select(skin, bmi)

diab %>%
  head() %>%
  knitr::kable()
skinbmi
2830.2
33NA
NA35.8
4347.9
NANA
27NA

Exploratory data analysis

As shown in the table above, there are obviously missing values coded as NA’s in our data set1.

We can use the summary() function in R to get the 5-number summary, mean, and number of missing values for the columns.

skinbmi
Min.7.018.20
1st Qu.20.527.60
Median29.032.80
Mean29.132.21
3rd Qu.36.036.48
Max.99.047.90
NA’s25.022.00

Histograms2 are helpful in determining the distributions of the variables. Both are roughly symmetric, and there’s one extreme observation for skin fold thickness.

Observed BMI measurements and skin fold thicknesses.

Figure 1: Observed BMI measurements and skin fold thicknesses.

We can also use a scatter plot to show the relationship between the two variables.

We observe a relatively strong positive relationship except for the one outlier, and we might want to quantify this in terms of how correlated the two variables are.

Finally, when we have missing values in our data set, it can be helpful to visualize the pattern of missingness. The aggr() function from the VIM R package shows proportions of missing values for each variable, as well as how frequently missingness occurrs for various variable combinations.

1
VIM::aggr(diab)

There’s about 13% missingness for the skin variable, and about 11% for the bmi variable. The blue color on the right indicates presence of data, whereas the red indicates missing values. We have both values for the large majority of the data, but in a few entries we are missing one or both of the values.

Different strategies might be applied for different types of missingness. For example, if we always have values in one variable but have missing entries in another variable, or if the two variables are always missing together.

Bayesian models

Our two variables appear correlated with each other, so we specify a bivariate normal data distribution for the sampling model:

Yi=[Yi1Yi2]θ,ΣiidMVN([θ1θ2],[Σ11Σ12Σ21Σ22])

where Yi1 is the ith person’s skin, and Yi2 is the ith person’s BMI. The two parameters of the MVN distribution are its mean vector and covariance matrix. The multivariate normal is just an extension of the normal distribution that accounts for correlation between components of the random variables3.

For the prior model, we have essentially six parameters. θ1 and θ2 are means, so we specify a normal distribution with a large variance:

θ1,θ2iidN(0,1010)

For the covariance matrix of the multivariate normal distribution, there’s a conjugate prior called the inverse Wishart distribution, which is a distribution over matrices that has two parameters.

Posterior distributions for our parameters are given by Bayes’ Theorem:

p(θ,Σy)p(yθ,Σ)p(θ,Σ)[i=1np(yiθ,Σ)]p(θ)p(Σ)

A problem arises when we evaluate the sampling model likelihood, since this depends on all of the data, but we have missing values. For example, if we were missing Y31, then we can’t evaluate p(y3θ,Σ).

Simple approaches

There are many ways to handle missing data. The simplest approach is to perform a complete case analysis, where we delete entries with a missing value for at least one variable. If the data are missing completely at random (MCAR), then we get unbiased parameter estimate. However:

  • Credible intervals will usually be too wide, since the sample size has effectively been reduced (drastically).
  • If missing values occur for other reasons, this yields biased estimates.

The next step up is instead of dropping all the missing entries, we replace missing values of a feature with the mean or median of available values for the feature. This is called single mean/median imputation, and is often recommended in many machine learning tasks if the data are MCAR as it can still lead to unbiased parameter estimates.

There are also some drawbacks with this method, especially if you’re interested in studying the variability. Single mean/median imputation artificially reduces the variance of features, which results in credible intervals which are too narrow. As shown in Figure 2, the peak at the mean is much higher relative to the surrounding values. The mean of the distribution can still be estimated appropriately, but the amount of spread is reduced.

1
2
3
4
5
6
7
diab %>%
  select(bmi) %>%
  mutate(Imputed = bmi) %>%
  replace_na(list(Imputed = mean(.$Imputed, na.rm = T))) %>%
  select(`No imputation` = bmi, `With mean imputation` = Imputed) %>%
  pivot_longer(everything(), names_to = "Variable", values_to = "BMI") %>%
  gghistogram(x = "BMI", bins = 10, facet.by = "Variable")
Mean imputation with the BMI feature.

Figure 2: Mean imputation with the BMI feature.

Another drawback is it doesn’t account for relationship between variables, thus reduces correlation. For example, the correlation between the two variables is 0.665 if we only look at the complete cases. If we use single mean imputation, the correlation decreases to 0.589.

Note that if you don’t care about assessing variability in your estimates, which is often the case for classification tasks, then mean imputation can work quite well.

A Bayesian approach

A popular approach to address the issues of single mean imputation is called multiple imputation by chained equations (MICE), and proceeds as follows.

  1. Specify separate imputation models for each feature with missing values, which are conditional on all other features. This allows us to exploit relationship between variables.
  2. Use Gibbs sampling with the imputation models to generate multiple imputed data sets. This has the effect of allowing us to more appropriately assess uncertainty in the missing values.
  3. Analyze each imputed data set separately, i.e. by sampling from the posterior using JAGS.
  4. Combine the results across imputed data sets at the end.

Recall that in the Gibbs sampler we generate values at step s+1 using the full conditional distributions given values of other parameters at step s. In our case since missing values are unknown (just like parameters), we can replace the parameters in the Gibbs sampler with our missing values.

Instead of deriving full conditional distributions, MICE takes a shortcut for missing data by directly specifying full conditional distributions. One hopes that there is a proper posterior distribution which would yield these full conditional distributions.

A common choice is to specify a Bayesian linear regression model, in our to exploit relationships between variables. Let Yi1 denote the skin variable and Yi2 the BMI:

Yi1Yi2,β(2),σ2(2)N(β1(2)+β2(2)Yi2,σ2(2))Yi2Yi1,β(1),σ2(1)N(β1(1)+β2(1)Yi1,σ2(1))

where the two mean terms are linear regression of BMI on skin and vice versa. Then we specify appropriate priors on β(1), β(2), σ2(1), and σ2(2). This gives us a Bayesian model to estimate the missing skin values Yi1 and missing BMI values Yi2.

What’s a bit different from before is that we are only generating a small number of samples, because we are filling in missing data instead of parameter values. This is usually only run for 5-10 iterations.

Implementation in R

This imputation procedure can be done automatically in R using the mice() function in the mice package. Below we run D=5 iterations to obtain D separate, imputed data sets:

1
2
3
4
5
6
7
8
9
set.seed(42)
D <- 5
diab_mice <- mice::mice(
  diab,            # dataset with missing values
  m = D,           # number of imputed datasets
  method = "norm"  # use Bayesian linear regression discussed above
)

Y_imp <- map(seq(D), ~mice::complete(diab_mice, .x))

The Y_imp variable is a list of five data frames, each with missing values filled in by the corresponding values in diab_mice$imp.

Just like with other MCMC methods, we can assess convergence by looking at trace plots. With only five iterations, we want to see them randomly dispersed among each other, e.g. it would be undesirable if the red curve was always above the other curves.

Mean and standard deviation of imputed values across the five generated data sets.

Figure 3: Mean and standard deviation of imputed values across the five generated data sets.

Drawing posterior samples

We can then treat each of the five imputed data sets separately, and use JAGS4 to generate separate posterior samples for the parameters of interest. The table below has 95% credible intervals for the specified parameters:

θ1θ2ρ
(27.59, 31.02)(31.24, 32.93)(0.57, 0.73)
(27.89, 31.16)(31.17, 32.86)(0.58, 0.74)
(27.48, 30.71)(31.25, 32.93)(0.58, 0.74)
(27.84, 31.23)(31.23, 32.91)(0.61, 0.75)
(27.67, 30.88)(31.62, 33.34)(0.55, 0.71)

These values are fairly similar, mainly because the sample size is large, and there is a relatively small proportion of missing values.

Combining results

Generally, we do not care about posterior inference of parameters for each imputed data set. Instead, we would rather pool our results together to account for imputation variability.

In frequentist inference, this is done using Rubin's rules; in Bayesian inference, we just aggregate all of our posterior samples together across the imputed data sets. The credible intervals after aggregating posterior samples are given below.

95% CI
θ1(27.67, 31.04)
θ2(31.26, 33.07)
ρ(0.57, 0.74)

From the 95% CI of ρ, we conclude there is fairly strong positive correlation between skin and BMI.

Concluding remarks

For many inferential tasks, simple imputation methods can perform well! Advanced methods are necessary if one is focused on quantifying uncertainty in estimates correctly, particularly for data sets with low sample sizes or a large amount of missingness.

Missing data is very application-specific, so you should really think about what the variables represent, i.e. use domain knowledge and explore the data before considering how to impute. Missing data often does not occur completely randomly, e.g.:

  • Non-response to survey questions – it might be a sensitive question, or maybe the survey is too long.
  • Patient dropout in a longitudinal study – maybe the patients died or moved to somewhere else.

Finally, never, ever replace all of the missing values with zeros (or other arbitrary values)!


  1. This is not the only way missing data is coded. For example, in some software -999 could be used to represent a missing value since it’s “obviously” unrealistic. ↩︎

  2. R code for generating Figure 1:

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    
    diab %>%
      pivot_longer(everything(), names_to = "Variable", values_to = "Value") %>%
      mutate(
        Label = case_when(
          Variable == "skin" ~ "Observed skin fold thickness",
          T ~ "Observed BMI measurements"
        )
      ) %>%
      gghistogram(x = "Value", bins = 10, xlab = "") %>%
      facet(facet.by = "Label", scales = "free")
    
     ↩︎
  3. If we don’t assume relationship between the two variables, then we could easily specify separate normal distributions for each. ↩︎

  4. R code for running JAGS:

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    
    n <- nrow(Y_imp[[1]])  # sample size
        
    # Specify parameters, initial values, and JAGS settings
    parameters <- c("theta", "sig21", "sig22", "rho")
        
    initValues <- list(
      "theta" = unname(colMeans(diab, na.rm = T)),
      "tau" = matrix(c(1, 0, 0, 1), nrow = 2)
    )
        
    # JAGS settings
    adaptSteps <- 10000     # number of steps to "tune" the samplers
    burnInSteps <- 10000    # number of steps to "burn-in" the samplers
    nChains <- 3            # number of chains to run
    numSavedSteps <- 10000  # total number of steps in chains to save
    thinSteps <- 100        # number of steps to "thin" (1 = keep every step)
    nIter <- ceiling((numSavedSteps*thinSteps)/nChains) 	# steps per chain
        
    mcmcChain <- NULL
    for (d in seq(D)) {
      dataList <- list(
        "n" = n,
        "Y_imp" = as.matrix(Y_imp[[d]]),
        "I" = matrix(c(1, 0, 0, 1), nrow = 2)
      )
          
      m <- textConnection("
    model {
      # Sampling model - for multivariate normal, second argument is the
      # precision matrix, which is the inverse of the covariance matrix
      for (i in 1:n) {
      	Y_imp[i,1:2] ~ dmnorm(theta[], tau[, ])
      }
        
      # Prior model
      theta[1] ~ dnorm(0, 1e-10)
      theta[2] ~ dnorm(0, 1e-10)
      tau[1:2, 1:2] ~ dwish(I, 3)
        
      # Transformation
      # Obtain covariance matrix from precision matrix
      sigma[1:2, 1:2] <- inverse(tau[, ])
        
      # Extract variances and correlation
      sig21 <- sigma[1, 1]
      sig22 <- sigma[2, 2]
      rho <- sigma[1, 2] / (sqrt(sig21)*sqrt(sig22))
    }")
      jagsModel <- jags.model(m, 
                              data = dataList, 
                              inits = initValues, 
                              n.chains = nChains, 
                              n.adapt = adaptSteps)
      close(m)
          
      if (burnInSteps > 0) {
        update(jagsModel, n.iter = burnInSteps)
      }
      codaSamples <- coda.samples(jagsModel, 
                                  variable.names = parameters, 
                                  n.iter = nIter, 
                                  thin = thinSteps)
      mcmcChain[[d]] <- as.matrix(codaSamples)
    }
    
     ↩︎