Clutter example for expectation propagation

The code below uses the Expectation Propagation algorithm to solve the Clutter problem, Bishop et al., 2006, section 10.7.1

Implementation Details

Below is the code for the EP functions used in this example:

import numpy as np
from scipy.stats import multivariate_normal


def init(b, D, N):
    m_f = np.zeros(shape=(N, D))
    v_f = np.ones(shape=N)*np.inf
    s_f = np.ones(shape=N)
    m = np.array([0.0])
    v = b
    return m, v, m_f, v_f, s_f


def get_cavity_var(v, v_fn):
    """Returns the cavity variance :math:`v^{\\setduf n}`.

    : param v: approximate posterior q variance
    : type v: float
    : param v_fn: factor variance
    : type v_fn: float
    : return: cavity variance :math:`v^{\\setduf n}`
    : rtype: float
    """
    if np.isposinf(v_fn):
        v_c = v
    else:
        v_c = v_fn * v / (v_fn - v)
    return v_c


def get_cavity_mean(m, m_fn, v_fn, v_cn):
    """Returns the cavity mean :math:`m^{\\setduf n}`.

    : param m: approximate posterior q mean
    : type m: array of length D
    : param m_fn: factor mean
    : type m_fn: array of length D
    : param v_fn: factor variance
    : type v_fn: float
    : param v_cn: cavity variance
    : type v_cn: float
    : return: cavity mean :math:`m^{\\setduf b}`
    : rtype: array of length D
    """
    m_c = m + v_cn / v_fn * (m - m_fn)
    return m_c


def get_zeroth_moment(w, a, m_cn, v_cn, x_n):
    """Returns the zeroth moment :math:`Z_n`.

    : param w: weight in the clutter problem likelihood
    : type w: float
    : param a: signal variance in clutter problem
    : type a: float
    : param m_c: cavity mean
    : type m: array of length D
    : param v_c: cavity variance
    : type v_c: float
    : return: zeorth moment :math:`Z_n`
    : rtype: float
    """
    D = len(m_cn)
    Z_n = ((1 - w) * multivariate_normal(
        mean=m_cn, cov=(v_cn + 1) * np.eye(D)).pdf(x=x_n) +
        w * multivariate_normal(mean=np.zeros(D),
                                cov=a * np.eye(D)).pdf(x=x_n))
    return Z_n


def get_site_strength(w, a, D, Z_n, x_n):
    """Returns the site strength :math:`\\rho_n`.

    : param m: approximate posterior q mean
    : type m: float
    : param m_fn: factor mean
    : type m_fn: float
    : param v_fn: factor variance
    : type v_fn: float
    : param v_cn: cavity variance
    : type v_cn: float
    : return: site strength :math:`\\rho_n`
    : rtype: float
    """
    rho_n = 1 - w / Z_n * multivariate_normal(mean=np.zeros(D),
                                              cov=a * np.eye(D)).pdf(x_n)
    return rho_n


def get_q_mean(m_cn, v_cn, rho_n, x_n):
    """Returns the mean of the approximate distribution :math:`q`.

    : param m_cn: cavity mean
    : type m_cn: array of length D
    : param v_cn: cavity variance
    : type v_cn: float
    : param rho_n: site strength
    : type rho_n: float
    : param x_n: sample
    : type x_n: array of dimension D
    : return: approximate posterior distribution mean
    : rtype: array of length D
    """
    m = m_cn + v_cn / (1 + v_cn) * rho_n * (x_n - m_cn)
    return m


def get_q_var(m_cn, v_cn, rho_n, x_n):
    """Returns the variance of the approximate distribution :math:`q`.

    : param m_cn: cavity mean
    : type m_cn: array of length D
    : param v_cn: cavity variance
    : type v_cn: float
    : param rho_n: site strength
    : type rho_n: float
    : param x_n: sample
    : type x_n: array of dimension D
    : return: approximate posterior variance
    : rtype: float
    """
    D = len(m_cn)
    v = (-(v_cn)**2 * rho_n / (1 + v_cn) + rho_n * (1 - rho_n) * (v_cn)**2 *
         np.linalg.norm(x_n - m_cn)**2 / (D * (1 + v_cn)**2) + v_cn)
    return v


def get_factor_var(v_cn, v, tol=1e-9):
    """Returns the factor variance.

    : param v_cn: cavity variance
    : type v_cn: float
    : param v: approximate posterior distribution variance
    : rtype: float
    : return: factor variance
    : rtype: float
    """

    if np.abs(v_cn - v) < tol:
        # v_n = np.inf
        v_n = 1e10
    else:
        v_n = v_cn * v / (v_cn - v)
    if np.isinf(v_n):
        raise ValueError("v_n is infty")
    return v_n


def get_factor_mean(m_cn, v_cn, v_fn, m, tol=1e-9):
    """Returns the factor mean.

    : param m_cn: cavity mean
    : type m_cn: array of dimension D
    : param v_cn: cavity variance
    : type v_cn: float
    : param v_fn: factor variance
    : rtype: float
    : param m: approximate posterior
    : type m: array of dimension D
    : return: factor variance
    : rtype: float
    """

    # inf x 0 problem
    if v_fn == np.inf and (v_fn == 0 or any(m - m_cn) < tol):
        m_n = m_cn
    else:
        m_n = m_cn + (v_fn + v_cn) / v_cn * (m - m_cn)
    return m_n


def get_factor_scale(Z_n, m_fn, v_fn, m_cn, v_cn):
    """Returns the factor scale.

    : param Z_n: zeroth moment
    : type Z_n: float
    : param m_fn: factor mean
    : type m_fn: array of length D
    : param v_fn: factor variance
    : type v_fn: float
    : param m_cn: cavity mean
    : type m_cn: array of dimenion D
    : param v_cn: cavity variance
    : type v_cn: float
    : return: factor scale
    : rtype: float
    """

    D = len(m_cn)
    var = v_cn + v_fn
    if var > 0:
        pdf_value = multivariate_normal(m_cn, var * np.eye(D)).pdf(m_fn)
    else:
        pdf_value = -multivariate_normal(m_cn, -var * np.eye(D)).pdf(m_fn)
    # pdf_value = (1.0 / (2 * np.pi * var)**(D/2.0) *
    #              np.exp(-1.0 / (2 * var) * np.linalg.norm(m_fn - m_cn)**2))
    # if np.isnan(pdf_value):
    #     raise ValueError("NaN pdf_value in get_factor_scale")
    sn = Z_n / pdf_value
    return sn


def get_log_evidence(m, v, m_f, v_f, s_f, b):
    """Returns the model evidence

    : param m: approximate posterior mean
    : type m: array of dimension D
    : param v: approximate posterior variance
    : type v: float
    : param m_f: factor means
    : type m_f: N \\times D array
    : param v_f: factor variance
    : type v_f: array of dimension N
    : param s_f: factor scale
    : type s_f: array of dimension N
    : param b: prior variance
    : type b: float
    : return: model evidence
    : rtype: float
    """

    D = len(m_f[0])
    N = len(s_f)

    constants_terms = D / 2 * (np.log(v) - np.log(b) - N * np.log(2*np.pi))

    log_s_term = (np.sum(np.log(np.abs(s_f))) -
                  D / 2 * np.sum(np.log(np.abs(v_f))))

    exp_terms = 0.5 * (np.sum(m**2) / v - np.sum(np.sum(m_f**2, axis=1) / v_f))

    log_evidence = constants_terms + log_s_term + exp_terms

    return log_evidence


Import requirments

import numpy as np
from scipy.stats import norm
from IPython.display import display
import plotly.graph_objects as go

import ep.examples.clutter.utils
import ep.examples.clutter.core
import ep.examples.clutter.plot

Sample from the clutter model

theta = 3.0
a = 10.0
b = 100.0
w = 0.5
N = 30
num_iter = 20

samples = ep.examples.clutter.utils.sample(theta=theta, a=a, w=w, n_samples=N)

Plot sampled data

x_min = -10
x_max = 10
x_dt = 0.1

x_dense = np.arange(x_min, x_max, x_dt)
signal_pdf_values = norm.pdf(x_dense, loc=theta, scale=1.0)
noise_pdf_values = norm.pdf(x_dense, loc=0, scale=np.sqrt(a))

fig = go.Figure()
trace = go.Scatter(x=samples, y=np.zeros(shape=samples.shape),
                   mode="markers", marker=dict(symbol="x", color="black"))
fig.add_trace(trace)
trace = go.Scatter(x=x_dense, y=signal_pdf_values, mode="lines",
                   line=dict(color="green"))
fig.add_trace(trace)
trace = go.Scatter(x=x_dense, y=noise_pdf_values, mode="lines",
                   line=dict(color="red"))
fig.add_trace(trace)
fig.update_xaxes(title=r"$\theta$")
fig.update_layout(showlegend=False)
fig


Expectation Propagation script

samples = [np.array([sample]) for sample in samples]
D = len(samples[0])
m, v, m_f, v_f, s_f = ep.examples.clutter.core.init(b=b, D=D, N=N)
log_evidences = []
snapshots = []

for iter_num in range(num_iter):
    for n in range(N):
        v_cn = ep.examples.clutter.core.get_cavity_var(v=v, v_fn=v_f[n])
        m_cn = ep.examples.clutter.core.get_cavity_mean(m=m, m_fn=m_f[n], v_fn=v_f[n],
                                                        v_cn=v_cn)
        Z_n = ep.examples.clutter.core.get_zeroth_moment(w=w, a=a, m_cn=m_cn,
                                                         v_cn=v_cn,
                                                         x_n=samples[n])
        rho_n = ep.examples.clutter.core.get_site_strength(w=w, a=a, D=D,
                                                           Z_n=Z_n,
                                                           x_n=samples[n])
        m = ep.examples.clutter.core.get_q_mean(m_cn=m_cn, v_cn=v_cn,
                                                rho_n=rho_n, x_n=samples[n])
        v = ep.examples.clutter.core.get_q_var(m_cn=m_cn, v_cn=v_cn,
                                               rho_n=rho_n, x_n=samples[n])
        v_f[n] = ep.examples.clutter.core.get_factor_var(v_cn=v_cn, v=v)
        m_f[n] = ep.examples.clutter.core.get_factor_mean(m_cn=m_cn, v_cn=v_cn,
                                                          v_fn=v_f[n], m=m)
        s_f[n] = ep.examples.clutter.core.get_factor_scale(Z_n=Z_n, m_fn=m_f[n],
                                                           v_fn=v_f[n],
                                                           m_cn=m_cn, v_cn=v_cn)
    snapshots.append({
        "iter": iter_num,
        "v_cn": v_cn,
        "m_cn": m_cn.copy(),
        "v": v,
        "m": m.copy(),
        "v_fn": v_f[n],
        "m_fn": m_f[n].copy(),
    })
    log_evidence = ep.examples.clutter.core.get_log_evidence(m=m, v=v, m_f=m_f,
                                                    v_f=v_f, s_f=s_f, b=b)
    log_evidences.append(log_evidence)

Plot EP probability density functions after iteration 0

iter_num = 0
s = snapshots[iter_num]
ep.examples.clutter.plot.plot_pdfs(theta=theta, m_cn=s["m_cn"], v_cn=s["v_cn"],
                                   m=s["m"], v=s["v"], m_fn=s["m_fn"],
                                   v_fn=s["v_fn"], samples=samples[:N],
                                   x_min=x_min, x_max=x_max, x_dt=x_dt,
                                   title=f"Iteration {iter_num}, Factor {N-1}")


Plot EP probability density functions after iteration 3

iter_num = 3
s = snapshots[iter_num]
ep.examples.clutter.plot.plot_pdfs(theta=theta, m_cn=s["m_cn"], v_cn=s["v_cn"],
                                   m=s["m"], v=s["v"], m_fn=s["m_fn"],
                                   v_fn=s["v_fn"], samples=samples[:N],
                                   x_min=x_min, x_max=x_max, x_dt=x_dt,
                                   title=f"Iteration {iter_num}, Factor {N-1}")


Plot EP probability density functions after iteration 6

iter_num = 6
s = snapshots[iter_num]
ep.examples.clutter.plot.plot_pdfs(theta=theta, m_cn=s["m_cn"], v_cn=s["v_cn"],
                                   m=s["m"], v=s["v"], m_fn=s["m_fn"],
                                   v_fn=s["v_fn"], samples=samples[:N],
                                   x_min=x_min, x_max=x_max, x_dt=x_dt,
                                   title=f"Iteration {iter_num}, Factor {N-1}")


Plot EP probability density functions after iteration 9

iter_num = 9
s = snapshots[iter_num]
ep.examples.clutter.plot.plot_pdfs(theta=theta, m_cn=s["m_cn"], v_cn=s["v_cn"],
                                   m=s["m"], v=s["v"], m_fn=s["m_fn"],
                                   v_fn=s["v_fn"], samples=samples[:N],
                                   x_min=x_min, x_max=x_max, x_dt=x_dt,
                                   title=f"Iteration {iter_num}, Factor {N-1}")


Plot EP probability density functions after iteration 19

iter_num = 19
s = snapshots[iter_num]
ep.examples.clutter.plot.plot_pdfs(theta=theta, m_cn=s["m_cn"], v_cn=s["v_cn"],
                                   m=s["m"], v=s["v"], m_fn=s["m_fn"],
                                   v_fn=s["v_fn"], samples=samples[:N],
                                   x_min=x_min, x_max=x_max, x_dt=x_dt,
                                   title=f"Iteration {iter_num}, Factor {N-1}")


Plot log evidences

ep.examples.clutter.plot.plot_log_evidences(log_evidences)


Total running time of the script: (0 minutes 1.249 seconds)

Gallery generated by Sphinx-Gallery