Note
Go to the end to download the full example code. or to run this example in your browser via Binder
1.4. Online Bayesian linear regression¶
Online estimation of the posterior of the weighs of a linear regression model. It reproduces Figure 3.7 from Bishop 2006.
1.4.1. Import requirments¶
import numpy as np
import scipy.stats
import plotly.subplots
import plotly.graph_objects as go
import bayesianLinearRegression
1.4.2. Define data generation variables¶
n_samples = 20
a0 = -0.3
a1 = 0.5
likelihood_precision_coef = (1/0.2)**2
n_samples_to_plot = (1, 2, 20)
1.4.3. Generate data¶
x = np.random.uniform(low=-1, high=1, size=n_samples)
y = a0 + a1 * x
t = y + np.random.standard_normal(size=y.shape) * 1.0/likelihood_precision_coef
1.4.4. Define plotting variables¶
n_post_samples = 6
marker_true = "cross"
size_true = 10
color_true = "white"
marker_data = "circle-open"
size_data = 10
color_data = "blue"
line_width_data = 5
1.4.5. Define estimation variables¶
prior_precision_coef = 2.0
1.4.6. Estimate and plot posterior¶
x_grid = np.linspace(-1, 1, 100)
y_grid = np.linspace(-1, 1, 100)
X_grid, Y_grid = np.meshgrid(x_grid, y_grid)
pos = np.dstack((X_grid, Y_grid))
Phi = np.column_stack((np.ones(len(x)), x))
# set prior
m0 = np.array([0.0, 0.0])
S0 = 1.0 / prior_precision_coef * np.eye(2)
fig = plotly.subplots.make_subplots(rows=len(n_samples_to_plot)+1, cols=3)
x_dense = np.arange(-1.0, 1.0, 0.1)
# trace true coefficient
trace_true_coef = go.Scatter(x=[a0], y=[a1], mode="markers",
marker_symbol=marker_true,
marker_size=size_true,
marker_color=color_true,
name="true mean",
showlegend=False)
rv = scipy.stats.multivariate_normal(m0, S0)
# plot prior
Z = rv.pdf(pos)
trace_post = go.Contour(x=x_grid, y=y_grid, z=Z, showscale=False)
fig.add_trace(trace_post, row=1, col=2)
fig.add_trace(trace_true_coef, row=1, col=2)
fig.update_xaxes(title_text="Intercept", row=1, col=2)
fig.update_yaxes(title_text="Slope", row=1, col=2)
# sample from prior
samples = rv.rvs(size=n_post_samples)
# plot regression lines corresponding to samples
for a_sample in samples:
sample_intercept, sample_slope = a_sample
sample_y = sample_intercept + sample_slope * x_dense
trace = go.Scatter(x=x_dense, y=sample_y, mode="lines",
line_color="red", showlegend=False)
fig.add_trace(trace, row=1, col=3)
fig.update_xaxes(title_text="x", row=1, col=3)
fig.update_yaxes(title_text="y", row=1, col=3)
mn = m0
Sn = S0
for n, t in enumerate(y):
print(f"Processing {n}/({len(y)})")
# update posterior
mn, Sn = bayesianLinearRegression.onlineUpdate(
mn=mn, Sn=Sn, phi=Phi[n, :], y=t, alpha=prior_precision_coef,
beta=likelihood_precision_coef)
if n+1 in n_samples_to_plot:
index_sample = n_samples_to_plot.index(n+1)
# compute likelihood
Z = np.empty(shape=(len(x_grid), len(y_grid)), dtype=np.double)
for i, w0 in enumerate(x_grid):
for j, w1 in enumerate(y_grid):
rv = scipy.stats.norm(w0 + w1 * x[n],
1.0/likelihood_precision_coef)
Z[j, i] = rv.pdf(t)
# plot likelihood
trace_like = go.Contour(x=x_grid, y=y_grid, z=Z, showscale=False)
fig.add_trace(trace_like, row=index_sample+2, col=1)
fig.add_trace(trace_true_coef, row=index_sample+2, col=1)
fig.update_xaxes(title_text="Intercept", row=index_sample+2, col=1)
fig.update_yaxes(title_text="Slope", row=index_sample+2, col=1)
rv = scipy.stats.multivariate_normal(mn, Sn)
# plot updated posterior
Z = rv.pdf(pos)
trace_post = go.Contour(x=x_grid, y=y_grid, z=Z, showscale=False)
fig.add_trace(trace_post, row=index_sample+2, col=2)
fig.add_trace(trace_true_coef, row=index_sample+2, col=2)
fig.update_xaxes(title_text="Intercept", row=index_sample+2, col=2)
fig.update_yaxes(title_text="Slope", row=index_sample+2, col=2)
# sample from posterior
samples = rv.rvs(size=n_post_samples)
# plot regression lines corresponding to samples
for a_sample in samples:
sample_intercept, sample_slope = a_sample
sample_y = sample_intercept + sample_slope * x_dense
trace = go.Scatter(x=x_dense, y=sample_y, mode="lines",
line_color="red", showlegend=False)
fig.add_trace(trace, row=index_sample+2, col=3)
trace_data = go.Scatter(x=x[:(n+1)], y=y[:(n+1)],
mode="markers",
marker_symbol=marker_data,
marker_size=size_data,
marker_color=color_data,
marker_line_width=line_width_data,
showlegend=False,
)
fig.add_trace(trace_data, row=index_sample+2, col=3)
fig.update_xaxes(title_text="x", row=index_sample+2, col=3)
fig.update_yaxes(title_text="y", row=index_sample+2, col=3)
fig
# sphinx_gallery_thumbnail_path = '_static/oblr.png'
Processing 0/(20)
Processing 1/(20)
Processing 2/(20)
Processing 3/(20)
Processing 4/(20)
Processing 5/(20)
Processing 6/(20)
Processing 7/(20)
Processing 8/(20)
Processing 9/(20)
Processing 10/(20)
Processing 11/(20)
Processing 12/(20)
Processing 13/(20)
Processing 14/(20)
Processing 15/(20)
Processing 16/(20)
Processing 17/(20)
Processing 18/(20)
Processing 19/(20)
Total running time of the script: (0 minutes 18.775 seconds)