Estimating the mean vector

Let’s estimate the mean vector from multivariate normal data.
Author

Paweł Czyż

Published

August 16, 2024

I recently ended up building another Gibbs sampler1. I had \(N\) vectors \((Y_n)\) such that each vector \(Y_n = (Y_{n1}, \dotsc, Y_{nG})\) was assumed to come from the multivariate normal distribution:

\[ Y_n\mid \mu \sim \mathcal N(\mu, \Sigma), \]

where \(\Sigma\) is a known \(G\times G\) covariance matrix and \(\mu \sim \mathcal N(0, B)\) is the unknown population mean, given a multivariate normal prior. In this case, it is important that we know \(\Sigma\) and that \(B\) is a fixed matrix, which was not necessarily build using \(\Sigma\): the Wikipedia derivation for Bayesian multivariate linear regression (which is a more general case) uses a different prior. I searched the internet for some time and I found a nice project, The Book of Statistical Proofs, but I still could not find the derivation adressing the simple case above.

Let’s quickly derive it. Define \(\nu(x) = \exp(-x/2)\), which has two key properties. First, \(\nu(x)\cdot \nu(y) = \nu(x + y)\). Second, \[\begin{align*} \mathcal N(x\mid m, V) &\propto \nu\big( (x-m)^T V^{-1}(x-m) \big)\\ &\propto \nu( x^TV^{-1}x - 2m^TV^{-1}x), \end{align*} \]

which shows us how to recognise the mean and the covariance matrix of a multivariate normal distribution.

Let’s define \(\bar Y = N^{-1}\sum_{n=1}^N Y_n\) to be the mean vector and \(V = (B^{-1} + N\Sigma^{-1})^{-1}\) to be an auxiliary matrix. (We see that \(V^{-1}\) looks like sum of precision matrices, so may turn out to be some precision matrix!). The posterior on \(\mu\) is given by \[\begin{align*} p\big(\mu \mid (Y_n), \Sigma, B\big) &\propto \mathcal N( \mu\mid 0, B) \cdot \prod_{n=1}^N \mathcal N(Y_n\mid \mu, \Sigma) \\ &\propto \nu( \mu^T B^{-1}\mu )\cdot \nu\left( \sum_{n=1}^N (Y_n - \mu)^T \Sigma^{-1} (Y_n - \mu) \right) \\ &\propto \nu\left( \mu^T \left(B^{-1} + N \Sigma^{-1}\right)\mu - 2 N \bar Y^T \Sigma^{-1} \mu \right) \\ & \propto \nu\left( \mu^T V^{-1} \mu - 2 N \bar Y^T \Sigma^{-1} (V V^{-1}) \mu \right) \\ & \propto \nu\left( \mu^T V^{-1} \mu - 2\left(N \bar Y^T \Sigma^{-1} V\right) V^{-1} \mu \right). \end{align*} \]

Let’s define \(m^T = N\bar Y^T \Sigma^{-1} V\), so that \(m = N \cdot V \Sigma^{-1} \bar Y\). In turn, we have \(p\big(\mu \mid (Y_n), \Sigma, B\big) = \mathcal N(\mu \mid m, V)\).

It looks a bit surprising that we have \(m\) being proportional to \(N\): we would expect that for \(N\gg 1\) we would have \(m\approx \bar Y\). However, this is fine as for \(N\gg 1\) we have \(V \approx N^{-1}\Sigma\) and \(m\approx \bar Y\). For a small sample size, however, the prior regularises the estimate.

Let’s implement these equations in JAX:

Code
from typing import Callable

import jax
import jax.numpy as jnp
import jax.random as jrandom

import blackjax
from jaxtyping import Float, Array


def normal_logp(
  x: Float[Array, " G"],
  mean: Float[Array, " G"],
  precision: Float[Array, "G G"],
) -> Float[Array, ""]:
  y = x - mean
  return -0.5 * jnp.einsum("g,gh,h->", y, precision, y)


def logposterior_fn(
  data: Float[Array, "N G"],
  precision_prior: Float[Array, "G G"],
  precision_likelihood: Float[Array, "G G"],
) -> Callable[[Float[Array, " G"]], Float[Array, ""]]:
  def fn(mu: Float[Array, " G"]) -> Float[Array, ""]:
    logprior = normal_logp(
      x=mu,
      mean=jnp.zeros_like(mu),
      precision=precision_prior,
    )
    loglike = jnp.sum(
      jax.vmap(
        normal_logp,
        in_axes=(0, None, None),)(
          data,
          mu,
          precision_likelihood,
        )
    )
    return logprior + loglike
  
  return fn


def get_y_bar(data: Float[Array, "N G"]) -> Float[Array, " G"]:
  return jnp.mean(data, axis=0)


def posterior_precision(
  data: Float[Array, "N G"],
  precision_prior: Float[Array, "G G"],
  precision_likelihood: Float[Array, "G G"],
):
  N = data.shape[0]
  return precision_prior + N * precision_likelihood


def posterior_mean(
  data: Float[Array, "N G"],
  precision_prior: Float[Array, "G G"],
  precision_likelihood: Float[Array, "G G"],
):
  N = data.shape[0]
  posterior_cov = jnp.linalg.inv(
    posterior_precision(
      data=data,
      precision_prior=precision_prior,
      precision_likelihood=precision_likelihood,
    )
  )
  return (N * posterior_cov) @ precision_likelihood  @  get_y_bar(data)


def posterior_sample(
  key,
  data: Float[Array, "N G"],
  precision_prior: Float[Array, "G G"],
  precision_likelihood: Float[Array, "G G"],
  size: int = 1_000,
):
  N = data.shape[0]

  m = posterior_mean(
    data=data,
    precision_prior=precision_prior,
    precision_likelihood=precision_likelihood,
  )
  V = jnp.linalg.inv(posterior_precision(
    data=data,
    precision_prior=precision_prior,
    precision_likelihood=precision_likelihood,
  ))

  return jrandom.multivariate_normal(
    key, mean=m, cov=V, shape=(size,)
  )

We start by generating some data points:

Code
n_samples = 4_000
data_size: int = 3

corr = 0.95
Sigma = jnp.asarray([
  [1.0, 2 * corr],
  [2 * corr, 2.0**2 * 1.0],
])

B = 1.0**2 * jnp.eye(2)

mu = jnp.asarray([0.0, 1.5])

key = jrandom.PRNGKey(42)
key, subkey = jrandom.split(key)

data = jrandom.multivariate_normal(key, mu, Sigma, shape=(data_size,))

Now let’s do inference in three different ways:

  1. Sample directly from multivariate normal using the formula derived above.
  2. Use the NUTS sampler from the BlackJAX package.
  3. Assume a somewhat wrong \(\Sigma\) matrix, ignoring the offdiagonal terms and retaining only the diagonal ones.

Additionally, we will plot a sample from the prior. On top of that we plot three points: the ground-truth vector \(\mu^*\), data mean \(\bar Y\), and the plotted (prior or an appropriate posterior) distribution mean[^2].

Code
import matplotlib.pyplot as plt
plt.style.use("dark_background")


# Sample from the prior
key, subkey = jrandom.split(key)
prior = jrandom.multivariate_normal(
  subkey,
  mean=jnp.zeros(2),
  cov=B,
  shape=(n_samples,)
)

# Sample from the posterior using analytic formula
key, subkey = jrandom.split(key)
posterior = posterior_sample(
  subkey,
  data=data,
  precision_prior=jnp.linalg.inv(B),
  precision_likelihood=jnp.linalg.inv(Sigma),
  size=n_samples,
)


# Sample from the posterior using BlackJAX
logdensity_fn = logposterior_fn(
  data=data,
  precision_prior=jnp.linalg.inv(B),
  precision_likelihood=jnp.linalg.inv(Sigma),
)

nuts = blackjax.nuts(
  logdensity_fn,
  1e-2,
  jnp.ones(2),
)

n_warmup = 2_000

state = nuts.init(jnp.zeros_like(mu))
step_fn = jax.jit(nuts.step)

key, subkey = jrandom.split(key)
for i in range(n_warmup):
    nuts_key = jrandom.fold_in(subkey, i)
    state, _ = step_fn(nuts_key, state)

posterior_blackjax = []
key, subkey = jrandom.split(key)
for i in range(n_samples):
    nuts_key = jrandom.fold_in(subkey, i)
    state, _ = step_fn(nuts_key, state)
    posterior_blackjax.append(state.position)

posterior_blackjax = jnp.asarray(posterior_blackjax)

# Assume that errors are uncorrelated and use analytic formula
key, subkey = jrandom.split(key)
posterior_ind = posterior_sample(
  subkey,
  data=data,
  precision_prior=jnp.linalg.inv(B),
  precision_likelihood=jnp.diag(1.0 / jnp.diagonal(Sigma)),
  size=5_000,
)


def _annotate(ax, x, y, marker, color, label=None):
  ax.scatter([x], [y], s=6**2, c=color, marker=marker, label=label)

def annotate_axis(ax):
  _annotate(ax, mu[0], mu[1], marker="x", color="r", label="$\\mu^*$")
  _annotate(ax, data.mean(axis=0)[0], data.mean(axis=0)[1], marker="+", color="yellow", label="$\\bar Y$")


fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, dpi=200)

ax = axs[0, 0]
ax.set_title("Prior")
ax.scatter(prior[:, 0], prior[:, 1], s=1, c="lightblue", alpha=0.3)
_annotate(ax, mu[0], mu[1], marker="x", color="r")
_annotate(ax, 0.0, 0.0, marker="*", color="salmon")

ax = axs[0, 1]
ax.set_title("Posterior (uncorrelated $\\Sigma$)")
ax.scatter(posterior_ind[:, 0], posterior_ind[:, 1], s=1, c="blue", alpha=0.3)
ax.scatter([mu[0]], [mu[1]], s=10, c="red", marker="x")
annotate_axis(ax)
_annotate(ax, posterior_ind[:, 0].mean(), posterior_ind[:, 1].mean(), marker="*", color="salmon")


ax = axs[1, 0]
ax.set_title("Posterior (analytic)")
ax.scatter(posterior[:, 0], posterior[:, 1], s=1, c="blue", alpha=0.3)
ax.scatter([mu[0]], [mu[1]], s=10, c="red", marker="x")
annotate_axis(ax)
_annotate(ax, posterior[:, 0].mean(), posterior[:, 1].mean(), marker="*", color="salmon")

ax = axs[1, 1]
ax.set_title("Posterior (BlackJAX)")
ax.scatter(posterior_blackjax[:, 0], posterior_blackjax[:, 1], s=1, c="blue", alpha=0.3)
ax.scatter([mu[0]], [mu[1]], s=10, c="red", marker="x")
annotate_axis(ax)
_annotate(ax, posterior_blackjax[:, 0].mean(), posterior_blackjax[:, 1].mean(), marker="*", color="salmon", label="Mean")
ax.legend(frameon=False)


for ax in axs.ravel():
  ax.set_xlabel("$\\mu_1$")
  ax.set_ylabel("$\\mu_2$")
  ax.spines[["top", "right"]].set_visible(False)

fig.tight_layout()

Looks like BlackJAX and analytic formula give the same posterior, so perhaps there is no mistake in the algebra. We also see that using a proper \(\Sigma\) should help us estimate the mean vector better and that using the prior should regularise the inference.

Let’s do several repetitions of this experiment and evaluate the distance from the point estimate to the ground-truth value:

Code
def distance(x1, x2):
  return jnp.sqrt(jnp.sum(jnp.square(x1 - x2)))


def make_repetition(key, data_size: int):
  key1, key2 = jrandom.split(key, 2)
  mu_true = jrandom.multivariate_normal(key1, jnp.zeros(2), B)

  data = jrandom.multivariate_normal(
    key2, mu_true, Sigma, shape=(data_size,)
  )

  y_bar = get_y_bar(data)
  
  mu_expected = posterior_mean(
    data=data,
    precision_prior=jnp.linalg.inv(B),
    precision_likelihood=jnp.linalg.inv(Sigma),
  )

  mu_diagonal = posterior_mean(
    data=data,
    precision_prior=jnp.linalg.inv(B),
    precision_likelihood=jnp.diag(1.0 / jnp.diagonal(Sigma)),
  )

  return {
    "prior": distance(jnp.zeros(2), mu_true),
    "posterior": distance(mu_expected, mu_true),
    "data": distance(y_bar, mu_true),
    "diagonal": distance(mu_diagonal, mu_true),
  }

n_reps = 2_000

def make_plots(key, axs, data_size: int):
  reps = [make_repetition(k, data_size=data_size) for k in jrandom.split(key, n_reps)]

  bins = jnp.linspace(0, 4, 20)

  def plot(ax, name, color):
    ax.hist(
      [r[name] for r in reps],
      color=color,
      density=True,
      bins=bins,
    )

  ax = axs[0]
  ax.set_title(f"$N={data_size}$")
  ax.set_xlabel("Prior mean")
  plot(ax, "prior", "white")

  ax = axs[1]
  ax.set_xlabel("Posterior mean")
  plot(ax, "posterior", "bisque")

  ax = axs[2]
  ax.set_xlabel("Data mean")
  plot(ax, "data", "darkorange")

  ax = axs[3]
  ax.set_xlabel("Diagonal model")
  plot(ax, "diagonal", "purple")


fig, axs = plt.subplots(4, 4, sharex=True, sharey="row")

for i, size in enumerate([2, 10, 50, 250]):
  key, subkey = jrandom.split(key)
  make_plots(
    subkey,
    axs=axs[:, i],
    data_size=size,
  )

for ax in axs.ravel():
  ax.spines[["top", "right", "left"]].set_visible(False)
  ax.set_yticks([])

fig.suptitle("Error")
fig.tight_layout()

We see what should expected:

  1. Every method that uses data (i.e., everything apart from the prior) has improved performance when \(N\) increases.
  2. For small sample sizes, using plain data can result in larger errors and a reasonable prior can regularise the posterior mean, so that the error is smaller.
  3. Given enough data, the performance of posterior mean and using the data mean looks quite similar.

Additionally, we see that a model assuming diagonal \(\Sigma\) (i.e., ignoring the correlations) also has performance quite similar to the true one.

This “performance looks similar” can actually be somewhat misleading: each of this distributions has quite large variance, so minor differences can be unobserved.

Let’s now repeat this experiment, but this time plotting the difference between distances, so that we can see any difference better. Namely, for the method \(M\) and and the \(s\)-th simulation, write \(d^{(M)}_s\) for the obtained distance. Now, instead of plotting the data sets \(\{ d^{(M_1)}_{1}, \dotsc, d^{(M_1)}_S\}\) and \(\{ d^{(M_2)}_{1}, \dotsc, d^{(M_2)}_S\}\), we can plot the differences \(\{ d^{(M_2)}_{1} - d^{(M_1)}_{1}, \dotsc, d^{(M_2)}_{S} - d^{(M_1)}_{S} \}\).

Let’s use the posterior mean in the right model (potentially the best solution) as the baseline and compare it with three other models. In each of the plots, the samples on the right of zero, represent positive difference, i.e., the case when the baseline method (in our case the posterior in the right model) was better than the considered alternative. Apart from raw samples, let’s plot the mean of such distribution (and, intuitively, we should expect it to be larger than zero) and report the percentage of samples on the right from zero.

Code
n_reps = 3_000

def compare_with_diagonal(key, data_size: int):
  key1, key2 = jrandom.split(key, 2)
  mu_true = jrandom.multivariate_normal(key1, jnp.zeros(2), B)

  data = jrandom.multivariate_normal(
    key2, mu_true, Sigma, shape=(data_size,)
  )

  y_bar = get_y_bar(data)
  
  mu_posterior = posterior_mean(
    data=data,
    precision_prior=jnp.linalg.inv(B),
    precision_likelihood=jnp.linalg.inv(Sigma),
  )
  mu_diagonal = posterior_mean(
    data=data,
    precision_prior=jnp.linalg.inv(B),
    precision_likelihood=jnp.diag(1.0 / jnp.diagonal(Sigma)),
  )

  baseline = distance(mu_posterior, mu_true)

  return {
    "delta_prior": distance(jnp.zeros(2), mu_true) - baseline,
    "delta_diagonal": distance(mu_diagonal, mu_true) - baseline,
    "delta_data": distance(y_bar, mu_true) - baseline,
  }


def make_plots(key, axs, data_size: int):
  reps = [compare_with_diagonal(k, data_size=data_size) for k in jrandom.split(key, n_reps)]

  bins = jnp.linspace(-2, 2, 20)

  def plot(ax, name, color):
    samples = jnp.array([r[name] for r in reps])
    ax.hist(
      samples,
      color=color,
      density=True,
      bins=bins,
    )
    p_worse = float(100 * jnp.mean(samples > 0))
    ax.axvline(jnp.mean(samples), linestyle=":", color="salmon")
    ax.axvline(0.0, linestyle=":", color="white")
    ax.annotate(f"{p_worse:.0f}%", xy=(0.05, 0.5), xycoords='axes fraction')

  ax = axs[0]
  ax.set_title(f"$N={data_size}$")
  ax.set_xlabel("Prior mean")
  plot(ax, "delta_prior", "white")

  ax = axs[1]
  ax.set_xlabel("Data mean")
  plot(ax, "delta_data", "darkorange")

  ax = axs[2]
  ax.set_xlabel("Diagonal model")
  plot(ax, "delta_diagonal", "purple")


fig, axs = plt.subplots(3, 4, sharex=True, sharey="row")

for i, size in enumerate([2, 10, 50, 250]):
  key, subkey = jrandom.split(key)
  make_plots(
    subkey,
    axs=axs[:, i],
    data_size=size,
  )

for ax in axs.ravel():
  ax.spines[["top", "right", "left"]].set_visible(False)
  ax.set_yticks([])

fig.suptitle("Extra error over the baseline")
fig.tight_layout()

As expected, a well-specified Bayesian model performs the best. However, having “enough” data points one can use the data mean as well (or the misspecified model without off-diagonal terms in the covariance). An interesting question would be to check how this “enough” depends on the dimensionality of the problem.

Footnotes

  1. Probably I shouldn’t have, but I had to use a sparse prior over the space of positive definite matrices and I don’t know how to run Hamiltonian Monte Carlo with these choices…↩︎