.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/neuralLatents/plot_simulations.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_neuralLatents_plot_simulations.py: Learning and inference of latents with simple simulated data ============================================================ The code below learns and infers latents with simple simulated data. .. GENERATED FROM PYTHON SOURCE LINES 12-14 Import packages ~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 14-22 .. code-block:: default import numpy as np import plotly.graph_objs as go import ssm.simulation import ssm.learning import ssm.tracking.plotting .. GENERATED FROM PYTHON SOURCE LINES 23-25 Define variables for simulation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 25-34 .. code-block:: default m0 = np.array([0.0, 0.0], dtype=np.double) V0 = np.array([[1e-3,0], [0,1e-3]], dtype=np.double) B = np.array([[0.9872,-0.0272], [0.0080,1.0128]], dtype=np.double) Q = np.array([[1e-3,0], [0,1e-3]], dtype=np.double) Z = np.array([[1,0], [0,1]], dtype=np.double) R = np.array([[.08,0], [0,.08]], dtype=np.double) num_pos = 2000 .. GENERATED FROM PYTHON SOURCE LINES 35-37 Perform simulation ~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 37-42 .. code-block:: default x0, x, y = ssm.simulation.simulateLDS(T=num_pos, B=B, Q=Q, Z=Z, R=R, m0=m0, V0=V0) simulation_step = np.arange(x.shape[1]) .. GENERATED FROM PYTHON SOURCE LINES 43-45 Plot simulation ~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 45-59 .. code-block:: default fig = go.Figure() trace_x = go.Scatter(x=x[0, :], y=x[1, :], mode="lines+markers", showlegend=True, name="x") trace_y = go.Scatter(x=y[0, :], y=y[1, :], mode="lines+markers", showlegend=True, name="y", opacity=0.3) trace_start = go.Scatter(x=[x0[0]], y=[x0[1]], mode="markers", text="x0", marker={"size": 7}, showlegend=False) fig.add_trace(trace_x) fig.add_trace(trace_y) fig.add_trace(trace_start) fig .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 60-62 Define initial conditions and control variables for EM learning ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 62-82 .. code-block:: default m0_0 =y[:,0] V0_0 = np.array([[1e-2,0], [0,1e-2]], np.double) B0 = np.array([[1.0,-0.1], [0.0080,1.5]], np.double) Q0 = np.array([[1e-4,0], [0,1e-2]], np.double) Z0 = np.array([[1.0,0.1], [-0.1,1.0]], np.double) R0 = np.array([[0.1,0], [0,0.1]], np.double) # True Values # V0_0 = np.array([[1e-3,0], [0,1e-3]], np.double) # B0 = np.array([[.9872,-0.0272],[0.0080,1.0128]], np.double) # Q0 = np.array([[1e-3,0], [0,1e-3]], np.double) # Z0 = np.array([[1.0,0.0], [0.0,1.0]], np.double) # R0 = np.array([[0.5,0], [0,0.5]], np.double) max_iter = 50 tol = 1e-6 vars_to_estimate = {"B": True, "Q": True, "Z": True, "R": True, "m0": True, "V0": True} .. GENERATED FROM PYTHON SOURCE LINES 83-85 Run EM ~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 85-91 .. code-block:: default optim_res = ssm.learning.em_SS_LDS( y=y, B0=B0, Q0=Q0, Z0=Z0, R0=R0, m0_0=m0_0, V0_0=V0_0, max_iter=max_iter, tol=tol, vars_to_estimate=vars_to_estimate, ) .. rst-class:: sphx-glr-script-out .. code-block:: none LogLike[0000]=-8039.883161 LogLike[0001]=-1468.555193 LogLike[0002]=-1206.301514 LogLike[0003]=-1098.275612 LogLike[0004]=-1053.390481 LogLike[0005]=-1030.138445 LogLike[0006]=-1013.434567 LogLike[0007]=-999.268393 LogLike[0008]=-986.695106 LogLike[0009]=-975.422165 LogLike[0010]=-965.280255 LogLike[0011]=-956.131935 LogLike[0012]=-947.858185 LogLike[0013]=-940.355527 LogLike[0014]=-933.534138 LogLike[0015]=-927.316084 LogLike[0016]=-921.633692 LogLike[0017]=-916.428122 LogLike[0018]=-911.648115 LogLike[0019]=-907.248923 LogLike[0020]=-903.191394 LogLike[0021]=-899.441185 LogLike[0022]=-895.968101 LogLike[0023]=-892.745526 LogLike[0024]=-889.749939 LogLike[0025]=-886.960502 LogLike[0026]=-884.358709 LogLike[0027]=-881.928084 LogLike[0028]=-879.653925 LogLike[0029]=-877.523078 LogLike[0030]=-875.523751 LogLike[0031]=-873.645347 LogLike[0032]=-871.878324 LogLike[0033]=-870.214069 LogLike[0034]=-868.644793 LogLike[0035]=-867.163440 LogLike[0036]=-865.763602 LogLike[0037]=-864.439452 LogLike[0038]=-863.185681 LogLike[0039]=-861.997441 LogLike[0040]=-860.870302 LogLike[0041]=-859.800206 LogLike[0042]=-858.783432 LogLike[0043]=-857.816562 LogLike[0044]=-856.896453 LogLike[0045]=-856.020208 LogLike[0046]=-855.185158 LogLike[0047]=-854.388836 LogLike[0048]=-853.628961 LogLike[0049]=-852.903424 .. GENERATED FROM PYTHON SOURCE LINES 92-94 Plot log likelihood vs iteration number ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 94-106 .. code-block:: default N = len(optim_res["log_like"]) iter_no = np.arange(0, N) fig = go.Figure() trace = go.Scatter(x=iter_no, y=optim_res["log_like"], mode="lines+markers") fig.add_trace(trace) fig.update_layout(xaxis=dict(title="Iteration Number"), yaxis=dict(title="Lower Bound")) fig .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 107-109 Filter simulations with the estimated parameters ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 109-114 .. code-block:: default filter_res = ssm.inference.filterLDS_SS_withMissingValues_np( y=y, B=optim_res["B"], Q=optim_res["Q"], m0=optim_res["m0"], V0=optim_res["V0"], Z=optim_res["Z"], R=optim_res["R"]) .. GENERATED FROM PYTHON SOURCE LINES 115-117 Plot true and filtered states ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 117-131 .. code-block:: default true_values = x[(0, 1), :] filtered_means = filter_res["xnn"][(0, 1), 0, :] filtered_stds = np.sqrt(np.diagonal(a=filter_res["Pnn"], axis1=0, axis2=1)[:, (0, 1)].T) fig = ssm.tracking.plotting.get_x_and_y_time_series_vs_time_fig( time=simulation_step, ylabel="State", xlabel="Simulation Step", true_values=true_values, filtered_means=filtered_means, filtered_stds=filtered_stds) fig.update_layout(title=f'Log-Likelihood: {filter_res["logLike"].squeeze()}') fig .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 132-134 Calculate the one-step ahead forecasts ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 134-143 .. code-block:: default one_step_ahead_mean = optim_res["Z"] @ filter_res["xnn1"][:, 0, :] # aux_covs = optim_res["R"] + (optim_res["Z"] @ filter_res["Pnn1"] @ optim_res["Z"].T) aux1 = optim_res["Z"] @ filter_res["Pnn1"] # \in ijl aux2 = optim_res["Z"].T # \in jk aux3 = np.einsum("ijl,jk->ikl", aux1, aux2) aux_covs = np.expand_dims(optim_res["R"], 2) + aux3 one_step_ahead_var = np.diagonal(aux_covs, axis1=0, axis2=1).T .. GENERATED FROM PYTHON SOURCE LINES 144-146 Plot measurements and one-step ahead forecasts ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 146-156 .. code-block:: default fig = ssm.tracking.plotting.get_x_and_y_time_series_vs_time_fig( time=simulation_step, ylabel="One-Step Ahead Forecasts", xlabel="Simulation Step", measurements=y, filtered_means=one_step_ahead_mean, filtered_stds=np.sqrt(one_step_ahead_var)) fig .. raw:: html


.. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 13.744 seconds) .. _sphx_glr_download_auto_examples_neuralLatents_plot_simulations.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/joacorapela/lds_python/gh-pages?filepath=notebooks/auto_examples/neuralLatents/plot_simulations.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_simulations.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_simulations.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_