Learning models with discrete Fisher divergence

High-dimensional discrete models suffer from the issue with intractable normalising constants. A recent manuscript introducing discrete Fisher divergence helps to resolve this issue.
Author

Paweł Czyż

Published

July 11, 2024

ISBA 2024 was a phenomenal experience, with a lot of great people, conversations, and presented projects. Summarising it would take a separate post. Or two. Or three.

In this one, however, we’ll take a look only at one topic, presented by Takuo Matsubara. I remember that during the talk I was sitting amazed for the whole time: not only was the talk great and engaging, but also the method was particularly elegant. In January, I have been thinking about using the Fisher’s noncentral hypergeometric distribution, but its normalising constant is computationally too expensive. I wish I had known Takuo’s paper back then!

The paper (or the preprint) focuses on the following problem: we have a distribution defined on a discrete space \(\mathcal Y = \{0, 1, \dotsc, K-1\}^G\) and its PMF is everywhere positive and known up to the normalizing constant. I.e., we have

\[ p_\theta(y) = \frac{1}{\mathcal Z(\theta)} q_\theta(y), \]

where \(q_\theta(y) > 0\) everywhere and this function is easy to evaluate. However, evaluating \(\mathcal Z(\theta)\) is prohibitively expensive, as it usually requires \(O(K^G)\) evaluations of \(q_\theta\).

Takuo’s framework helps to do inference in this model without the need to calculate \(\mathcal Z(\theta)\) at all! In fact, it applies to more general spaces, although in this blog post we restrict our attention to the special case above.

Before we discuss this method, let’s quickly summarise how likelihood-based inference works.

Kullback–Leibler divergence

As Takuo explained in his presentation, once we observe data points \(y_1, \dotsc, y_N\), we can form the empirical distribution \(p_\text{emp} = \frac{1}{N} \sum_{n=1}^N \delta_{y_n}\) and consider the Kullback–Leibler divergence \[ \mathrm{KL}(p_\text{emp} \parallel p_\theta ) = -\frac{1}{N} \sum_{n=1}^N \log p_\theta(y_n) + \frac{1}{N} \sum_{n=1}^N \log p_\text{emp}(y_n) = -\frac{1}{N} \sum_{n=1}^N \log p_\theta(y_n) - H(p_\text{emp}). \]

Hence, \(N \cdot \mathrm{KL}(p_\text{emp} \parallel p_\theta)\) is the negative loglikelihood (up to an additive constant, which depends on the entropy of the empirical data distribution), which can be then optimised in maximum likelihood approaches. A nice property is that \(\mathrm{KL}(p_1\parallel p_2) \ge 0\) and becomes \(0\) if and only if \(p_1 = p_2\). When we have \(N\) large enough, so that \(p_\text{emp}\) is close to the data distribution, using maximum likelihood should result in a distribution being “the closest” to the data distribution amongh the family \(p_\theta\). In particular, under no misspecification we should rediscover the data distribution.

As Takuo explained, also Bayesian inference also proceeds in this manner: \[ p( \theta \mid y_1, \dotsc, y_N ) \propto p(\theta) \cdot \exp(-N\cdot \mathrm{KL}(p_\text{emp} \parallel p_\theta) ), \]

where the entropy of \(p_\text{emp}\) is effectively hidden in the proportionality constant.

Before Takuo’s talk, I haven’t thought about Bayesian inference in terms of the Kullback–Leibler divergence from the empirical data distribution \(p_\text{emp}\) to the model \(p_\theta\): on continuous spaces \(p_\text{emp}\) is atomic, while \(p_\theta\) is (usually) not and \(\mathrm{KL}(p_\text{emp} \parallel p_\theta) = +\infty\) for all parameters \(\theta\). However, for a discrete space \(\mathcal Y\) both measures are necessarily atomic and this works nicely.

However, maximum likelihood, minimisation of the Kullback–Leibler divergence, and Bayesian inference all rely on having the access to \(\log p_\theta(y) = \log q_\theta(y) - \log \mathcal Z(\theta)\), which is not tractable in our case.

Discrete Fisher divergence

Define operators on \(\mathcal Y\) incrementing and decrementing a specific position: \[ (\mathcal I_g y)_h = \begin{cases} (y_g + 1) \mod K &\text{ if } g = h\\ y_h &\text{ otherwise} \end{cases} \] and \[ (\mathcal D_g y)_h = \begin{cases} (y_g - 1) \mod K &\text{ if } g = h\\ y_h &\text{ otherwise} \end{cases} \]

where the “mod \(K\)” means that we “increment” \(K-1\) to \(0\) (and, conversely, “decrement” \(0\) to \(K-1\)). When the data are binary, \(K=2\), we have \(\mathcal I_g = \mathcal D_g\) and they reduce to flipping the \(g\)-th bit.

The discrete Fisher divergence is then given by \[ \mathrm{DFD}(p_1\parallel p_2) = \mathbb E_{y\sim p_2}\left[ \sum_{g=1}^G \left( \frac{ p_1( \mathcal D_gy ) }{ p_1(y) } \right)^2 -2 \frac{ p_1(y) }{ p_1(\mathcal I_g y) } \right] + C(p_2), \]

where \(C(p_2)\) is a particular expression which does not depend on \(p_1\). In the paper you can find the formula for it, as well as the proof that \(\mathrm{DFD}(p_1\parallel p_2)\ge 0\) and is zero if and only if \(p_1 = p_2\). Hence, a generalised posterior is proposed: \[ p^{\mathrm{DFD}} \propto p(\theta) \cdot \exp(-\tau N\cdot \mathrm{DFD}(p_\theta \parallel p_\text{emp} )), \]

where \(\tau\) is the temperature parameter, used in generalised Bayesian inference and as a solution to model misspecification.

Note that \[\begin{align*} \mathrm{DFD}(p_\theta \parallel p_\text{emp}) &= \frac{1}{N} \sum_{n=1}^N\sum_{g=1}^G \left(\frac{ p_\theta(\mathcal D_g y_n) }{ p_\theta(y_n) }\right)^2 - 2 \frac{p_\theta(y_n)}{p_\theta(\mathcal{I}_{g}y_n)} + C(p_\text{emp}) \\ &= \frac{1}{N} \sum_{n=1}^N\sum_{g=1}^G \left(\frac{ q_\theta(\mathcal D_g y_n) }{ q_\theta(y_n) }\right)^2 - 2 \frac{q_\theta(y_n)}{q_\theta(\mathcal{I}_{g}y_n)} + C(p_\text{emp}) \end{align*} \]

meaning that it does not need the (intractable) normalising constant \(\mathcal Z(\theta)\)! This comes at the price of \(O(NG)\) evaluations of \(q_\theta\) (rather than \(O(N)\) calls to \(p_\theta\) as in the likelihood-based methods) and the fact that we are using now a generalised Bayesian approach, rather than the typical Bayesian one (the usual question is “what is the value of the temperature \(\tau\) one should use?”).

Overall, I very much like this idea, with potentially large impact on applications: in computational biology we use many discrete models and likelihood based methods could not be used because of intractable normalising constants.

Experiments on binary data

Let’s consider \(K=2\), so that \(\mathcal X = \{0, 1\}^G\). If \(\mathcal F_g := \mathcal I_g = \mathcal D_g\) is the bitflip operator, the discrete Fisher divergence takes the form

\[ \mathrm{DFD}(p_\theta \parallel p_\text{emp}) = \frac{1}{N} \sum_{n=1}^N\sum_{g=1}^G \left(\frac{ q_\theta(\mathcal F_g y_n) }{ q_\theta(y_n) }\right)^2 - 2 \frac{q_\theta(y_n)}{q_\theta(\mathcal{F}_{g}y_n)} + C(p_\text{emp}). \]

I will simulate the data by tossing \(G\) independent coins, each with its own bias \(\pi_g\). In this case the likelihood is tractable, as it is just \[ p_\pi( y ) = \prod_{g=1}^G \pi_g^{y_g} (1-\pi_g)^{1-y_g}. \]

It is convenient to write this model in the exponential family form, by reparameterising it into log-odds, \(\alpha_g = \log\frac{\pi_g}{1-\pi_g}\). Then, we have \[ q_\alpha(y) = \exp\left( \sum_{g=1}^G \alpha_g y_g \right) \]

Code
from functools import partial
from typing import Callable
from jaxtyping import Float, Int, Array

import numpy as np
import jax
import jax.numpy as jnp
from jax.scipy import optimize

rng = np.random.default_rng(42)

n_samples: int = 100
n_genes: int = 10

def logit(p):
  return np.log(p) - np.log1p(-p)

def expit(x):
  return 1.0 / (1 + np.exp(-x))

true_bias = np.linspace(0.2, 0.7, n_genes)
true_alpha = logit(true_bias)

Y = rng.binomial(1, p=true_bias, size=(n_samples, n_genes))

The simplest option is to calculate the maximum likelihood solution. Let’s do that:

Code
def print_estimate_summary(alphas, name: str | None = None) -> None:
  if name is not None:
    print(f"----- {name} -----")
  print("Absolute error on the bias:")
  found_bias = expit(alphas)
  print(true_bias - found_bias)

  print("True bias:")
  print(true_bias)

mle_bias = np.mean(Y, axis=0)
mle_alpha = logit(mle_bias)

print_estimate_summary(mle_alpha, name="maximum likelihood")
----- maximum likelihood -----
Absolute error on the bias:
[ 0.04       -0.07444444 -0.03888889  0.04666667 -0.02777778  0.00777778
  0.01333333 -0.02111111 -0.06555556 -0.03      ]
True bias:
[0.2        0.25555556 0.31111111 0.36666667 0.42222222 0.47777778
 0.53333333 0.58888889 0.64444444 0.7       ]

As we have the luxury of doing full Bayesian inference, let’s try it. We will use a hierarchical model, in which the prior on \(\alpha\) is a flexible normal distribution:

Code
import numpyro
import numpyro.distributions as dist


def independent_model(mutations):
    n_samples, n_genes = mutations.shape
    
    mu = numpyro.sample("_mu", dist.Normal(0.0, 5.0))
    sigma = numpyro.sample("_sigma", dist.HalfCauchy(scale=3.0))

    z =  numpyro.sample('_z', dist.Normal(jnp.zeros(n_genes), 1))

    alpha = numpyro.deterministic("alpha", mu + sigma * z)    

    with numpyro.plate("samples", n_samples, dim=-2):
        with numpyro.plate("genes", n_genes, dim=-1):
            numpyro.sample("obs", dist.BernoulliLogits(alpha[None, :]), obs=mutations)

key = jax.random.PRNGKey(0)

key, subkey = jax.random.split(key)

posterior = numpyro.infer.MCMC(
  numpyro.infer.NUTS(independent_model),
  num_chains=4,
  num_samples=500,
  num_warmup=500,
)
posterior.run(subkey, mutations=Y)

posterior.print_summary()

alpha_samples = posterior.get_samples()["alpha"]
bias_samples = expit(alpha_samples)

posterior_mean_alpha = alpha_samples.mean(axis=0)

print_estimate_summary(posterior_mean_alpha, name="posterior mean")
/tmp/ipykernel_60846/1763543017.py:23: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  posterior = numpyro.infer.MCMC(
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   0%|          | 1/1000 [00:01<22:10,  1.33s/it, 1 steps of size 2.34e+00. acc. prob=0.00]warmup:  45%|████▍     | 449/1000 [00:01<00:01, 434.47it/s, 15 steps of size 3.03e-01. acc. prob=0.79]sample:  93%|█████████▎| 933/1000 [00:01<00:00, 969.62it/s, 7 steps of size 1.94e-01. acc. prob=0.89] sample: 100%|██████████| 1000/1000 [00:01<00:00, 646.29it/s, 15 steps of size 1.94e-01. acc. prob=0.89]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:  49%|████▉     | 494/1000 [00:00<00:00, 4937.19it/s, 15 steps of size 3.51e-01. acc. prob=0.79]sample: 100%|█████████▉| 998/1000 [00:00<00:00, 4996.62it/s, 7 steps of size 2.14e-01. acc. prob=0.91] sample: 100%|██████████| 1000/1000 [00:00<00:00, 4971.57it/s, 15 steps of size 2.14e-01. acc. prob=0.91]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:  49%|████▉     | 489/1000 [00:00<00:00, 4886.30it/s, 31 steps of size 1.42e-01. acc. prob=0.79]sample:  98%|█████████▊| 981/1000 [00:00<00:00, 4903.42it/s, 15 steps of size 1.80e-01. acc. prob=0.94]sample: 100%|██████████| 1000/1000 [00:00<00:00, 4884.00it/s, 15 steps of size 1.80e-01. acc. prob=0.94]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:  48%|████▊     | 482/1000 [00:00<00:00, 4814.70it/s, 47 steps of size 1.64e-01. acc. prob=0.79]sample:  98%|█████████▊| 979/1000 [00:00<00:00, 4901.74it/s, 15 steps of size 1.66e-01. acc. prob=0.93]sample: 100%|██████████| 1000/1000 [00:00<00:00, 4869.76it/s, 15 steps of size 1.66e-01. acc. prob=0.93]

                mean       std    median      5.0%     95.0%     n_eff     r_hat
       _mu     -0.18      0.31     -0.18     -0.67      0.30    319.12      1.01
    _sigma      0.92      0.29      0.86      0.50      1.34    290.48      1.01
     _z[0]     -1.59      0.56     -1.55     -2.44     -0.68    428.13      1.01
     _z[1]     -0.58      0.41     -0.57     -1.21      0.11    439.12      1.01
     _z[2]     -0.48      0.40     -0.48     -1.12      0.16    452.72      1.01
     _z[3]     -0.64      0.42     -0.63     -1.37      0.01    472.24      1.00
     _z[4]     -0.03      0.39     -0.02     -0.65      0.60    450.87      1.01
     _z[5]      0.07      0.39      0.08     -0.51      0.70    432.61      1.01
     _z[6]      0.28      0.39      0.29     -0.37      0.90    408.89      1.01
     _z[7]      0.70      0.42      0.70      0.02      1.39    363.06      1.01
     _z[8]      1.19      0.49      1.19      0.39      2.00    343.35      1.00
     _z[9]      1.29      0.49      1.29      0.42      2.07    350.32      1.00

Number of divergences: 0
----- posterior mean -----
Absolute error on the bias:
[ 0.02347436 -0.08082746 -0.04482673  0.04107978 -0.02729267  0.00835171
  0.01834139 -0.01336413 -0.05330887 -0.01561599]
True bias:
[0.2        0.25555556 0.31111111 0.36666667 0.42222222 0.47777778
 0.53333333 0.58888889 0.64444444 0.7       ]

This looks like both maximum likelihood and Bayesian inference do reasonable job in this problem. Let’s implement now the discrete Fisher divergence:

Code
DataPoint = Int[Array, " G"]


def bitflip(g: int, y: DataPoint) -> DataPoint:
  return y.at[g].set(1 - y[g])


def dfd_onepoint(
  log_q: Callable[[DataPoint], float],
  y: DataPoint,
) -> float:
  log_qy: float = log_q(y)

  def log_q_flip_fn(g: int):
    return log_q(bitflip(g, y))

  log_qflipped = jax.vmap(log_q_flip_fn)(jnp.arange(y.shape[0]))

  log_ratio = log_qflipped - log_qy

  return jnp.sum( jnp.exp(2 * log_ratio) - 2 * jnp.exp(-log_ratio))


def dfd(log_q, ys) -> float:
  f = partial(dfd_onepoint, log_q)

  return jnp.mean(jax.vmap(f)(ys))


def linear(alpha: Float[Array, " G"], y: DataPoint) -> float:
  return jnp.sum(alpha * y)

@jax.jit
def loss(alpha):
  return dfd(partial(linear, alpha), Y)

result = optimize.minimize(loss, jnp.zeros(n_genes), method="BFGS")

dfd_alpha = result.x

print_estimate_summary(dfd_alpha, name="DFD")
----- DFD -----
Absolute error on the bias:
[ 0.03998992 -0.07451569 -0.0388778   0.04673009 -0.02779001  0.00777489
  0.01333484 -0.02109128 -0.06553831 -0.02999656]
True bias:
[0.2        0.25555556 0.31111111 0.36666667 0.42222222 0.47777778
 0.53333333 0.58888889 0.64444444 0.7       ]

Wow, this looks pretty good to me! Let’s now visualise the performance of all methods:

Code
import matplotlib.pyplot as plt
plt.style.use("dark_background")

fig, axs = plt.subplots(1, 2, sharex=True, sharey=False, figsize=(4, 2), dpi=300)
for ax in axs:
  ax.spines[["top", "right"]].set_visible(False)

colors = {
  "true": "white",
  "mle": "maroon",
  "posterior_sample": "lightgrey",
  "dfd": "gold",
}

ax = axs[0]
x_axis = np.arange(1, n_genes + 1)
ax.set_xticks(x_axis)
ax.set_title("$\\alpha$")

ax = axs[1]
ax.set_xticks(x_axis)
ax.set_title("$\\pi$")


def plot(alpha_values, color, alpha=1.0, scatter: bool = True):
  ax = axs[0]
  ax.plot(x_axis, alpha_values, c=color, alpha=alpha)
  if scatter:
    ax.scatter(x_axis, alpha_values, c=color, alpha=alpha)

  ax = axs[1]
  bias_values = expit(alpha_values)
  ax.plot(x_axis, bias_values, c=color, alpha=alpha)
  if scatter:
    ax.scatter(x_axis, bias_values, c=color, alpha=alpha)


for sample in alpha_samples[::40]:
  plot(sample, colors["posterior_sample"], alpha=0.1, scatter=False)

plot(true_alpha, colors["true"])
plot(mle_alpha, colors["mle"])
plot(dfd_alpha, colors["dfd"])

fig.tight_layout()

This looks pretty good to me! Let’s do one more thing: sample from the DFD posterior:

Code
def dfd_model(mutations, temperature):
    n_samples, n_genes = mutations.shape
    
    mu = numpyro.sample("_mu", dist.Normal(0.0, 5.0))
    sigma = numpyro.sample("_sigma", dist.HalfCauchy(scale=3.0))

    z = numpyro.sample('_z', dist.Normal(jnp.zeros(n_genes), 1))

    alpha = numpyro.deterministic("alpha", mu + sigma * z)    

    numpyro.factor("dfd", -temperature * n_samples * loss(alpha))


dfd_samples = {}

temperature_range = [0.01, 0.1, 0.5, 1.0, 2.0]

for temperature in temperature_range:
  key, subkey = jax.random.split(key)

  posterior = numpyro.infer.MCMC(
    numpyro.infer.NUTS(dfd_model),
    num_chains=4,
    num_samples=500,
    num_warmup=500,
  )
  
  posterior.run(
    subkey,
    mutations=Y,
    temperature=temperature,
  )
  print(f"----- Temperature: {temperature:.2f} -----")
  posterior.print_summary()

  dfd_samples[temperature] = posterior.get_samples()["alpha"]
/tmp/ipykernel_60846/908130843.py:21: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  posterior = numpyro.infer.MCMC(
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   0%|          | 1/1000 [00:01<22:38,  1.36s/it, 1 steps of size 2.34e+00. acc. prob=0.00]warmup:  12%|█▏        | 121/1000 [00:01<00:07, 114.70it/s, 31 steps of size 1.92e-01. acc. prob=0.77]warmup:  23%|██▎       | 229/1000 [00:01<00:03, 230.07it/s, 31 steps of size 3.50e-02. acc. prob=0.78]warmup:  34%|███▍      | 345/1000 [00:01<00:01, 365.50it/s, 31 steps of size 2.54e-01. acc. prob=0.78]warmup:  49%|████▊     | 486/1000 [00:01<00:00, 544.55it/s, 31 steps of size 8.26e-02. acc. prob=0.78]sample:  63%|██████▎   | 631/1000 [00:01<00:00, 723.12it/s, 15 steps of size 2.17e-01. acc. prob=0.77]sample:  78%|███████▊  | 783/1000 [00:01<00:00, 897.42it/s, 15 steps of size 2.17e-01. acc. prob=0.76]sample:  94%|█████████▍| 944/1000 [00:02<00:00, 1067.34it/s, 15 steps of size 2.17e-01. acc. prob=0.74]sample: 100%|██████████| 1000/1000 [00:02<00:00, 475.86it/s, 15 steps of size 2.17e-01. acc. prob=0.74]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:  14%|█▍        | 138/1000 [00:00<00:00, 1376.21it/s, 31 steps of size 1.25e-01. acc. prob=0.78]warmup:  28%|██▊       | 276/1000 [00:00<00:00, 1326.50it/s, 7 steps of size 1.89e-01. acc. prob=0.78] warmup:  44%|████▍     | 439/1000 [00:00<00:00, 1459.82it/s, 7 steps of size 1.84e-01. acc. prob=0.79]sample:  61%|██████    | 606/1000 [00:00<00:00, 1539.66it/s, 15 steps of size 3.01e-01. acc. prob=0.89]sample:  78%|███████▊  | 776/1000 [00:00<00:00, 1596.10it/s, 15 steps of size 3.01e-01. acc. prob=0.83]sample:  94%|█████████▍| 938/1000 [00:00<00:00, 1601.95it/s, 15 steps of size 3.01e-01. acc. prob=0.84]sample: 100%|██████████| 1000/1000 [00:00<00:00, 1551.24it/s, 15 steps of size 3.01e-01. acc. prob=0.84]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:  11%|█         | 109/1000 [00:00<00:00, 1078.77it/s, 31 steps of size 4.75e-01. acc. prob=0.78]warmup:  26%|██▌       | 257/1000 [00:00<00:00, 1312.59it/s, 15 steps of size 6.13e-02. acc. prob=0.78]warmup:  41%|████      | 407/1000 [00:00<00:00, 1395.57it/s, 7 steps of size 5.14e-01. acc. prob=0.79] sample:  56%|█████▋    | 563/1000 [00:00<00:00, 1458.26it/s, 15 steps of size 2.28e-01. acc. prob=0.78]sample:  73%|███████▎  | 727/1000 [00:00<00:00, 1521.32it/s, 15 steps of size 2.28e-01. acc. prob=0.78]sample:  89%|████████▉ | 891/1000 [00:00<00:00, 1559.12it/s, 15 steps of size 2.28e-01. acc. prob=0.76]sample: 100%|██████████| 1000/1000 [00:00<00:00, 1489.34it/s, 15 steps of size 2.28e-01. acc. prob=0.77]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:  11%|█         | 107/1000 [00:00<00:00, 1066.82it/s, 15 steps of size 4.23e-01. acc. prob=0.77]warmup:  21%|██▏       | 214/1000 [00:00<00:00, 1067.49it/s, 15 steps of size 6.69e-02. acc. prob=0.78]warmup:  32%|███▎      | 325/1000 [00:00<00:00, 1077.00it/s, 63 steps of size 1.11e-01. acc. prob=0.78]warmup:  46%|████▌     | 457/1000 [00:00<00:00, 1171.20it/s, 15 steps of size 4.02e-01. acc. prob=0.79]sample:  57%|█████▊    | 575/1000 [00:00<00:00, 1060.12it/s, 31 steps of size 1.64e-01. acc. prob=0.92]sample:  68%|██████▊   | 683/1000 [00:00<00:00, 1043.34it/s, 31 steps of size 1.64e-01. acc. prob=0.93]sample:  79%|███████▉  | 791/1000 [00:00<00:00, 1054.47it/s, 15 steps of size 1.64e-01. acc. prob=0.92]sample:  90%|█████████ | 903/1000 [00:00<00:00, 1073.16it/s, 31 steps of size 1.64e-01. acc. prob=0.91]sample: 100%|██████████| 1000/1000 [00:00<00:00, 1078.38it/s, 31 steps of size 1.64e-01. acc. prob=0.91]
/tmp/ipykernel_60846/908130843.py:21: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  posterior = numpyro.infer.MCMC(
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   0%|          | 1/1000 [00:01<24:04,  1.45s/it, 1 steps of size 2.34e+00. acc. prob=0.00]warmup:  10%|█         | 105/1000 [00:01<00:09, 94.11it/s, 15 steps of size 3.41e-01. acc. prob=0.77]warmup:  18%|█▊        | 185/1000 [00:01<00:04, 174.29it/s, 31 steps of size 1.46e-01. acc. prob=0.78]warmup:  28%|██▊       | 275/1000 [00:01<00:02, 275.43it/s, 63 steps of size 6.17e-02. acc. prob=0.78]warmup:  36%|███▌      | 357/1000 [00:01<00:01, 365.77it/s, 31 steps of size 2.04e-01. acc. prob=0.78]warmup:  46%|████▋     | 465/1000 [00:01<00:01, 502.31it/s, 15 steps of size 4.07e-02. acc. prob=0.78]sample:  57%|█████▋    | 570/1000 [00:02<00:00, 619.46it/s, 15 steps of size 1.29e-01. acc. prob=0.91]sample:  69%|██████▉   | 691/1000 [00:02<00:00, 757.85it/s, 31 steps of size 1.29e-01. acc. prob=0.90]sample:  81%|████████  | 810/1000 [00:02<00:00, 866.06it/s, 15 steps of size 1.29e-01. acc. prob=0.89]sample:  93%|█████████▎| 934/1000 [00:02<00:00, 963.63it/s, 15 steps of size 1.29e-01. acc. prob=0.89]sample: 100%|██████████| 1000/1000 [00:02<00:00, 415.10it/s, 15 steps of size 1.29e-01. acc. prob=0.89]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   8%|▊         | 81/1000 [00:00<00:01, 805.87it/s, 15 steps of size 1.37e-02. acc. prob=0.76]warmup:  17%|█▋        | 168/1000 [00:00<00:00, 838.95it/s, 31 steps of size 7.85e-02. acc. prob=0.77]warmup:  25%|██▌       | 252/1000 [00:00<00:00, 818.74it/s, 2 steps of size 2.54e-01. acc. prob=0.78] warmup:  35%|███▌      | 351/1000 [00:00<00:00, 882.08it/s, 31 steps of size 1.21e-01. acc. prob=0.78]warmup:  46%|████▌     | 460/1000 [00:00<00:00, 954.33it/s, 15 steps of size 3.02e-01. acc. prob=0.79]sample:  56%|█████▌    | 561/1000 [00:00<00:00, 969.55it/s, 63 steps of size 1.16e-01. acc. prob=0.94]sample:  66%|██████▋   | 663/1000 [00:00<00:00, 985.18it/s, 47 steps of size 1.16e-01. acc. prob=0.95]sample:  76%|███████▌  | 762/1000 [00:00<00:00, 986.43it/s, 31 steps of size 1.16e-01. acc. prob=0.95]sample:  87%|████████▋ | 868/1000 [00:00<00:00, 1008.45it/s, 15 steps of size 1.16e-01. acc. prob=0.94]sample:  98%|█████████▊| 979/1000 [00:01<00:00, 1036.30it/s, 63 steps of size 1.16e-01. acc. prob=0.94]sample: 100%|██████████| 1000/1000 [00:01<00:00, 968.19it/s, 15 steps of size 1.16e-01. acc. prob=0.94]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   9%|▉         | 88/1000 [00:00<00:01, 876.77it/s, 31 steps of size 2.28e-02. acc. prob=0.76]warmup:  18%|█▊        | 176/1000 [00:00<00:00, 873.79it/s, 15 steps of size 1.80e-01. acc. prob=0.78]warmup:  26%|██▋       | 264/1000 [00:00<00:00, 857.51it/s, 63 steps of size 4.83e-02. acc. prob=0.78]warmup:  36%|███▌      | 357/1000 [00:00<00:00, 883.21it/s, 31 steps of size 1.13e-01. acc. prob=0.78]warmup:  45%|████▍     | 446/1000 [00:00<00:00, 861.09it/s, 31 steps of size 1.80e-01. acc. prob=0.79]sample:  54%|█████▎    | 535/1000 [00:00<00:00, 869.41it/s, 15 steps of size 1.32e-01. acc. prob=0.95]sample:  63%|██████▎   | 628/1000 [00:00<00:00, 888.53it/s, 15 steps of size 1.32e-01. acc. prob=0.94]sample:  74%|███████▎  | 737/1000 [00:00<00:00, 947.07it/s, 63 steps of size 1.32e-01. acc. prob=0.92]sample:  85%|████████▌ | 853/1000 [00:00<00:00, 1011.16it/s, 63 steps of size 1.32e-01. acc. prob=0.93]sample:  96%|█████████▌| 962/1000 [00:01<00:00, 1033.78it/s, 15 steps of size 1.32e-01. acc. prob=0.93]sample: 100%|██████████| 1000/1000 [00:01<00:00, 952.63it/s, 31 steps of size 1.32e-01. acc. prob=0.93]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   9%|▊         | 86/1000 [00:00<00:01, 852.33it/s, 127 steps of size 3.91e-02. acc. prob=0.77]warmup:  17%|█▋        | 172/1000 [00:00<00:00, 843.41it/s, 63 steps of size 1.43e-01. acc. prob=0.78]warmup:  28%|██▊       | 278/1000 [00:00<00:00, 940.86it/s, 127 steps of size 8.49e-02. acc. prob=0.78]warmup:  42%|████▎     | 425/1000 [00:00<00:00, 1145.04it/s, 23 steps of size 1.69e-01. acc. prob=0.79]sample:  54%|█████▍    | 540/1000 [00:00<00:00, 1096.10it/s, 15 steps of size 1.14e-01. acc. prob=0.96]sample:  65%|██████▌   | 651/1000 [00:00<00:00, 983.26it/s, 15 steps of size 1.14e-01. acc. prob=0.96] sample:  75%|███████▌  | 752/1000 [00:00<00:00, 957.32it/s, 15 steps of size 1.14e-01. acc. prob=0.96]sample:  85%|████████▍ | 849/1000 [00:00<00:00, 909.34it/s, 31 steps of size 1.14e-01. acc. prob=0.96]sample:  94%|█████████▍| 941/1000 [00:00<00:00, 888.96it/s, 31 steps of size 1.14e-01. acc. prob=0.96]sample: 100%|██████████| 1000/1000 [00:01<00:00, 943.11it/s, 15 steps of size 1.14e-01. acc. prob=0.96]
/tmp/ipykernel_60846/908130843.py:21: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  posterior = numpyro.infer.MCMC(
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   0%|          | 1/1000 [00:01<22:51,  1.37s/it, 1 steps of size 2.34e+00. acc. prob=0.00]warmup:   7%|▋         | 66/1000 [00:01<00:15, 61.80it/s, 79 steps of size 2.44e-02. acc. prob=0.76]warmup:  13%|█▎        | 132/1000 [00:01<00:06, 132.57it/s, 79 steps of size 1.10e-01. acc. prob=0.77]warmup:  19%|█▊        | 187/1000 [00:01<00:04, 192.42it/s, 191 steps of size 6.23e-02. acc. prob=0.77]warmup:  26%|██▌       | 262/1000 [00:01<00:02, 286.73it/s, 255 steps of size 1.26e-02. acc. prob=0.78]warmup:  32%|███▏      | 322/1000 [00:01<00:01, 346.03it/s, 31 steps of size 9.07e-02. acc. prob=0.78] warmup:  39%|███▉      | 393/1000 [00:01<00:01, 424.76it/s, 15 steps of size 1.07e-01. acc. prob=0.78]warmup:  46%|████▌     | 456/1000 [00:02<00:01, 466.23it/s, 95 steps of size 3.43e-02. acc. prob=0.78]sample:  52%|█████▏    | 518/1000 [00:02<00:00, 499.05it/s, 15 steps of size 6.86e-02. acc. prob=0.93]sample:  60%|█████▉    | 595/1000 [00:02<00:00, 569.38it/s, 15 steps of size 6.86e-02. acc. prob=0.93]sample:  67%|██████▋   | 671/1000 [00:02<00:00, 620.11it/s, 79 steps of size 6.86e-02. acc. prob=0.92]sample:  75%|███████▌  | 753/1000 [00:02<00:00, 674.96it/s, 15 steps of size 6.86e-02. acc. prob=0.92]sample:  83%|████████▎ | 834/1000 [00:02<00:00, 709.85it/s, 63 steps of size 6.86e-02. acc. prob=0.92]sample:  91%|█████████ | 910/1000 [00:02<00:00, 723.79it/s, 95 steps of size 6.86e-02. acc. prob=0.92]sample:  99%|█████████▊| 986/1000 [00:02<00:00, 713.96it/s, 63 steps of size 6.86e-02. acc. prob=0.92]sample: 100%|██████████| 1000/1000 [00:02<00:00, 354.69it/s, 15 steps of size 6.86e-02. acc. prob=0.92]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   7%|▋         | 70/1000 [00:00<00:01, 697.31it/s, 15 steps of size 2.99e-02. acc. prob=0.76]warmup:  14%|█▍        | 140/1000 [00:00<00:01, 655.18it/s, 95 steps of size 6.36e-02. acc. prob=0.77]warmup:  21%|██        | 206/1000 [00:00<00:01, 633.97it/s, 31 steps of size 9.03e-02. acc. prob=0.78]warmup:  28%|██▊       | 275/1000 [00:00<00:01, 652.89it/s, 31 steps of size 7.94e-02. acc. prob=0.78]warmup:  34%|███▍      | 343/1000 [00:00<00:00, 659.87it/s, 63 steps of size 4.77e-02. acc. prob=0.78]warmup:  41%|████      | 410/1000 [00:00<00:00, 652.81it/s, 31 steps of size 4.90e-02. acc. prob=0.78]warmup:  48%|████▊     | 480/1000 [00:00<00:00, 661.88it/s, 127 steps of size 1.30e-01. acc. prob=0.78]sample:  55%|█████▍    | 547/1000 [00:00<00:00, 648.31it/s, 23 steps of size 6.96e-02. acc. prob=0.95] sample:  61%|██████    | 612/1000 [00:00<00:00, 630.96it/s, 31 steps of size 6.96e-02. acc. prob=0.94]sample:  68%|██████▊   | 676/1000 [00:01<00:00, 621.71it/s, 63 steps of size 6.96e-02. acc. prob=0.94]sample:  74%|███████▍  | 745/1000 [00:01<00:00, 638.28it/s, 63 steps of size 6.96e-02. acc. prob=0.94]sample:  81%|████████  | 809/1000 [00:01<00:00, 626.69it/s, 111 steps of size 6.96e-02. acc. prob=0.93]sample:  87%|████████▋ | 872/1000 [00:01<00:00, 600.58it/s, 15 steps of size 6.96e-02. acc. prob=0.93] sample:  94%|█████████▍| 938/1000 [00:01<00:00, 615.82it/s, 47 steps of size 6.96e-02. acc. prob=0.93]sample: 100%|██████████| 1000/1000 [00:01<00:00, 635.73it/s, 15 steps of size 6.96e-02. acc. prob=0.93]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   8%|▊         | 78/1000 [00:00<00:01, 778.00it/s, 15 steps of size 1.88e-02. acc. prob=0.76]warmup:  16%|█▌        | 156/1000 [00:00<00:01, 671.59it/s, 159 steps of size 4.04e-02. acc. prob=0.77]warmup:  22%|██▎       | 225/1000 [00:00<00:01, 666.63it/s, 15 steps of size 1.91e-01. acc. prob=0.78] warmup:  29%|██▉       | 293/1000 [00:00<00:01, 633.64it/s, 15 steps of size 1.07e-01. acc. prob=0.78]warmup:  36%|███▋      | 364/1000 [00:00<00:00, 658.81it/s, 15 steps of size 5.66e-02. acc. prob=0.78]warmup:  45%|████▌     | 454/1000 [00:00<00:00, 733.90it/s, 191 steps of size 2.03e-02. acc. prob=0.78]sample:  53%|█████▎    | 529/1000 [00:00<00:00, 671.25it/s, 47 steps of size 6.27e-02. acc. prob=0.93] sample:  61%|██████    | 608/1000 [00:00<00:00, 703.53it/s, 31 steps of size 6.27e-02. acc. prob=0.91]sample:  68%|██████▊   | 685/1000 [00:00<00:00, 722.10it/s, 15 steps of size 6.27e-02. acc. prob=0.91]sample:  77%|███████▋  | 769/1000 [00:01<00:00, 754.85it/s, 47 steps of size 6.27e-02. acc. prob=0.91]sample:  86%|████████▌ | 856/1000 [00:01<00:00, 788.83it/s, 15 steps of size 6.27e-02. acc. prob=0.92]sample:  94%|█████████▎| 936/1000 [00:01<00:00, 784.19it/s, 15 steps of size 6.27e-02. acc. prob=0.92]sample: 100%|██████████| 1000/1000 [00:01<00:00, 725.87it/s, 47 steps of size 6.27e-02. acc. prob=0.91]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   6%|▋         | 64/1000 [00:00<00:01, 638.15it/s, 7 steps of size 7.81e-03. acc. prob=0.75]warmup:  13%|█▎        | 131/1000 [00:00<00:01, 654.81it/s, 63 steps of size 8.02e-02. acc. prob=0.77]warmup:  20%|██        | 201/1000 [00:00<00:01, 652.50it/s, 383 steps of size 5.09e-02. acc. prob=0.78]warmup:  30%|██▉       | 295/1000 [00:00<00:00, 755.78it/s, 255 steps of size 3.45e-02. acc. prob=0.78]warmup:  37%|███▋      | 371/1000 [00:00<00:00, 754.53it/s, 63 steps of size 5.71e-02. acc. prob=0.78] warmup:  46%|████▌     | 455/1000 [00:00<00:00, 771.29it/s, 255 steps of size 2.88e-02. acc. prob=0.78]sample:  53%|█████▎    | 533/1000 [00:00<00:00, 744.85it/s, 47 steps of size 5.87e-02. acc. prob=0.92] sample:  62%|██████▏   | 619/1000 [00:00<00:00, 777.03it/s, 111 steps of size 5.87e-02. acc. prob=0.92]sample:  70%|██████▉   | 697/1000 [00:00<00:00, 757.16it/s, 15 steps of size 5.87e-02. acc. prob=0.93] sample:  78%|███████▊  | 785/1000 [00:01<00:00, 791.90it/s, 63 steps of size 5.87e-02. acc. prob=0.91]sample:  86%|████████▋ | 865/1000 [00:01<00:00, 788.40it/s, 31 steps of size 5.87e-02. acc. prob=0.92]sample:  95%|█████████▌| 953/1000 [00:01<00:00, 811.67it/s, 63 steps of size 5.87e-02. acc. prob=0.92]sample: 100%|██████████| 1000/1000 [00:01<00:00, 766.68it/s, 95 steps of size 5.87e-02. acc. prob=0.92]
/tmp/ipykernel_60846/908130843.py:21: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  posterior = numpyro.infer.MCMC(
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   0%|          | 1/1000 [00:01<24:31,  1.47s/it, 1 steps of size 2.34e+00. acc. prob=0.00]warmup:   5%|▍         | 49/1000 [00:01<00:22, 43.00it/s, 63 steps of size 1.83e-02. acc. prob=0.75]warmup:  11%|█         | 108/1000 [00:01<00:08, 104.09it/s, 111 steps of size 1.33e-01. acc. prob=0.77]warmup:  16%|█▌        | 158/1000 [00:01<00:05, 158.20it/s, 63 steps of size 2.03e-02. acc. prob=0.77] warmup:  21%|██▏       | 214/1000 [00:01<00:03, 224.04it/s, 7 steps of size 5.26e-02. acc. prob=0.78] warmup:  28%|██▊       | 276/1000 [00:01<00:02, 296.36it/s, 255 steps of size 6.84e-02. acc. prob=0.78]warmup:  34%|███▍      | 338/1000 [00:02<00:01, 364.10it/s, 135 steps of size 8.97e-02. acc. prob=0.78]warmup:  40%|████      | 400/1000 [00:02<00:01, 421.03it/s, 95 steps of size 6.67e-02. acc. prob=0.78] warmup:  46%|████▋     | 463/1000 [00:02<00:01, 466.42it/s, 511 steps of size 3.62e-02. acc. prob=0.78]sample:  52%|█████▏    | 521/1000 [00:02<00:00, 493.23it/s, 15 steps of size 4.22e-02. acc. prob=0.90] sample:  58%|█████▊    | 584/1000 [00:02<00:00, 527.39it/s, 63 steps of size 4.22e-02. acc. prob=0.93]sample:  64%|██████▍   | 643/1000 [00:02<00:00, 506.63it/s, 63 steps of size 4.22e-02. acc. prob=0.94]sample:  70%|███████   | 705/1000 [00:02<00:00, 536.83it/s, 15 steps of size 4.22e-02. acc. prob=0.95]sample:  76%|███████▋  | 765/1000 [00:02<00:00, 553.06it/s, 175 steps of size 4.22e-02. acc. prob=0.94]sample:  83%|████████▎ | 832/1000 [00:02<00:00, 582.79it/s, 87 steps of size 4.22e-02. acc. prob=0.95] sample:  89%|████████▉ | 893/1000 [00:03<00:00, 564.51it/s, 15 steps of size 4.22e-02. acc. prob=0.95]sample:  95%|█████████▌| 951/1000 [00:03<00:00, 566.00it/s, 79 steps of size 4.22e-02. acc. prob=0.95]sample: 100%|██████████| 1000/1000 [00:03<00:00, 310.19it/s, 31 steps of size 4.22e-02. acc. prob=0.95]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:  10%|▉         | 99/1000 [00:00<00:00, 980.29it/s, 63 steps of size 2.15e-02. acc. prob=0.77]warmup:  20%|█▉        | 198/1000 [00:00<00:01, 681.27it/s, 79 steps of size 8.71e-02. acc. prob=0.78]warmup:  27%|██▋       | 272/1000 [00:00<00:01, 570.17it/s, 7 steps of size 2.75e-02. acc. prob=0.78] warmup:  34%|███▍      | 339/1000 [00:00<00:01, 599.04it/s, 39 steps of size 6.02e-02. acc. prob=0.78]warmup:  42%|████▏     | 424/1000 [00:00<00:00, 672.02it/s, 95 steps of size 1.16e-01. acc. prob=0.78]warmup:  50%|████▉     | 495/1000 [00:00<00:00, 606.26it/s, 15 steps of size 2.15e-02. acc. prob=0.78]sample:  56%|█████▌    | 559/1000 [00:00<00:00, 603.29it/s, 79 steps of size 3.82e-02. acc. prob=0.93]sample:  62%|██████▏   | 622/1000 [00:00<00:00, 608.28it/s, 15 steps of size 3.82e-02. acc. prob=0.94]sample:  68%|██████▊   | 685/1000 [00:01<00:00, 608.70it/s, 31 steps of size 3.82e-02. acc. prob=0.94]sample:  75%|███████▍  | 747/1000 [00:01<00:00, 605.91it/s, 15 steps of size 3.82e-02. acc. prob=0.93]sample:  81%|████████  | 809/1000 [00:01<00:00, 608.14it/s, 15 steps of size 3.82e-02. acc. prob=0.94]sample:  87%|████████▋ | 872/1000 [00:01<00:00, 614.24it/s, 95 steps of size 3.82e-02. acc. prob=0.94]sample:  93%|█████████▎| 934/1000 [00:01<00:00, 606.82it/s, 63 steps of size 3.82e-02. acc. prob=0.94]sample: 100%|█████████▉| 996/1000 [00:01<00:00, 609.79it/s, 31 steps of size 3.82e-02. acc. prob=0.94]sample: 100%|██████████| 1000/1000 [00:01<00:00, 619.40it/s, 31 steps of size 3.82e-02. acc. prob=0.94]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   9%|▉         | 92/1000 [00:00<00:00, 914.80it/s, 31 steps of size 1.27e-02. acc. prob=0.76]warmup:  18%|█▊        | 184/1000 [00:00<00:01, 538.61it/s, 63 steps of size 1.11e-01. acc. prob=0.78]warmup:  25%|██▍       | 247/1000 [00:00<00:01, 402.53it/s, 63 steps of size 4.24e-02. acc. prob=0.78]warmup:  30%|███       | 305/1000 [00:00<00:01, 446.32it/s, 15 steps of size 1.24e-01. acc. prob=0.78]warmup:  37%|███▋      | 368/1000 [00:00<00:01, 492.28it/s, 103 steps of size 9.20e-02. acc. prob=0.78]warmup:  43%|████▎     | 433/1000 [00:00<00:01, 533.02it/s, 47 steps of size 7.91e-02. acc. prob=0.79] warmup:  50%|████▉     | 499/1000 [00:00<00:00, 566.63it/s, 191 steps of size 2.37e-02. acc. prob=0.78]sample:  56%|█████▌    | 562/1000 [00:01<00:00, 578.15it/s, 167 steps of size 4.39e-02. acc. prob=0.93]sample:  62%|██████▏   | 624/1000 [00:01<00:00, 587.64it/s, 63 steps of size 4.39e-02. acc. prob=0.92] sample:  70%|██████▉   | 698/1000 [00:01<00:00, 631.73it/s, 31 steps of size 4.39e-02. acc. prob=0.91]sample:  76%|███████▋  | 763/1000 [00:01<00:00, 598.29it/s, 95 steps of size 4.39e-02. acc. prob=0.92]sample:  83%|████████▎ | 831/1000 [00:01<00:00, 616.18it/s, 127 steps of size 4.39e-02. acc. prob=0.92]sample:  91%|█████████ | 911/1000 [00:01<00:00, 667.86it/s, 15 steps of size 4.39e-02. acc. prob=0.92] sample:  98%|█████████▊| 979/1000 [00:01<00:00, 643.35it/s, 95 steps of size 4.39e-02. acc. prob=0.92]sample: 100%|██████████| 1000/1000 [00:01<00:00, 579.95it/s, 15 steps of size 4.39e-02. acc. prob=0.92]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   6%|▌         | 58/1000 [00:00<00:01, 577.58it/s, 15 steps of size 2.84e-02. acc. prob=0.75]warmup:  13%|█▎        | 133/1000 [00:00<00:01, 678.18it/s, 31 steps of size 1.48e-01. acc. prob=0.77]warmup:  20%|██        | 201/1000 [00:00<00:01, 524.95it/s, 31 steps of size 1.07e-01. acc. prob=0.78]warmup:  26%|██▋       | 265/1000 [00:00<00:01, 560.97it/s, 191 steps of size 3.58e-02. acc. prob=0.78]warmup:  34%|███▎      | 335/1000 [00:00<00:01, 603.47it/s, 159 steps of size 8.17e-02. acc. prob=0.78]warmup:  40%|███▉      | 398/1000 [00:00<00:00, 607.00it/s, 135 steps of size 1.05e-01. acc. prob=0.78]warmup:  46%|████▌     | 461/1000 [00:00<00:00, 586.03it/s, 127 steps of size 1.43e-02. acc. prob=0.78]sample:  52%|█████▏    | 523/1000 [00:00<00:00, 591.42it/s, 103 steps of size 3.85e-02. acc. prob=0.96]sample:  60%|█████▉    | 595/1000 [00:00<00:00, 627.54it/s, 31 steps of size 3.85e-02. acc. prob=0.95] sample:  66%|██████▌   | 659/1000 [00:01<00:00, 601.07it/s, 15 steps of size 3.85e-02. acc. prob=0.95]sample:  72%|███████▏  | 720/1000 [00:01<00:00, 586.66it/s, 47 steps of size 3.85e-02. acc. prob=0.94]sample:  78%|███████▊  | 780/1000 [00:01<00:00, 578.14it/s, 47 steps of size 3.85e-02. acc. prob=0.94]sample:  84%|████████▍ | 843/1000 [00:01<00:00, 591.36it/s, 95 steps of size 3.85e-02. acc. prob=0.94]sample:  90%|█████████ | 903/1000 [00:01<00:00, 581.78it/s, 31 steps of size 3.85e-02. acc. prob=0.94]sample:  97%|█████████▋| 971/1000 [00:01<00:00, 608.61it/s, 31 steps of size 3.85e-02. acc. prob=0.94]sample: 100%|██████████| 1000/1000 [00:01<00:00, 595.93it/s, 31 steps of size 3.85e-02. acc. prob=0.94]
/tmp/ipykernel_60846/908130843.py:21: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  posterior = numpyro.infer.MCMC(
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   0%|          | 1/1000 [00:01<22:11,  1.33s/it, 1 steps of size 2.34e+00. acc. prob=0.00]warmup:   6%|▋         | 64/1000 [00:01<00:15, 61.23it/s, 319 steps of size 9.04e-03. acc. prob=0.75]warmup:  11%|█         | 111/1000 [00:01<00:08, 110.52it/s, 63 steps of size 3.10e-02. acc. prob=0.76]warmup:  16%|█▋        | 164/1000 [00:01<00:04, 169.33it/s, 447 steps of size 1.95e-02. acc. prob=0.77]warmup:  22%|██▏       | 217/1000 [00:01<00:03, 230.57it/s, 47 steps of size 4.15e-02. acc. prob=0.78] warmup:  28%|██▊       | 276/1000 [00:01<00:02, 299.04it/s, 383 steps of size 2.14e-02. acc. prob=0.78]warmup:  33%|███▎      | 329/1000 [00:01<00:01, 346.68it/s, 447 steps of size 2.53e-02. acc. prob=0.78]warmup:  39%|███▉      | 391/1000 [00:02<00:01, 408.89it/s, 111 steps of size 6.06e-02. acc. prob=0.78]warmup:  46%|████▌     | 455/1000 [00:02<00:01, 465.08it/s, 159 steps of size 2.42e-02. acc. prob=0.78]sample:  51%|█████     | 512/1000 [00:02<00:01, 430.42it/s, 31 steps of size 3.50e-02. acc. prob=0.95] sample:  56%|█████▋    | 563/1000 [00:02<00:01, 413.99it/s, 319 steps of size 3.50e-02. acc. prob=0.95]sample:  61%|██████    | 610/1000 [00:02<00:00, 423.45it/s, 127 steps of size 3.50e-02. acc. prob=0.95]sample:  66%|██████▌   | 662/1000 [00:02<00:00, 442.35it/s, 191 steps of size 3.50e-02. acc. prob=0.94]sample:  71%|███████   | 711/1000 [00:02<00:00, 453.16it/s, 63 steps of size 3.50e-02. acc. prob=0.94] sample:  76%|███████▋  | 763/1000 [00:02<00:00, 465.17it/s, 191 steps of size 3.50e-02. acc. prob=0.93]sample:  81%|████████  | 812/1000 [00:02<00:00, 465.18it/s, 15 steps of size 3.50e-02. acc. prob=0.93] sample:  87%|████████▋ | 870/1000 [00:03<00:00, 494.00it/s, 159 steps of size 3.50e-02. acc. prob=0.93]sample:  92%|█████████▏| 921/1000 [00:03<00:00, 477.23it/s, 15 steps of size 3.50e-02. acc. prob=0.93] sample:  97%|█████████▋| 970/1000 [00:03<00:00, 454.22it/s, 15 steps of size 3.50e-02. acc. prob=0.93]sample: 100%|██████████| 1000/1000 [00:03<00:00, 297.69it/s, 15 steps of size 3.50e-02. acc. prob=0.93]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   6%|▌         | 61/1000 [00:00<00:01, 605.09it/s, 63 steps of size 6.04e-03. acc. prob=0.74]warmup:  12%|█▏        | 122/1000 [00:00<00:01, 584.52it/s, 54 steps of size 1.01e-02. acc. prob=0.76]warmup:  18%|█▊        | 181/1000 [00:00<00:01, 576.41it/s, 63 steps of size 5.78e-02. acc. prob=0.77]warmup:  24%|██▍       | 239/1000 [00:00<00:01, 535.90it/s, 31 steps of size 1.15e-01. acc. prob=0.78]warmup:  29%|██▉       | 293/1000 [00:00<00:01, 442.11it/s, 15 steps of size 1.02e-01. acc. prob=0.78]warmup:  34%|███▍      | 340/1000 [00:00<00:01, 449.55it/s, 127 steps of size 7.71e-02. acc. prob=0.78]warmup:  40%|████      | 401/1000 [00:00<00:01, 487.35it/s, 255 steps of size 4.94e-02. acc. prob=0.78]warmup:  45%|████▌     | 454/1000 [00:00<00:01, 474.94it/s, 767 steps of size 8.95e-03. acc. prob=0.78]sample:  50%|█████     | 503/1000 [00:01<00:01, 439.51it/s, 175 steps of size 2.65e-02. acc. prob=0.96]sample:  55%|█████▍    | 548/1000 [00:01<00:01, 432.44it/s, 175 steps of size 2.65e-02. acc. prob=0.95]sample:  59%|█████▉    | 592/1000 [00:01<00:01, 403.54it/s, 31 steps of size 2.65e-02. acc. prob=0.96] sample:  63%|██████▎   | 634/1000 [00:01<00:00, 402.29it/s, 255 steps of size 2.65e-02. acc. prob=0.95]sample:  68%|██████▊   | 678/1000 [00:01<00:00, 412.09it/s, 255 steps of size 2.65e-02. acc. prob=0.95]sample:  72%|███████▏  | 723/1000 [00:01<00:00, 415.83it/s, 287 steps of size 2.65e-02. acc. prob=0.95]sample:  76%|███████▋  | 765/1000 [00:01<00:00, 383.89it/s, 15 steps of size 2.65e-02. acc. prob=0.95] sample:  81%|████████  | 811/1000 [00:01<00:00, 402.33it/s, 95 steps of size 2.65e-02. acc. prob=0.95]sample:  85%|████████▌ | 852/1000 [00:01<00:00, 388.70it/s, 95 steps of size 2.65e-02. acc. prob=0.95]sample:  91%|█████████ | 910/1000 [00:02<00:00, 438.71it/s, 127 steps of size 2.65e-02. acc. prob=0.96]sample:  96%|█████████▌| 955/1000 [00:02<00:00, 409.88it/s, 255 steps of size 2.65e-02. acc. prob=0.96]sample: 100%|█████████▉| 997/1000 [00:02<00:00, 398.76it/s, 15 steps of size 2.65e-02. acc. prob=0.95] sample: 100%|██████████| 1000/1000 [00:02<00:00, 433.97it/s, 127 steps of size 2.65e-02. acc. prob=0.95]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   5%|▍         | 49/1000 [00:00<00:01, 488.03it/s, 39 steps of size 1.41e-02. acc. prob=0.74]warmup:  11%|█         | 110/1000 [00:00<00:01, 554.03it/s, 111 steps of size 8.65e-02. acc. prob=0.77]warmup:  17%|█▋        | 166/1000 [00:00<00:01, 531.65it/s, 31 steps of size 4.23e-02. acc. prob=0.77] warmup:  22%|██▏       | 220/1000 [00:00<00:01, 469.94it/s, 31 steps of size 3.82e-02. acc. prob=0.78]warmup:  27%|██▋       | 270/1000 [00:00<00:01, 470.52it/s, 191 steps of size 1.21e-02. acc. prob=0.78]warmup:  32%|███▏      | 318/1000 [00:00<00:01, 424.56it/s, 239 steps of size 2.51e-02. acc. prob=0.78]warmup:  38%|███▊      | 378/1000 [00:00<00:01, 470.23it/s, 95 steps of size 3.77e-02. acc. prob=0.78] warmup:  43%|████▎     | 427/1000 [00:00<00:01, 465.21it/s, 47 steps of size 2.71e-02. acc. prob=0.78]warmup:  48%|████▊     | 475/1000 [00:01<00:01, 414.12it/s, 191 steps of size 2.49e-02. acc. prob=0.78]sample:  52%|█████▏    | 520/1000 [00:01<00:01, 421.12it/s, 199 steps of size 3.39e-02. acc. prob=0.91]sample:  56%|█████▋    | 564/1000 [00:01<00:01, 421.10it/s, 15 steps of size 3.39e-02. acc. prob=0.93] sample:  61%|██████    | 610/1000 [00:01<00:00, 430.16it/s, 95 steps of size 3.39e-02. acc. prob=0.93]sample:  65%|██████▌   | 654/1000 [00:01<00:00, 402.25it/s, 79 steps of size 3.39e-02. acc. prob=0.93]sample:  72%|███████▏  | 716/1000 [00:01<00:00, 461.58it/s, 31 steps of size 3.39e-02. acc. prob=0.93]sample:  76%|███████▋  | 764/1000 [00:01<00:00, 465.96it/s, 95 steps of size 3.39e-02. acc. prob=0.92]sample:  81%|████████  | 812/1000 [00:01<00:00, 452.43it/s, 175 steps of size 3.39e-02. acc. prob=0.93]sample:  86%|████████▌ | 861/1000 [00:01<00:00, 462.43it/s, 15 steps of size 3.39e-02. acc. prob=0.92] sample:  91%|█████████ | 908/1000 [00:02<00:00, 439.42it/s, 15 steps of size 3.39e-02. acc. prob=0.92]sample:  95%|█████████▌| 953/1000 [00:02<00:00, 438.54it/s, 31 steps of size 3.39e-02. acc. prob=0.92]sample: 100%|██████████| 1000/1000 [00:02<00:00, 453.73it/s, 15 steps of size 3.39e-02. acc. prob=0.92]
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   7%|▋         | 72/1000 [00:00<00:01, 678.61it/s, 255 steps of size 1.16e-02. acc. prob=0.75]warmup:  14%|█▍        | 140/1000 [00:00<00:01, 454.79it/s, 39 steps of size 7.82e-03. acc. prob=0.77]warmup:  19%|█▉        | 190/1000 [00:00<00:02, 352.63it/s, 159 steps of size 3.29e-02. acc. prob=0.77]warmup:  24%|██▎       | 236/1000 [00:00<00:02, 381.03it/s, 79 steps of size 6.74e-02. acc. prob=0.78] warmup:  28%|██▊       | 278/1000 [00:00<00:02, 344.27it/s, 15 steps of size 6.07e-02. acc. prob=0.78]warmup:  32%|███▎      | 325/1000 [00:00<00:01, 375.88it/s, 45 steps of size 2.14e-02. acc. prob=0.78]warmup:  37%|███▋      | 374/1000 [00:00<00:01, 405.79it/s, 159 steps of size 4.15e-02. acc. prob=0.78]warmup:  43%|████▎     | 427/1000 [00:01<00:01, 440.22it/s, 7 steps of size 3.09e-02. acc. prob=0.78]  warmup:  47%|████▋     | 473/1000 [00:01<00:01, 396.00it/s, 15 steps of size 5.15e-02. acc. prob=0.78]sample:  52%|█████▏    | 520/1000 [00:01<00:01, 415.52it/s, 63 steps of size 3.29e-02. acc. prob=0.91]sample:  56%|█████▋    | 564/1000 [00:01<00:01, 418.76it/s, 15 steps of size 3.29e-02. acc. prob=0.92]sample:  62%|██████▏   | 623/1000 [00:01<00:00, 462.00it/s, 159 steps of size 3.29e-02. acc. prob=0.91]sample:  67%|██████▋   | 671/1000 [00:01<00:00, 425.96it/s, 15 steps of size 3.29e-02. acc. prob=0.91] sample:  72%|███████▏  | 715/1000 [00:01<00:00, 410.81it/s, 7 steps of size 3.29e-02. acc. prob=0.91] sample:  76%|███████▌  | 757/1000 [00:01<00:00, 392.43it/s, 15 steps of size 3.29e-02. acc. prob=0.91]sample:  81%|████████  | 808/1000 [00:01<00:00, 423.23it/s, 31 steps of size 3.29e-02. acc. prob=0.92]sample:  85%|████████▌ | 852/1000 [00:02<00:00, 409.48it/s, 95 steps of size 3.29e-02. acc. prob=0.92]sample:  90%|████████▉ | 897/1000 [00:02<00:00, 416.68it/s, 143 steps of size 3.29e-02. acc. prob=0.92]sample:  94%|█████████▍| 944/1000 [00:02<00:00, 430.65it/s, 255 steps of size 3.29e-02. acc. prob=0.91]sample:  99%|█████████▉| 988/1000 [00:02<00:00, 406.61it/s, 111 steps of size 3.29e-02. acc. prob=0.91]sample: 100%|██████████| 1000/1000 [00:02<00:00, 412.09it/s, 31 steps of size 3.29e-02. acc. prob=0.91]
----- Temperature: 0.01 -----

                mean       std    median      5.0%     95.0%     n_eff     r_hat
       _mu     -0.13      0.26     -0.14     -0.52      0.31    727.16      1.00
    _sigma      0.43      0.29      0.39      0.00      0.82    547.15      1.00
     _z[0]     -0.85      0.99     -0.92     -2.65      0.62   1237.86      1.00
     _z[1]     -0.30      0.82     -0.32     -1.66      1.03   1398.45      1.00
     _z[2]     -0.26      0.84     -0.28     -1.67      1.07   1504.57      1.00
     _z[3]     -0.33      0.87     -0.34     -1.72      1.10   1825.18      1.00
     _z[4]     -0.02      0.85     -0.02     -1.50      1.29   1658.80      1.00
     _z[5]      0.01      0.84      0.01     -1.28      1.36   1918.23      1.00
     _z[6]      0.14      0.81      0.13     -1.14      1.43   1666.27      1.00
     _z[7]      0.30      0.79      0.32     -1.08      1.52   1669.95      1.00
     _z[8]      0.56      0.85      0.57     -0.82      1.99   1633.24      1.00
     _z[9]      0.57      0.86      0.59     -0.86      1.94   1463.64      1.00

Number of divergences: 21
----- Temperature: 0.10 -----

                mean       std    median      5.0%     95.0%     n_eff     r_hat
       _mu     -0.14      0.33     -0.13     -0.61      0.40    228.95      1.00
    _sigma      0.91      0.28      0.86      0.52      1.29    226.41      1.03
     _z[0]     -1.75      0.58     -1.72     -2.65     -0.76    228.53      1.02
     _z[1]     -0.60      0.43     -0.59     -1.32      0.05    280.44      1.01
     _z[2]     -0.50      0.42     -0.50     -1.17      0.23    356.88      1.00
     _z[3]     -0.65      0.42     -0.65     -1.29      0.07    299.21      1.01
     _z[4]     -0.05      0.41     -0.06     -0.75      0.59    342.85      1.00
     _z[5]      0.03      0.40      0.04     -0.60      0.68    315.56      1.00
     _z[6]      0.23      0.41      0.22     -0.42      0.88    315.50      1.00
     _z[7]      0.62      0.43      0.62     -0.13      1.26    330.74      1.01
     _z[8]      1.13      0.47      1.14      0.34      1.85    310.08      1.01
     _z[9]      1.25      0.49      1.24      0.45      2.03    286.92      1.02

Number of divergences: 0
----- Temperature: 0.50 -----

                mean       std    median      5.0%     95.0%     n_eff     r_hat
       _mu     -0.13      0.27     -0.14     -0.55      0.33    193.62      1.03
    _sigma      0.92      0.25      0.87      0.54      1.26    216.76      1.00
     _z[0]     -1.76      0.51     -1.73     -2.68     -1.00    258.98      1.01
     _z[1]     -0.64      0.33     -0.64     -1.21     -0.10    278.92      1.02
     _z[2]     -0.54      0.32     -0.54     -1.06     -0.01    272.43      1.02
     _z[3]     -0.70      0.33     -0.70     -1.21     -0.09    275.79      1.02
     _z[4]     -0.06      0.30     -0.07     -0.53      0.43    239.25      1.02
     _z[5]      0.03      0.31      0.03     -0.46      0.53    240.01      1.02
     _z[6]      0.25      0.32      0.24     -0.28      0.79    220.40      1.02
     _z[7]      0.67      0.36      0.67      0.06      1.23    193.99      1.02
     _z[8]      1.19      0.43      1.19      0.48      1.89    181.31      1.01
     _z[9]      1.31      0.46      1.31      0.51      2.01    181.15      1.01

Number of divergences: 1
----- Temperature: 1.00 -----

                mean       std    median      5.0%     95.0%     n_eff     r_hat
       _mu     -0.19      0.27     -0.17     -0.56      0.29    176.79      1.01
    _sigma      0.92      0.24      0.87      0.57      1.26    215.09      1.01
     _z[0]     -1.70      0.51     -1.70     -2.63     -0.93    211.78      1.00
     _z[1]     -0.60      0.33     -0.60     -1.20     -0.12    197.15      1.00
     _z[2]     -0.50      0.32     -0.51     -1.01      0.02    200.85      1.00
     _z[3]     -0.65      0.34     -0.66     -1.17     -0.05    203.18      1.00
     _z[4]     -0.02      0.28     -0.03     -0.45      0.50    188.09      1.01
     _z[5]      0.07      0.28      0.06     -0.36      0.59    189.28      1.01
     _z[6]      0.30      0.28      0.28     -0.17      0.78    195.43      1.01
     _z[7]      0.72      0.31      0.70      0.17      1.21    177.20      1.02
     _z[8]      1.24      0.37      1.22      0.61      1.84    186.58      1.02
     _z[9]      1.35      0.39      1.33      0.67      1.97    187.23      1.02

Number of divergences: 0
----- Temperature: 2.00 -----

                mean       std    median      5.0%     95.0%     n_eff     r_hat
       _mu     -0.21      0.33     -0.20     -0.72      0.28    173.57      1.01
    _sigma      0.93      0.25      0.89      0.57      1.30    182.60      1.02
     _z[0]     -1.66      0.54     -1.62     -2.54     -0.85    153.43      1.04
     _z[1]     -0.57      0.37     -0.57     -1.12      0.03    158.81      1.03
     _z[2]     -0.47      0.35     -0.47     -1.01      0.10    161.59      1.03
     _z[3]     -0.62      0.37     -0.63     -1.20     -0.03    158.92      1.03
     _z[4]      0.01      0.33      0.00     -0.55      0.58    176.08      1.02
     _z[5]      0.10      0.33      0.09     -0.42      0.70    186.36      1.02
     _z[6]      0.33      0.34      0.32     -0.26      0.87    191.94      1.01
     _z[7]      0.75      0.38      0.74      0.08      1.32    200.20      1.01
     _z[8]      1.26      0.45      1.26      0.49      1.96    200.88      1.00
     _z[9]      1.37      0.47      1.37      0.57      2.10    199.99      1.00

Number of divergences: 1

We can visualise the samples and compare them with the usual Bayesian approach:

Code
fig, axs = plt.subplots(n_genes, len(temperature_range) + 1)

for i in range(n_genes):
  for j in range(len(temperature_range) + 1):
    samples = alpha_samples if j == len(temperature_range) else  dfd_samples[temperature_range[j]]

    color = colors["posterior_sample"] if j == len(temperature_range) else colors["dfd"]
    if j == len(temperature_range):
      c_alpha = 1.0
    else:
      c_alpha = 0.3 + 0.7 * j / len(temperature_range)

    bias_samples = expit(samples)

    ax = axs[i, j]
    ax.hist(bias_samples[:, i], bins=np.linspace(0, 1, 20), density=True, alpha=c_alpha, color=color)

    ax.set_yticks([]) 
    ax.set_xticks([])
    ax.spines[["top", "right"]].set_visible(False)

    ax.axvline(true_bias[i], color=colors["true"], linestyle="--"  )

for i, ax in enumerate(axs[:, 0]):
  ax.set_ylabel(f"{i}")

for j, ax in enumerate(axs[0, :]):
  name = "Bayes" if j == len(temperature_range) else  f"$T$={temperature_range[j]:.2f}"
  ax.set_title(name)

for ax in axs[-1, :]:
  ax.set_xticks([0, 0.5, 1])

If the temperature is too low, the prior seems to have too large influence. In this case \(T = 0.1\) seems to be the most reasonable, yielding reasonable uncertainty quantification.

Summary

Overall, I have to say that I very much like the discrete Fisher divergence idea and I congratulate the authors on this article!

In the manuscript there are some experiments with the Ising model – I will need to try this method on the related high-dimensional problems, as (on the toy problem above) it showed some very promising performance. This method is also easy to implement and fast to run.