Note
Go to the end to download the full example code
4.1. Simulated data and default params#
In this notebook we use simulated data to estimate an svGPFA model using the default initial parameters.
4.1.1. Estimate model#
4.1.1.1. Import required packages#
import sys
import time
import warnings
import torch
import pickle
import gcnu_common.stats.pointProcesses.tests
import svGPFA.stats.kernels
import svGPFA.stats.svGPFAModelFactory
import svGPFA.stats.svEM
import svGPFA.utils.miscUtils
import svGPFA.utils.initUtils
import svGPFA.plot.plotUtilsPlotly
4.1.1.2. Get spikes times#
The spikes times of all neurons in all trials should be stored in nested lists. spikes_times[r][n]
should contain a list of spikes times of neuron n
in trial r
.
sim_res_filename = "../../examples/data/32451751_simRes.pickle" # simulation results filename
with open(sim_res_filename, "rb") as f:
sim_res = pickle.load(f)
spikes_times = sim_res["spikes"]
n_trials = len(spikes_times)
n_neurons = len(spikes_times[0])
trials_start_time = 0.0
trials_end_time = 1.0
trials_start_times = [trials_start_time] * n_trials
trials_end_times = [trials_end_time] * n_trials
4.1.1.3. Check that spikes have been epoched correctly#
4.1.1.4. Plot spikes#
Plot the spikes of all trials of a randomly chosen neuron. Most trials should contain at least one spike.
neuron_to_plot_index = torch.randint(low=0, high=n_neurons, size=(1,)).item()
fig = svGPFA.plot.plotUtilsPlotly.getSpikesTimesPlotOneNeuron(
spikes_times=spikes_times,
neuron_index=neuron_to_plot_index,
title=f"Neuron index: {neuron_to_plot_index}",
)
fig
4.1.1.5. Run some simple checks on spikes#
The function checkEpochedSpikesTimes
tests that:
every neuron fired at least one spike across all trials,
for each trial, the spikes times of every neuron are between the trial start and end times.
If any check fails, a ValueError
will be raised. Otherwise a checks
passed message should be printed.
try:
gcnu_common.utils.neural_data_analysis.checkEpochedSpikesTimes(
spikes_times=spikes_times, trials_start_times=trials_start_times,
trials_end_times=trials_end_times,
)
except ValueError:
raise
print("Checks passed")
Checks passed
4.1.1.6. Set estimation hyperparameters#
n_latents = 2
em_max_iter = 30
model_save_filename = "../results/simulation_model.pickle"
4.1.1.7. Get parameters#
4.1.1.7.1. Build default parameters specificiation#
default_params_spec = svGPFA.utils.initUtils.getDefaultParamsDict(
n_neurons=n_neurons, n_trials=n_trials, n_latents=n_latents,
em_max_iter=em_max_iter)
4.1.1.7.2. Get parameters and kernels types from the parameters specification#
params, kernels_types = svGPFA.utils.initUtils.getParamsAndKernelsTypes(
n_trials=n_trials, n_neurons=n_neurons, n_latents=n_latents,
trials_start_times=trials_start_times,
trials_end_times=trials_end_times,
default_params_spec=default_params_spec)
Extracted default_params_spec[optim_params][n_quad]=200
Extracted default_params_spec[ind_points_locs_params0][n_ind_points]=[10, 10]
Extracted from default c0_distribution=Normal, c0_loc=0.0, c0_scale=1.0, c0_random_seed=None
Extracted from default d0_distribution=Normal, d0_loc=0.0, d0_scale=1.0, d0_random_seed=None
Extracted from default k_type=exponentialQuadratic and k_lengthsales0=1.0
Extracted from default ind_points_locs0_layout=equidistant
Extracted from default variational_mean0
Extracted from default variational_cov0
Extracted default_params_spec[optim_params][n_quad]=200
Extracted default_params_spec[optim_params][prior_cov_reg_param]=0.001
Extracted default_params_spec[optim_params][optim_method]=ecm
Extracted default_params_spec[optim_params][em_max_iter]=30
Extracted default_params_spec[optim_params][verbose]=True
Extracted default_params_spec[optim_params][estep_estimate]=True
Extracted default_params_spec[optim_params][estep_max_iter]=20
Extracted default_params_spec[optim_params][estep_lr]=1.0
Extracted default_params_spec[optim_params][estep_tolerance_grad]=1e-07
Extracted default_params_spec[optim_params][estep_tolerance_change]=1e-09
Extracted default_params_spec[optim_params][estep_line_search_fn]=strong_wolfe
Extracted default_params_spec[optim_params][mstep_embedding_estimate]=True
Extracted default_params_spec[optim_params][mstep_embedding_max_iter]=20
Extracted default_params_spec[optim_params][mstep_embedding_lr]=1.0
Extracted default_params_spec[optim_params][mstep_embedding_tolerance_grad]=1e-07
Extracted default_params_spec[optim_params][mstep_embedding_tolerance_change]=1e-09
Extracted default_params_spec[optim_params][mstep_embedding_line_search_fn]=strong_wolfe
Extracted default_params_spec[optim_params][mstep_kernels_estimate]=True
Extracted default_params_spec[optim_params][mstep_kernels_max_iter]=20
Extracted default_params_spec[optim_params][mstep_kernels_lr]=1.0
Extracted default_params_spec[optim_params][mstep_kernels_tolerance_grad]=1e-07
Extracted default_params_spec[optim_params][mstep_kernels_tolerance_change]=1e-09
Extracted default_params_spec[optim_params][mstep_kernels_line_search_fn]=strong_wolfe
Extracted default_params_spec[optim_params][mstep_indpointslocs_estimate]=True
Extracted default_params_spec[optim_params][mstep_indpointslocs_max_iter]=20
Extracted default_params_spec[optim_params][mstep_indpointslocs_lr]=1.0
Extracted default_params_spec[optim_params][mstep_indpointslocs_tolerance_grad]=1e-07
Extracted default_params_spec[optim_params][mstep_indpointslocs_tolerance_change]=1e-09
Extracted default_params_spec[optim_params][mstep_indpointslocs_line_search_fn]=strong_wolfe
4.1.1.8. Create kernels, a model and set its initial parameters#
4.1.1.8.1. Build kernels#
kernels_params0 = params["initial_params"]["posterior_on_latents"]["kernels_matrices_store"]["kernels_params0"]
kernels = svGPFA.utils.miscUtils.buildKernels(
kernels_types=kernels_types, kernels_params=kernels_params0)
4.1.1.8.2. Create model#
model = svGPFA.stats.svGPFAModelFactory.SVGPFAModelFactory.\
buildModelPyTorch(kernels=kernels)
4.1.1.8.3. Set initial parameters#
model.setParamsAndData(
measurements=spikes_times,
initial_params=params["initial_params"],
eLLCalculationParams=params["ell_calculation_params"],
priorCovRegParam=params["optim_params"]["prior_cov_reg_param"])
4.1.1.9. Maximize the Lower Bound#
(Warning: with the parameters above, this step takes around 5 minutes for 30 em_max_iter)
svEM = svGPFA.stats.svEM.SVEM_PyTorch()
tic = time.perf_counter()
lowerBoundHist, elapsedTimeHist, terminationInfo, iterationsModelParams = \
svEM.maximize(model=model, optim_params=params["optim_params"],
method=params["optim_params"]["optim_method"],
out=sys.stdout)
toc = time.perf_counter()
print(f"Elapsed time {toc - tic:0.4f} seconds")
resultsToSave = {"lowerBoundHist": lowerBoundHist,
"elapsedTimeHist": elapsedTimeHist,
"terminationInfo": terminationInfo,
"iterationModelParams": iterationsModelParams,
"model": model}
with open(model_save_filename, "wb") as f:
pickle.dump(resultsToSave, f)
print("Saved results to {:s}".format(model_save_filename))
Iteration 01, estep start: -inf
Iteration 01, estep end: 4128.037260, niter: 20, nfeval: 25
Iteration 01, mstep_embedding start: 4128.037260
Iteration 01, mstep_embedding end: 839456.097316, niter: 12, nfeval: 25
Iteration 01, mstep_kernels start: 839456.097316
Iteration 01, mstep_kernels end: 843016.309502, niter: 11, nfeval: 12
Iteration 01, mstep_indpointslocs start: 843016.309502
Iteration 01, mstep_indpointslocs end: 912289.726056, niter: 20, nfeval: 25
Iteration 02, estep start: 912289.726056
Iteration 02, estep end: 919788.831435, niter: 20, nfeval: 23
Iteration 02, mstep_embedding start: 919788.831435
Iteration 02, mstep_embedding end: 995501.214158, niter: 20, nfeval: 23
Iteration 02, mstep_kernels start: 995501.214158
Iteration 02, mstep_kernels end: 993240.034656, niter: 13, nfeval: 14
Iteration 02, mstep_indpointslocs start: 993240.034656
Iteration 02, mstep_indpointslocs end: 994696.626136, niter: 20, nfeval: 23
Iteration 03, estep start: 994696.626136
Iteration 03, estep end: 995099.477660, niter: 20, nfeval: 23
Iteration 03, mstep_embedding start: 995099.477660
Iteration 03, mstep_embedding end: 1001617.538126, niter: 20, nfeval: 22
Iteration 03, mstep_kernels start: 1001617.538126
Iteration 03, mstep_kernels end: 999592.433795, niter: 9, nfeval: 12
Iteration 03, mstep_indpointslocs start: 999592.433795
Iteration 03, mstep_indpointslocs end: 1000759.900391, niter: 20, nfeval: 23
Iteration 04, estep start: 1000759.900391
Iteration 04, estep end: 1001380.780466, niter: 18, nfeval: 25
Iteration 04, mstep_embedding start: 1001380.780466
Iteration 04, mstep_embedding end: 1004977.603299, niter: 20, nfeval: 23
Iteration 04, mstep_kernels start: 1004977.603299
Iteration 04, mstep_kernels end: 1003180.366095, niter: 10, nfeval: 12
Iteration 04, mstep_indpointslocs start: 1003180.366095
Iteration 04, mstep_indpointslocs end: 1003665.321165, niter: 20, nfeval: 23
Iteration 05, estep start: 1003665.321165
Iteration 05, estep end: 1003832.761878, niter: 20, nfeval: 25
Iteration 05, mstep_embedding start: 1003832.761878
Iteration 05, mstep_embedding end: 1006137.009643, niter: 20, nfeval: 22
Iteration 05, mstep_kernels start: 1006137.009643
Iteration 05, mstep_kernels end: 1004709.401130, niter: 9, nfeval: 10
Iteration 05, mstep_indpointslocs start: 1004709.401130
Iteration 05, mstep_indpointslocs end: 1004937.915194, niter: 20, nfeval: 22
Iteration 06, estep start: 1004937.915194
Iteration 06, estep end: 1005003.032510, niter: 20, nfeval: 24
Iteration 06, mstep_embedding start: 1005003.032510
Iteration 06, mstep_embedding end: 1006710.901285, niter: 20, nfeval: 23
Iteration 06, mstep_kernels start: 1006710.901285
Iteration 06, mstep_kernels end: 1005447.549660, niter: 8, nfeval: 9
Iteration 06, mstep_indpointslocs start: 1005447.549660
Iteration 06, mstep_indpointslocs end: 1005607.217553, niter: 20, nfeval: 23
Iteration 07, estep start: 1005607.217553
Iteration 07, estep end: 1005656.279229, niter: 20, nfeval: 22
Iteration 07, mstep_embedding start: 1005656.279229
Iteration 07, mstep_embedding end: 1007074.302837, niter: 20, nfeval: 23
Iteration 07, mstep_kernels start: 1007074.302837
Iteration 07, mstep_kernels end: 1005949.975639, niter: 8, nfeval: 9
Iteration 07, mstep_indpointslocs start: 1005949.975639
Iteration 07, mstep_indpointslocs end: 1006079.227378, niter: 20, nfeval: 22
Iteration 08, estep start: 1006079.227378
Iteration 08, estep end: 1006111.687889, niter: 20, nfeval: 22
Iteration 08, mstep_embedding start: 1006111.687889
Iteration 08, mstep_embedding end: 1007315.055609, niter: 20, nfeval: 24
Iteration 08, mstep_kernels start: 1007315.055609
Iteration 08, mstep_kernels end: 1006289.421176, niter: 7, nfeval: 9
Iteration 08, mstep_indpointslocs start: 1006289.421176
Iteration 08, mstep_indpointslocs end: 1006389.702885, niter: 20, nfeval: 22
Iteration 09, estep start: 1006389.702885
Iteration 09, estep end: 1006417.660659, niter: 20, nfeval: 22
Iteration 09, mstep_embedding start: 1006417.660659
Iteration 09, mstep_embedding end: 1007496.626271, niter: 20, nfeval: 22
Iteration 09, mstep_kernels start: 1007496.626271
Iteration 09, mstep_kernels end: 1006550.819127, niter: 7, nfeval: 9
Iteration 09, mstep_indpointslocs start: 1006550.819127
Iteration 09, mstep_indpointslocs end: 1006645.995529, niter: 20, nfeval: 23
Iteration 10, estep start: 1006645.995529
Iteration 10, estep end: 1006672.383425, niter: 20, nfeval: 22
Iteration 10, mstep_embedding start: 1006672.383425
Iteration 10, mstep_embedding end: 1007672.161427, niter: 20, nfeval: 23
Iteration 10, mstep_kernels start: 1007672.161427
Iteration 10, mstep_kernels end: 1006786.495230, niter: 7, nfeval: 9
Iteration 10, mstep_indpointslocs start: 1006786.495230
Iteration 10, mstep_indpointslocs end: 1006884.330343, niter: 20, nfeval: 23
Iteration 11, estep start: 1006884.330343
Iteration 11, estep end: 1006917.137028, niter: 20, nfeval: 22
Iteration 11, mstep_embedding start: 1006917.137028
Iteration 11, mstep_embedding end: 1007851.171472, niter: 20, nfeval: 25
Iteration 11, mstep_kernels start: 1007851.171472
Iteration 11, mstep_kernels end: 1007021.449478, niter: 6, nfeval: 8
Iteration 11, mstep_indpointslocs start: 1007021.449478
Iteration 11, mstep_indpointslocs end: 1007119.368865, niter: 20, nfeval: 22
Iteration 12, estep start: 1007119.368865
Iteration 12, estep end: 1007153.750191, niter: 20, nfeval: 22
Iteration 12, mstep_embedding start: 1007153.750191
Iteration 12, mstep_embedding end: 1008037.916298, niter: 20, nfeval: 23
Iteration 12, mstep_kernels start: 1008037.916298
Iteration 12, mstep_kernels end: 1007248.340255, niter: 6, nfeval: 8
Iteration 12, mstep_indpointslocs start: 1007248.340255
Iteration 12, mstep_indpointslocs end: 1007346.250877, niter: 20, nfeval: 23
Iteration 13, estep start: 1007346.250877
Iteration 13, estep end: 1007379.656071, niter: 20, nfeval: 24
Iteration 13, mstep_embedding start: 1007379.656071
Iteration 13, mstep_embedding end: 1008215.292073, niter: 20, nfeval: 22
Iteration 13, mstep_kernels start: 1008215.292073
Iteration 13, mstep_kernels end: 1007463.915059, niter: 5, nfeval: 8
Iteration 13, mstep_indpointslocs start: 1007463.915059
Iteration 13, mstep_indpointslocs end: 1007550.261962, niter: 20, nfeval: 24
Iteration 14, estep start: 1007550.261962
Iteration 14, estep end: 1007578.947433, niter: 20, nfeval: 25
Iteration 14, mstep_embedding start: 1007578.947433
Iteration 14, mstep_embedding end: 1008366.767131, niter: 20, nfeval: 25
Iteration 14, mstep_kernels start: 1008366.767131
Iteration 14, mstep_kernels end: 1007646.441650, niter: 6, nfeval: 9
Iteration 14, mstep_indpointslocs start: 1007646.441650
Iteration 14, mstep_indpointslocs end: 1007718.297510, niter: 20, nfeval: 24
Iteration 15, estep start: 1007718.297510
Iteration 15, estep end: 1007740.224843, niter: 20, nfeval: 22
Iteration 15, mstep_embedding start: 1007740.224843
Iteration 15, mstep_embedding end: 1008490.160603, niter: 20, nfeval: 22
Iteration 15, mstep_kernels start: 1008490.160603
Iteration 15, mstep_kernels end: 1007801.872879, niter: 6, nfeval: 9
Iteration 15, mstep_indpointslocs start: 1007801.872879
Iteration 15, mstep_indpointslocs end: 1007863.921695, niter: 20, nfeval: 24
Iteration 16, estep start: 1007863.921695
Iteration 16, estep end: 1007886.165951, niter: 20, nfeval: 25
Iteration 16, mstep_embedding start: 1007886.165951
Iteration 16, mstep_embedding end: 1008600.136247, niter: 20, nfeval: 24
Iteration 16, mstep_kernels start: 1008600.136247
Iteration 16, mstep_kernels end: 1007935.349276, niter: 6, nfeval: 9
Iteration 16, mstep_indpointslocs start: 1007935.349276
Iteration 16, mstep_indpointslocs end: 1007989.681645, niter: 20, nfeval: 23
Iteration 17, estep start: 1007989.681645
Iteration 17, estep end: 1008008.478958, niter: 20, nfeval: 23
Iteration 17, mstep_embedding start: 1008008.478958
Iteration 17, mstep_embedding end: 1008693.379550, niter: 20, nfeval: 23
Iteration 17, mstep_kernels start: 1008693.379550
Iteration 17, mstep_kernels end: 1008053.692009, niter: 6, nfeval: 9
Iteration 17, mstep_indpointslocs start: 1008053.692009
Iteration 17, mstep_indpointslocs end: 1008099.126966, niter: 20, nfeval: 23
Iteration 18, estep start: 1008099.126966
Iteration 18, estep end: 1008118.751830, niter: 20, nfeval: 24
Iteration 18, mstep_embedding start: 1008118.751830
Iteration 18, mstep_embedding end: 1008776.543436, niter: 20, nfeval: 22
Iteration 18, mstep_kernels start: 1008776.543436
Iteration 18, mstep_kernels end: 1008156.230533, niter: 6, nfeval: 9
Iteration 18, mstep_indpointslocs start: 1008156.230533
Iteration 18, mstep_indpointslocs end: 1008195.639122, niter: 20, nfeval: 22
Iteration 19, estep start: 1008195.639122
Iteration 19, estep end: 1008214.251499, niter: 20, nfeval: 23
Iteration 19, mstep_embedding start: 1008214.251499
Iteration 19, mstep_embedding end: 1008851.780522, niter: 20, nfeval: 22
Iteration 19, mstep_kernels start: 1008851.780522
Iteration 19, mstep_kernels end: 1008245.164088, niter: 6, nfeval: 9
Iteration 19, mstep_indpointslocs start: 1008245.164088
Iteration 19, mstep_indpointslocs end: 1008284.550419, niter: 20, nfeval: 24
Iteration 20, estep start: 1008284.550419
Iteration 20, estep end: 1008301.322737, niter: 20, nfeval: 23
Iteration 20, mstep_embedding start: 1008301.322737
Iteration 20, mstep_embedding end: 1008920.775276, niter: 20, nfeval: 22
Iteration 20, mstep_kernels start: 1008920.775276
Iteration 20, mstep_kernels end: 1008327.217120, niter: 6, nfeval: 9
Iteration 20, mstep_indpointslocs start: 1008327.217120
Iteration 20, mstep_indpointslocs end: 1008358.274073, niter: 20, nfeval: 24
Iteration 21, estep start: 1008358.274073
Iteration 21, estep end: 1008376.283336, niter: 20, nfeval: 23
Iteration 21, mstep_embedding start: 1008376.283336
Iteration 21, mstep_embedding end: 1008983.847546, niter: 20, nfeval: 23
Iteration 21, mstep_kernels start: 1008983.847546
Iteration 21, mstep_kernels end: 1008397.696074, niter: 6, nfeval: 9
Iteration 21, mstep_indpointslocs start: 1008397.696074
Iteration 21, mstep_indpointslocs end: 1008424.824064, niter: 20, nfeval: 24
Iteration 22, estep start: 1008424.824064
Iteration 22, estep end: 1008440.204958, niter: 20, nfeval: 23
Iteration 22, mstep_embedding start: 1008440.204958
Iteration 22, mstep_embedding end: 1009037.280333, niter: 20, nfeval: 22
Iteration 22, mstep_kernels start: 1009037.280333
Iteration 22, mstep_kernels end: 1008457.754513, niter: 6, nfeval: 7
Iteration 22, mstep_indpointslocs start: 1008457.754513
Iteration 22, mstep_indpointslocs end: 1008480.977928, niter: 20, nfeval: 24
Iteration 23, estep start: 1008480.977928
Iteration 23, estep end: 1008495.047297, niter: 20, nfeval: 23
Iteration 23, mstep_embedding start: 1008495.047297
Iteration 23, mstep_embedding end: 1009085.528036, niter: 20, nfeval: 23
Iteration 23, mstep_kernels start: 1009085.528036
Iteration 23, mstep_kernels end: 1008511.178146, niter: 7, nfeval: 10
Iteration 23, mstep_indpointslocs start: 1008511.178146
Iteration 23, mstep_indpointslocs end: 1008530.123453, niter: 20, nfeval: 22
Iteration 24, estep start: 1008530.123453
Iteration 24, estep end: 1008542.404188, niter: 20, nfeval: 22
Iteration 24, mstep_embedding start: 1008542.404188
Iteration 24, mstep_embedding end: 1009124.133484, niter: 20, nfeval: 23
Iteration 24, mstep_kernels start: 1009124.133484
Iteration 24, mstep_kernels end: 1008555.185563, niter: 6, nfeval: 10
Iteration 24, mstep_indpointslocs start: 1008555.185563
Iteration 24, mstep_indpointslocs end: 1008571.307263, niter: 20, nfeval: 24
Iteration 25, estep start: 1008571.307263
Iteration 25, estep end: 1008582.389255, niter: 20, nfeval: 23
Iteration 25, mstep_embedding start: 1008582.389255
Iteration 25, mstep_embedding end: 1009157.419597, niter: 20, nfeval: 24
Iteration 25, mstep_kernels start: 1009157.419597
Iteration 25, mstep_kernels end: 1008592.962719, niter: 6, nfeval: 10
Iteration 25, mstep_indpointslocs start: 1008592.962719
Iteration 25, mstep_indpointslocs end: 1008606.840877, niter: 20, nfeval: 25
Iteration 26, estep start: 1008606.840877
Iteration 26, estep end: 1008617.468583, niter: 20, nfeval: 22
Iteration 26, mstep_embedding start: 1008617.468583
Iteration 26, mstep_embedding end: 1009186.617546, niter: 20, nfeval: 24
Iteration 26, mstep_kernels start: 1009186.617546
Iteration 26, mstep_kernels end: 1008626.929016, niter: 6, nfeval: 10
Iteration 26, mstep_indpointslocs start: 1008626.929016
Iteration 26, mstep_indpointslocs end: 1008639.182860, niter: 20, nfeval: 23
Iteration 27, estep start: 1008639.182860
Iteration 27, estep end: 1008649.724493, niter: 20, nfeval: 23
Iteration 27, mstep_embedding start: 1008649.724493
Iteration 27, mstep_embedding end: 1009213.672263, niter: 20, nfeval: 24
Iteration 27, mstep_kernels start: 1009213.672263
Iteration 27, mstep_kernels end: 1008657.615090, niter: 6, nfeval: 10
Iteration 27, mstep_indpointslocs start: 1008657.615090
Iteration 27, mstep_indpointslocs end: 1008668.096148, niter: 20, nfeval: 22
Iteration 28, estep start: 1008668.096148
Iteration 28, estep end: 1008678.372142, niter: 20, nfeval: 24
Iteration 28, mstep_embedding start: 1008678.372142
Iteration 28, mstep_embedding end: 1009237.871204, niter: 20, nfeval: 22
Iteration 28, mstep_kernels start: 1009237.871204
Iteration 28, mstep_kernels end: 1008685.218372, niter: 7, nfeval: 11
Iteration 28, mstep_indpointslocs start: 1008685.218372
Iteration 28, mstep_indpointslocs end: 1008695.137825, niter: 20, nfeval: 23
Iteration 29, estep start: 1008695.137825
Iteration 29, estep end: 1008704.934258, niter: 20, nfeval: 23
Iteration 29, mstep_embedding start: 1008704.934258
Iteration 29, mstep_embedding end: 1009261.632924, niter: 20, nfeval: 23
Iteration 29, mstep_kernels start: 1009261.632924
Iteration 29, mstep_kernels end: 1008711.160850, niter: 7, nfeval: 11
Iteration 29, mstep_indpointslocs start: 1008711.160850
Iteration 29, mstep_indpointslocs end: 1008720.664673, niter: 20, nfeval: 24
Iteration 30, estep start: 1008720.664673
Iteration 30, estep end: 1008730.066452, niter: 20, nfeval: 22
Iteration 30, mstep_embedding start: 1008730.066452
Iteration 30, mstep_embedding end: 1009283.491195, niter: 20, nfeval: 22
Iteration 30, mstep_kernels start: 1009283.491195
Iteration 30, mstep_kernels end: 1008735.656425, niter: 6, nfeval: 10
Iteration 30, mstep_indpointslocs start: 1008735.656425
Iteration 30, mstep_indpointslocs end: 1008744.229288, niter: 20, nfeval: 23
Elapsed time 438.7607 seconds
Saved results to ../results/simulation_model.pickle
4.1.2. Goodness-of-fit analysis#
ks_test_gamma = 10 # number of simulations for the KS test numerical correction
trial_for_gof = 0
cluster_id_for_gof = 1
n_time_steps_IF = 100
trials_times = svGPFA.utils.miscUtils.getTrialsTimes(
start_times=trials_start_times,
end_times=trials_end_times,
n_steps=n_time_steps_IF)
with torch.no_grad():
epm_cif_values = model.computeExpectedPosteriorCIFs(times=trials_times)
cif_values_GOF = epm_cif_values[trial_for_gof][cluster_id_for_gof]
trial_times_GOF = trials_times[trial_for_gof, :, 0]
spikes_times_GOF = spikes_times[trial_for_gof][cluster_id_for_gof].numpy()
if len(spikes_times_GOF) == 0:
raise ValueError("No spikes found for goodness-of-fit analysis")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
diffECDFsX, diffECDFsY, estECDFx, estECDFy, simECDFx, simECDFy, cb = \
gcnu_common.stats.pointProcesses.tests.KSTestTimeRescalingNumericalCorrection(
spikes_times=spikes_times_GOF, cif_times=trial_times_GOF,
cif_values=cif_values_GOF, gamma=ks_test_gamma)
title = "Trial {:d}, Neuron {:d}".format(trial_for_gof, cluster_id_for_gof)
fig = svGPFA.plot.plotUtilsPlotly.getPlotResKSTestTimeRescalingNumericalCorrection(diffECDFsX=diffECDFsX, diffECDFsY=diffECDFsY, estECDFx=estECDFx, estECDFy=estECDFy, simECDFx=simECDFx, simECDFy=simECDFy, cb=cb, title=title)
fig
Processing given ISIs
Processing iter 0/9
Processing iter 1/9
Processing iter 2/9
Processing iter 3/9
Processing iter 4/9
Processing iter 5/9
Processing iter 6/9
Processing iter 7/9
Processing iter 8/9
Processing iter 9/9
4.1.2.1. ROC predictive analysis#
with warnings.catch_warnings():
warnings.simplefilter("ignore")
fpr, tpr, roc_auc = svGPFA.utils.miscUtils.computeSpikeClassificationROC(
spikes_times=spikes_times_GOF,
cif_times=trial_times_GOF,
cif_values=cif_values_GOF)
fig = svGPFA.plot.plotUtilsPlotly.getPlotResROCAnalysis(
fpr=fpr, tpr=tpr, auc=roc_auc, title=title)
fig
4.1.3. Plotting#
4.1.3.1. Imports for plotting#
import numpy as np
import plotly.express as px
4.1.3.2. Set plotting parameters#
neuron_to_plot = 0
latent_to_plot = 0
trials_colorscale = "hot"
4.1.3.2.1. Set trials colors#
trials_colors = px.colors.sample_colorscale(
colorscale=trials_colorscale, samplepoints=n_trials,
colortype="rgb")
trials_colors_patterns = [f"rgba{trial_color[3:-1]}, {{:f}})"
for trial_color in trials_colors]
4.1.3.2.2. Set trials ids#
trials_ids = [r for r in range(n_trials)]
4.1.3.3. Lower bound history#
fig = svGPFA.plot.plotUtilsPlotly.getPlotLowerBoundHist(
lowerBoundHist=lowerBoundHist)
fig
4.1.3.4. Latent across trials#
test_mu_k, test_var_k = model.predictLatents(times=trials_times)
fig = svGPFA.plot.plotUtilsPlotly.getPlotLatentAcrossTrials(
times=trials_times.numpy(), latentsMeans=test_mu_k,
latentsSTDs=torch.sqrt(test_var_k), latentToPlot=latent_to_plot,
trials_colors_patterns=trials_colors_patterns, xlabel="Time (msec)")
fig
4.1.3.4.1. Orthonormalized latent across trials#
testMuK, testVarK = model.predictLatents(times=trials_times)
testMuK_np = [testMuK[r].detach().numpy() for r in range(len(testMuK))]
estimatedC, estimatedD = model.getSVEmbeddingParams()
estimatedC_np = estimatedC.detach().numpy()
fig = svGPFA.plot.plotUtilsPlotly.getPlotOrthonormalizedLatentAcrossTrials(
trials_times=trials_times,
latentsMeans=testMuK_np,
C=estimatedC_np,
trials_ids=trials_ids,
latentToPlot=latent_to_plot,
trials_colors=trials_colors,
xlabel="Time (msec)")
fig
4.1.3.5. Embedding#
embedding_means, embedding_vars = model.predictEmbedding(times=trials_times)
embedding_means = embedding_means.detach().numpy()
embedding_vars = embedding_vars.detach().numpy()
title = "Neuron {:d}".format(neuron_to_plot)
fig = svGPFA.plot.plotUtilsPlotly.getPlotEmbeddingAcrossTrials(times=trials_times.numpy(), embeddingsMeans=embedding_means[:,:,neuron_to_plot], embeddingsSTDs=np.sqrt(embedding_vars[:,:,neuron_to_plot]), trials_colors_patterns=trials_colors_patterns, title=title)
fig
4.1.3.6. IFs#
with torch.no_grad():
ePos_IF_values = model.computeExpectedPosteriorCIFs(times=trials_times)
fig = svGPFA.plot.plotUtilsPlotly.getPlotCIFsOneNeuronAllTrials(
trials_times=trials_times, cif_values=ePos_IF_values,
trials_ids=trials_ids, neuron_index=neuron_to_plot,
trials_colors=trials_colors)
fig
4.1.3.7. Embedding parameters#
estimatedC, estimatedD = model.getSVEmbeddingParams()
fig = svGPFA.plot.plotUtilsPlotly.getPlotEmbeddingParams(C=estimatedC.numpy(), d=estimatedD.numpy())
fig
4.1.3.8. Kernels parameters#
kernelsParams = model.getKernelsParams()
kernelsTypes = [type(kernel).__name__ for kernel in model.getKernels()]
fig = svGPFA.plot.plotUtilsPlotly.getPlotKernelsParams(
kernelsTypes=kernelsTypes, kernelsParams=kernelsParams)
fig
To run the Python script or Jupyter notebook below, please download them to the examples/sphinx_gallery folder of the repository and execute them from there.
# sphinx_gallery_thumbnail_path = '_static/model.png'
Total running time of the script: ( 7 minutes 46.571 seconds)