Note
Go to the end to download the full example code.
True and estimated posteriors¶
The code below plots the true and estimated posteriors of a sample of two data points from the clutter problem.
Implementation Details¶
Below is the code for the EP functions used in this example:
import numpy as np
from scipy.stats import multivariate_normal
def init(b, D, N):
m_f = np.zeros(shape=(N, D))
v_f = np.ones(shape=N)*np.inf
s_f = np.ones(shape=N)
m = np.array([0.0])
v = b
return m, v, m_f, v_f, s_f
def get_cavity_var(v, v_fn):
"""Returns the cavity variance :math:`v^{\\setduf n}`.
: param v: approximate posterior q variance
: type v: float
: param v_fn: factor variance
: type v_fn: float
: return: cavity variance :math:`v^{\\setduf n}`
: rtype: float
"""
if np.isposinf(v_fn):
v_c = v
else:
v_c = v_fn * v / (v_fn - v)
return v_c
def get_cavity_mean(m, m_fn, v_fn, v_cn):
"""Returns the cavity mean :math:`m^{\\setduf n}`.
: param m: approximate posterior q mean
: type m: array of length D
: param m_fn: factor mean
: type m_fn: array of length D
: param v_fn: factor variance
: type v_fn: float
: param v_cn: cavity variance
: type v_cn: float
: return: cavity mean :math:`m^{\\setduf b}`
: rtype: array of length D
"""
m_c = m + v_cn / v_fn * (m - m_fn)
return m_c
def get_zeroth_moment(w, a, m_cn, v_cn, x_n):
"""Returns the zeroth moment :math:`Z_n`.
: param w: weight in the clutter problem likelihood
: type w: float
: param a: signal variance in clutter problem
: type a: float
: param m_c: cavity mean
: type m: array of length D
: param v_c: cavity variance
: type v_c: float
: return: zeorth moment :math:`Z_n`
: rtype: float
"""
D = len(m_cn)
Z_n = ((1 - w) * multivariate_normal(
mean=m_cn, cov=(v_cn + 1) * np.eye(D)).pdf(x=x_n) +
w * multivariate_normal(mean=np.zeros(D),
cov=a * np.eye(D)).pdf(x=x_n))
return Z_n
def get_site_strength(w, a, D, Z_n, x_n):
"""Returns the site strength :math:`\\rho_n`.
: param m: approximate posterior q mean
: type m: float
: param m_fn: factor mean
: type m_fn: float
: param v_fn: factor variance
: type v_fn: float
: param v_cn: cavity variance
: type v_cn: float
: return: site strength :math:`\\rho_n`
: rtype: float
"""
rho_n = 1 - w / Z_n * multivariate_normal(mean=np.zeros(D),
cov=a * np.eye(D)).pdf(x_n)
return rho_n
def get_q_mean(m_cn, v_cn, rho_n, x_n):
"""Returns the mean of the approximate distribution :math:`q`.
: param m_cn: cavity mean
: type m_cn: array of length D
: param v_cn: cavity variance
: type v_cn: float
: param rho_n: site strength
: type rho_n: float
: param x_n: sample
: type x_n: array of dimension D
: return: approximate posterior distribution mean
: rtype: array of length D
"""
m = m_cn + v_cn / (1 + v_cn) * rho_n * (x_n - m_cn)
return m
def get_q_var(m_cn, v_cn, rho_n, x_n):
"""Returns the variance of the approximate distribution :math:`q`.
: param m_cn: cavity mean
: type m_cn: array of length D
: param v_cn: cavity variance
: type v_cn: float
: param rho_n: site strength
: type rho_n: float
: param x_n: sample
: type x_n: array of dimension D
: return: approximate posterior variance
: rtype: float
"""
D = len(m_cn)
v = (-(v_cn)**2 * rho_n / (1 + v_cn) + rho_n * (1 - rho_n) * (v_cn)**2 *
np.linalg.norm(x_n - m_cn)**2 / (D * (1 + v_cn)**2) + v_cn)
return v
def get_factor_var(v_cn, v, tol=1e-9):
"""Returns the factor variance.
: param v_cn: cavity variance
: type v_cn: float
: param v: approximate posterior distribution variance
: rtype: float
: return: factor variance
: rtype: float
"""
if np.abs(v_cn - v) < tol:
# v_n = np.inf
v_n = 1e10
else:
v_n = v_cn * v / (v_cn - v)
if np.isinf(v_n):
raise ValueError("v_n is infty")
return v_n
def get_factor_mean(m_cn, v_cn, v_fn, m, tol=1e-9):
"""Returns the factor mean.
: param m_cn: cavity mean
: type m_cn: array of dimension D
: param v_cn: cavity variance
: type v_cn: float
: param v_fn: factor variance
: rtype: float
: param m: approximate posterior
: type m: array of dimension D
: return: factor variance
: rtype: float
"""
# inf x 0 problem
if v_fn == np.inf and (v_fn == 0 or any(m - m_cn) < tol):
m_n = m_cn
else:
m_n = m_cn + (v_fn + v_cn) / v_cn * (m - m_cn)
return m_n
def get_factor_scale(Z_n, m_fn, v_fn, m_cn, v_cn):
"""Returns the factor scale.
: param Z_n: zeroth moment
: type Z_n: float
: param m_fn: factor mean
: type m_fn: array of length D
: param v_fn: factor variance
: type v_fn: float
: param m_cn: cavity mean
: type m_cn: array of dimenion D
: param v_cn: cavity variance
: type v_cn: float
: return: factor scale
: rtype: float
"""
D = len(m_cn)
var = v_cn + v_fn
if var > 0:
pdf_value = multivariate_normal(m_cn, var * np.eye(D)).pdf(m_fn)
else:
pdf_value = -multivariate_normal(m_cn, -var * np.eye(D)).pdf(m_fn)
# pdf_value = (1.0 / (2 * np.pi * var)**(D/2.0) *
# np.exp(-1.0 / (2 * var) * np.linalg.norm(m_fn - m_cn)**2))
# if np.isnan(pdf_value):
# raise ValueError("NaN pdf_value in get_factor_scale")
sn = Z_n / pdf_value
return sn
def get_log_evidence(m, v, m_f, v_f, s_f, b):
"""Returns the model evidence
: param m: approximate posterior mean
: type m: array of dimension D
: param v: approximate posterior variance
: type v: float
: param m_f: factor means
: type m_f: N \\times D array
: param v_f: factor variance
: type v_f: array of dimension N
: param s_f: factor scale
: type s_f: array of dimension N
: param b: prior variance
: type b: float
: return: model evidence
: rtype: float
"""
D = len(m_f[0])
N = len(s_f)
constants_terms = D / 2 * (np.log(v) - np.log(b) - N * np.log(2*np.pi))
log_s_term = (np.sum(np.log(np.abs(s_f))) -
D / 2 * np.sum(np.log(np.abs(v_f))))
exp_terms = 0.5 * (np.sum(m**2) / v - np.sum(np.sum(m_f**2, axis=1) / v_f))
log_evidence = constants_terms + log_s_term + exp_terms
return log_evidence
Import requirments¶
import numpy as np
from scipy.stats import multivariate_normal
from scipy.stats import norm
from IPython.display import display
import plotly.graph_objects as go
import ep.examples.clutter.utils
import ep.examples.clutter.core
import ep.examples.clutter.plot
Generate two samples from the clutter model¶
theta = 3.0
a = 10.0
b = 100.0
w = 0.5
N = 2
x_min = -5.0
x_max = 10.0
x_dt = 0.01
samples_list = ep.examples.clutter.utils.sample(theta=theta, a=a, w=w, n_samples=N)
samples = [np.array([sample]) for sample in samples_list]
Compute true posterior¶
D = len(samples[0])
x1 = samples[0]
x2 = samples[1]
sigma2_SS = b / (1 + 2 * b)
mu_SS = sigma2_SS * (x1 + x2)
sigma2_SC = b / (1 + b)
mu_SC = sigma2_SC * x1
sigma2_CS = sigma2_SC
mu_CS = sigma2_CS * x2
sigma2_CC = b
mu_CC = 0.0
pi_SS = ((1 - w)**2 * (1 / (2 * np.pi)**D) * (1 / (1 + 2 * b)**(D / 2)) *
np.exp((b * np.linalg.norm(x1 + x2)**2 -
(1 + 2 * b) * (np.linalg.norm(x1)**2 +
np.linalg.norm(x2)**2)) /
(2 * (1 + 2 * b))))
pi_SC = ((1 - w) * w *
multivariate_normal(np.zeros(shape=D),
a * np.eye(D)).pdf(x2) *
multivariate_normal(np.zeros(shape=D),
(b + 1) * np.eye(D)).pdf(x1))
pi_CS = (w * (1 - w) *
multivariate_normal(np.zeros(shape=D),
a * np.eye(D)).pdf(x1) *
multivariate_normal(np.zeros(shape=D),
(b + 1) * np.eye(D)).pdf(x2))
pi_CC = (w**2 *
multivariate_normal(np.zeros(shape=D),
a * np.eye(D)).pdf(x1) *
multivariate_normal(np.zeros(shape=D),
a * np.eye(D)).pdf(x2))
K = 1.0 / (pi_SS + pi_SC + pi_CS + pi_CC)
def two_points_true_posterior(theta):
answer = K * (pi_SS * multivariate_normal(mu_SS, sigma2_SS *
np.eye(D)).pdf(theta) +
pi_SC * multivariate_normal(mu_SC, sigma2_SC *
np.eye(D)).pdf(theta) +
pi_CS * multivariate_normal(mu_CS, sigma2_CS *
np.eye(D)).pdf(theta) +
pi_CC * multivariate_normal(mu_CC, sigma2_CC *
np.eye(D)).pdf(theta))
return answer
Expectation Propagation script¶
num_iter = 20
D = len(samples[0])
m, v, m_f, v_f, s_f = ep.examples.clutter.core.init(b=b, D=D, N=N)
log_evidences = []
snapshots = []
for iter_num in range(num_iter):
for n in range(N):
v_cn = ep.examples.clutter.core.get_cavity_var(v=v, v_fn=v_f[n])
m_cn = ep.examples.clutter.core.get_cavity_mean(m=m, m_fn=m_f[n], v_fn=v_f[n], v_cn=v_cn)
Z_n = ep.examples.clutter.core.get_zeroth_moment(w=w, a=a, m_cn=m_cn, v_cn=v_cn, x_n=samples[n])
rho_n = ep.examples.clutter.core.get_site_strength(w=w, a=a, D=D, Z_n=Z_n, x_n=samples[n])
m = ep.examples.clutter.core.get_q_mean(m_cn=m_cn, v_cn=v_cn, rho_n=rho_n, x_n=samples[n])
v = ep.examples.clutter.core.get_q_var(m_cn=m_cn, v_cn=v_cn, rho_n=rho_n, x_n=samples[n])
v_f[n] = ep.examples.clutter.core.get_factor_var(v_cn=v_cn, v=v)
m_f[n] = ep.examples.clutter.core.get_factor_mean(m_cn=m_cn, v_cn=v_cn, v_fn=v_f[n], m=m)
s_f[n] = ep.examples.clutter.core.get_factor_scale(Z_n=Z_n, m_fn=m_f[n],
v_fn=v_f[n], m_cn=m_cn, v_cn=v_cn)
snapshots.append({
"iter": iter_num,
"v_cn": v_cn,
"m_cn": m_cn.copy(),
"v": v,
"m": m.copy(),
"v_fn": v_f[n],
"m_fn": m_f[n].copy(),
})
log_evidence = ep.examples.clutter.core.get_log_evidence(m=m, v=v, m_f=m_f, v_f=v_f, s_f=s_f,
b=b)
log_evidences.append(log_evidence)
Plot EP probability density functions after iteration 0¶
x_min = -10
x_max = 10
x_dt = 0.1
iter_num = 0
s = snapshots[iter_num]
ep.examples.clutter.plot.plot_pdfs(theta=theta, m_cn=s["m_cn"], v_cn=s["v_cn"],
m=s["m"], v=s["v"], m_fn=s["m_fn"],
v_fn=s["v_fn"], samples=samples[:N],
x_min=x_min, x_max=x_max, x_dt=x_dt,
title=f"Iteration {iter_num}, Factor {N-1}",
true_posterior_func=two_points_true_posterior)
Plot EP probability density functions after iteration 3¶
iter_num = 3
s = snapshots[iter_num]
ep.examples.clutter.plot.plot_pdfs(theta=theta, m_cn=s["m_cn"], v_cn=s["v_cn"],
m=s["m"], v=s["v"], m_fn=s["m_fn"],
v_fn=s["v_fn"], samples=samples[:N],
x_min=x_min, x_max=x_max, x_dt=x_dt,
title=f"Iteration {iter_num}, Factor {N-1}",
true_posterior_func=two_points_true_posterior)
Plot EP probability density functions after iteration 6¶
iter_num = 6
s = snapshots[iter_num]
ep.examples.clutter.plot.plot_pdfs(theta=theta, m_cn=s["m_cn"], v_cn=s["v_cn"],
m=s["m"], v=s["v"], m_fn=s["m_fn"],
v_fn=s["v_fn"], samples=samples[:N],
x_min=x_min, x_max=x_max, x_dt=x_dt,
title=f"Iteration {iter_num}, Factor {N-1}",
true_posterior_func=two_points_true_posterior)
Plot EP probability density functions after iteration 9¶
iter_num = 9
s = snapshots[iter_num]
ep.examples.clutter.plot.plot_pdfs(theta=theta, m_cn=s["m_cn"], v_cn=s["v_cn"],
m=s["m"], v=s["v"], m_fn=s["m_fn"],
v_fn=s["v_fn"], samples=samples[:N],
x_min=x_min, x_max=x_max, x_dt=x_dt,
title=f"Iteration {iter_num}, Factor {N-1}",
true_posterior_func=two_points_true_posterior)
Plot EP probability density functions after iteration 19¶
iter_num = 19
s = snapshots[iter_num]
ep.examples.clutter.plot.plot_pdfs(theta=theta, m_cn=s["m_cn"], v_cn=s["v_cn"],
m=s["m"], v=s["v"], m_fn=s["m_fn"],
v_fn=s["v_fn"], samples=samples[:N],
x_min=x_min, x_max=x_max, x_dt=x_dt,
title=f"Iteration {iter_num}, Factor {N-1}",
true_posterior_func=two_points_true_posterior)
Plot log evidences¶
ep.examples.clutter.plot.plot_log_evidences(log_evidences)
Total running time of the script: (0 minutes 0.157 seconds)