Note
Go to the end to download the full example code. or to run this example in your browser via Binder
Online filtering of a foraging mouse trajectory
The code below performs online Kalman filtering of a trajectory of the mouse shown on the left image below, as it forages in the arena shown on the right image below. Click on the images to see their larger versions.
Import packages
import sys
import os
import random
import pickle
import configparser
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import lds.tracking.utils
import lds.inference
Setup configuration variables
start_position = 0
number_positions = 10000
color_measures = "black"
color_pattern_filtered = "rgba(255,0,0,{:f})"
cb_alpha = 0.3
data_filename = "http://www.gatsby.ucl.ac.uk/~rapela/svGPFA/data/positions_session003_start0.00_end15548.27.csv"
Get the mouse position measurements
data = pd.read_csv(data_filename)
data = data.iloc[start_position:start_position+number_positions,:]
y = np.transpose(data[["x", "y"]].to_numpy())
Get the Kalman filter configuration parameters
pos_x0 = y[0, 0]
pos_y0 = y[0, 0]
vel_x0 = 0.0
vel_y0 = 0.0
acc_x0 = 0.0
acc_y0 = 0.0
sigma_a = 1e4
sigma_x = 1e2
sigma_y = 1e2
sqrt_diag_V0_value = 1e-3
Build the Kalman filter matrices for tracking
date_times = pd.to_datetime(data["time"])
dt = (date_times.iloc[1]-date_times.iloc[0]).total_seconds()
B, Q, Z, R, Qe = lds.tracking.utils.getLDSmatricesForTracking(
dt=dt, sigma_a=sigma_a, sigma_x=sigma_x, sigma_y=sigma_y)
m0 = np.array([[y[0, 0], 0, 0, y[1, 0], 0, 0]], dtype=np.double).T
V0 = np.diag(np.ones(len(m0))*sqrt_diag_V0_value**2)
Apply the Kalman filter to the mouse position measurements
onlineKF = lds.inference.OnlineKalmanFilter(B=B, Q=Q, m0=m0, V0=V0, Z=Z, R=R)
filtered_means = np.empty((6, 1, y.shape[1]), dtype=np.double)
filtered_covs = np.empty((6, 6, y.shape[1]), dtype=np.double)
for i in range(y.shape[1]):
_, _ = onlineKF.predict()
filtered_means[:, :, i], filtered_covs[:, :, i] = \
onlineKF.update(y=y[:, i])
Plot the filtered positions with their 95% confidence bands
time = (pd.to_datetime(data["time"]) - pd.to_datetime(data["time"][0])).dt.total_seconds().to_numpy()
filter_mean_x = filtered_means[0, 0, :]
filter_mean_y = filtered_means[3, 0, :]
filter_std_x_y = np.sqrt(np.diagonal(a=filtered_covs, axis1=0, axis2=1))
filter_ci_x_upper = filter_mean_x + 1.96*filter_std_x_y[:, 0]
filter_ci_x_lower = filter_mean_x - 1.96*filter_std_x_y[:, 0]
filter_ci_y_upper = filter_mean_y + 1.96*filter_std_x_y[:, 3]
filter_ci_y_lower = filter_mean_y - 1.96*filter_std_x_y[:, 3]
fig = go.Figure()
trace_mes_x = go.Scatter(
x=time, y=y[0, :],
mode="markers",
marker={"color": color_measures},
name="measured x",
showlegend=True,
)
trace_mes_y = go.Scatter(
x=time, y=y[1, :],
mode="markers",
marker={"color": color_measures},
name="measured y",
showlegend=True,
)
trace_filter_x = go.Scatter(
x=time, y=filter_mean_x,
mode="markers",
marker={"color": color_pattern_filtered.format(1.0)},
name="filtered x",
showlegend=True,
legendgroup="filtered_x",
)
trace_filter_x_cb = go.Scatter(
x=np.concatenate([time, time[::-1]]),
y=np.concatenate([filter_ci_x_upper,
filter_ci_x_lower[::-1]]),
fill="toself",
fillcolor=color_pattern_filtered.format(cb_alpha),
line=dict(color=color_pattern_filtered.format(0.0)),
showlegend=False,
legendgroup="filtered_x",
)
trace_filter_y = go.Scatter(
x=time, y=filter_mean_y,
mode="markers",
marker={"color":
color_pattern_filtered.format(1.0)},
name="filtered y",
showlegend=True,
legendgroup="filtered_y",
)
trace_filter_y_cb = go.Scatter(
x=np.concatenate([time, time[::-1]]),
y=np.concatenate([filter_ci_y_upper,
filter_ci_y_lower[::-1]]),
fill="toself",
fillcolor=color_pattern_filtered.format(cb_alpha),
line=dict(color=color_pattern_filtered.format(0.0)),
showlegend=False,
legendgroup="filtered_y",
)
fig.add_trace(trace_mes_x)
fig.add_trace(trace_mes_y)
fig.add_trace(trace_filter_x)
fig.add_trace(trace_filter_x_cb)
fig.add_trace(trace_filter_y)
fig.add_trace(trace_filter_y_cb)
fig.update_layout(xaxis_title="time (seconds)", yaxis_title="position (pixels)",
paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)',
yaxis_range=[filter_mean_x.min(), filter_mean_x.max()])
fig
Total running time of the script: (0 minutes 5.972 seconds)