Offline filtering of a simulated mouse trajectory

The code below performs online Kalman filtering of a simulated mouse trajectory.

Import packages

import numpy as np
import plotly.graph_objects as go

import lds.tracking.utils
import lds.simulation
import lds.inference

Set initial conditions and parameters

pos_x0 = 0.0
pos_y0 = 0.0
vel_x0 = 0.0
vel_y0 = 0.0
ace_x0 = 0.0
ace_y0 = 0.0
dt = 1e-3
num_pos = 1000
sigma_a = 1.0
sigma_x = 1.0
sigma_y = 1.0
sqrt_diag_V0_value = 1e-03

Set LDS parameters

B, Q, Z, R, Qe = lds.tracking.utils.getLDSmatricesForTracking(
    dt=dt, sigma_a=sigma_a, sigma_x=sigma_x, sigma_y=sigma_y)
m0 = np.array([pos_x0, vel_x0, ace_x0, pos_y0, vel_y0, ace_y0],
              dtype=np.double)
V0 = np.diag(np.ones(len(m0))*sqrt_diag_V0_value**2)

Sample from the LDS

View source code of lds.simulation.simulateLDS

x0, x, y = lds.simulation.simulateLDS(N=num_pos, B=B, Q=Q, Z=Z, R=R,
                                             m0=m0, V0=V0)

Perform batch filtering

View source code of lds.inference.filterLDS_SS_withMissingValues_np

Q = sigma_a*Qe
filterRes = lds.inference.filterLDS_SS_withMissingValues_np(
    y=y, B=B, Q=Q, m0=m0, V0=V0, Z=Z, R=R)

Set variables for plotting

N = y.shape[1]
time = np.arange(0, N*dt, dt)
filtered_means = filterRes["xnn"]
filtered_covs = filterRes["Vnn"]
filter_std_x_y = np.sqrt(np.diagonal(a=filtered_covs, axis1=0, axis2=1))
color_true = "blue"
color_measured = "black"
color_filtered_pattern = "rgba(255,0,0,{:f})"
cb_alpha = 0.3

Define function for plotting

def get_fig_kinematics_vs_time(
    time,
    true_x, true_y,
    measured_x, measured_y,
    estimated_mean_x, estimated_mean_y,
    estimated_ci_x_upper, estimated_ci_y_upper,
    estimated_ci_x_lower, estimated_ci_y_lower,
    cb_alpha,
    color_true,
    color_measured,
    color_estimated_pattern,
    xlabel, ylabel):

    fig = go.Figure()
    trace_true_x = go.Scatter(
        x=time, y=true_x,
        mode="markers",
        marker={"color": color_true},
        name="true x",
        showlegend=True,
    )
    fig.add_trace(trace_true_x)
    trace_true_y = go.Scatter(
        x=time, y=true_y,
        mode="markers",
        marker={"color": color_true},
        name="true y",
        showlegend=True,
    )
    fig.add_trace(trace_true_y)
    if measured_x is not None:
        trace_mes_x = go.Scatter(
            x=time, y=measured_x,
            mode="markers",
            marker={"color": color_measured},
            name="measured x",
            showlegend=True,
        )
        fig.add_trace(trace_mes_x)
    if measured_y is not None:
        trace_mes_y = go.Scatter(
            x=time, y=measured_y,
            mode="markers",
            marker={"color": color_measured},
            name="measured y",
            showlegend=True,
        )
        fig.add_trace(trace_mes_y)
    trace_est_x = go.Scatter(
        x=time, y=estimated_mean_x,
        mode="markers",
        marker={"color": color_estimated_pattern.format(1.0)},
        name="estimated x",
        showlegend=True,
        legendgroup="estimated_x",
    )
    fig.add_trace(trace_est_x)
    trace_est_x_cb = go.Scatter(
        x=np.concatenate([time, time[::-1]]),
        y=np.concatenate([estimated_ci_x_upper, estimated_ci_x_lower[::-1]]),
        fill="toself",
        fillcolor=color_estimated_pattern.format(cb_alpha),
        line=dict(color=color_estimated_pattern.format(0.0)),
        showlegend=False,
        legendgroup="estimated_x",
    )
    fig.add_trace(trace_est_x_cb)
    trace_est_y = go.Scatter(
        x=time, y=estimated_mean_y,
        mode="markers",
        marker={"color": color_estimated_pattern.format(1.0)},
        name="estimated y",
        showlegend=True,
        legendgroup="estimated_y",
    )
    fig.add_trace(trace_est_y)
    trace_est_y_cb = go.Scatter(
        x=np.concatenate([time, time[::-1]]),
        y=np.concatenate([estimated_ci_y_upper, estimated_ci_y_lower[::-1]]),
        fill="toself",
        fillcolor=color_estimated_pattern.format(cb_alpha),
        line=dict(color=color_estimated_pattern.format(0.0)),
        showlegend=False,
        legendgroup="estimated_y",
    )
    fig.add_trace(trace_est_y_cb)

    fig.update_layout(xaxis_title=xlabel,
                      yaxis_title=ylabel,
                      paper_bgcolor='rgba(0,0,0,0)',
                      plot_bgcolor='rgba(0,0,0,0)',
                      yaxis_range=[estimated_mean_x.min(),
                                   estimated_mean_x.max()],
                     )
    return fig

Plot true, measured and filtered positions (with 95% confidence band)

true_x = x[0, :]
true_y = x[3, :]
measured_x = y[0, :]
measured_y = y[1, :]
filter_mean_x = filtered_means[0, 0, :]
filter_mean_y = filtered_means[3, 0, :]

filter_ci_x_upper = filter_mean_x + 1.96*filter_std_x_y[:, 0]
filter_ci_x_lower = filter_mean_x - 1.96*filter_std_x_y[:, 0]
filter_ci_y_upper = filter_mean_y + 1.96*filter_std_x_y[:, 3]
filter_ci_y_lower = filter_mean_y - 1.96*filter_std_x_y[:, 3]

fig = get_fig_kinematics_vs_time(
    time=time,
    true_x=true_x, true_y=true_y,
    measured_x=measured_x, measured_y=measured_y,
    estimated_mean_x=filter_mean_x, estimated_mean_y=filter_mean_y,
    estimated_ci_x_upper=filter_ci_x_upper,
    estimated_ci_y_upper=filter_ci_y_upper,
    estimated_ci_x_lower=filter_ci_x_lower,
    estimated_ci_y_lower=filter_ci_y_lower,
    cb_alpha=cb_alpha,
    color_true=color_true, color_measured=color_measured,
    color_estimated_pattern=color_filtered_pattern,
    xlabel="Time (sec)", ylabel="Position (pixels)")
# fig_filename_pattern = "../../figures/filtered_pos.{:s}"
# fig.write_image(fig_filename_pattern.format("png"))
# fig.write_html(fig_filename_pattern.format("html"))
fig


Plot true and filtered velocities (with 95% confidence band)

true_x = x[1, :]
true_y = x[4, :]
measured_x = None
measured_y = None
filter_mean_x = filtered_means[1, 0, :]
filter_mean_y = filtered_means[4, 0, :]

filter_ci_x_upper = filter_mean_x + 1.96*filter_std_x_y[:, 1]
filter_ci_x_lower = filter_mean_x - 1.96*filter_std_x_y[:, 1]
filter_ci_y_upper = filter_mean_y + 1.96*filter_std_x_y[:, 4]
filter_ci_y_lower = filter_mean_y - 1.96*filter_std_x_y[:, 4]

fig = get_fig_kinematics_vs_time(
    time=time,
    true_x=true_x, true_y=true_y,
    measured_x=measured_x, measured_y=measured_y,
    estimated_mean_x=filter_mean_x, estimated_mean_y=filter_mean_y,
    estimated_ci_x_upper=filter_ci_x_upper,
    estimated_ci_y_upper=filter_ci_y_upper,
    estimated_ci_x_lower=filter_ci_x_lower,
    estimated_ci_y_lower=filter_ci_y_lower,
    cb_alpha=cb_alpha,
    color_true=color_true, color_measured=color_measured,
    color_estimated_pattern=color_filtered_pattern,
    xlabel="Time (sec)", ylabel="Velocity (pixels/sec)")
# fig_filename_pattern = "../../figures/filtered_vel.{:s}"
# fig.write_image(fig_filename_pattern.format("png"))
# fig.write_html(fig_filename_pattern.format("html"))
fig


Plot true and filtered accelerations (with 95% confidence band)

true_x = x[2, :]
true_y = x[5, :]
measured_x = None
measured_y = None
filter_mean_x = filtered_means[2, 0, :]
filter_mean_y = filtered_means[5, 0, :]

filter_ci_x_upper = filter_mean_x + 1.96*filter_std_x_y[:, 2]
filter_ci_x_lower = filter_mean_x - 1.96*filter_std_x_y[:, 2]
filter_ci_y_upper = filter_mean_y + 1.96*filter_std_x_y[:, 5]
filter_ci_y_lower = filter_mean_y - 1.96*filter_std_x_y[:, 5]

fig = get_fig_kinematics_vs_time(
    time=time,
    true_x=true_x, true_y=true_y,
    measured_x=measured_x, measured_y=measured_y,
    estimated_mean_x=filter_mean_x, estimated_mean_y=filter_mean_y,
    estimated_ci_x_upper=filter_ci_x_upper,
    estimated_ci_y_upper=filter_ci_y_upper,
    estimated_ci_x_lower=filter_ci_x_lower,
    estimated_ci_y_lower=filter_ci_y_lower,
    cb_alpha=cb_alpha,
    color_true=color_true, color_measured=color_measured,
    color_estimated_pattern=color_filtered_pattern,
    xlabel="Time (sec)", ylabel="Acceleration (pixels/sec^2)")
# fig_filename_pattern = "../../figures/filtered_acc.{:s}"
# fig.write_image(fig_filename_pattern.format("png"))
# fig.write_html(fig_filename_pattern.format("html"))
fig


Total running time of the script: (0 minutes 0.139 seconds)

Gallery generated by Sphinx-Gallery