.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_sumProduct.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_sumProduct.py: Example in Figure 5.3 of `David Barber's book Bayesian Reasoning and Machine Learning `_ ===================================================================================================================================================== .. GENERATED FROM PYTHON SOURCE LINES 10-12 Import required packages ^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 12-17 .. code-block:: Python import numpy as np import rxMsgPassing.sumProduct .. GENERATED FROM PYTHON SOURCE LINES 18-20 Define probability tables ^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 20-40 .. code-block:: Python paGb = np.array([[0.4, 0.0], # p(a|b) [0.2, 0.1], [0.4, 0.2], [0.0, 0.7]]) pbGcd = np.array([[[0.8, 0.7, 0.9], # p(b|c,d) [0.5, 0.3, 0.2]], [[0.2, 0.3, 0.1], [0.5, 0.7, 0.8]]]) pc = np.array([0.2, 0.8]) # p(c) pd = np.array([0.1, 0.3, 0.6]) # p(d) peGd = np.array([[0.1, 0.7, 0.0], # p(e|d) [0.1, 0.3, 0.0], [0.2, 0.0, 0.0], [0.3, 0.0, 0.0], [0.3, 0.0, 1.0]]) .. GENERATED FROM PYTHON SOURCE LINES 41-43 Create factor nodes ^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 43-70 .. code-block:: Python f1_probabilities = paGb f1_varNames = ["va", "vb"] f1 = rxMsgPassing.sumProduct.FactorNode(name="f1", probabilities=f1_probabilities, var_names=f1_varNames) f2_probabilities = pbGcd f2_varNames = ["vb", "vc", "vd"] f2 = rxMsgPassing.sumProduct.FactorNode(name="f2", probabilities=f2_probabilities, var_names=f2_varNames) f3_probabilities = pc f3_varNames = ["vc"] f3 = rxMsgPassing.sumProduct.FactorNode(name="f3", probabilities=f3_probabilities, var_names=f3_varNames) f4_probabilities = peGd f4_varNames = ["ve", "vd"] f4 = rxMsgPassing.sumProduct.FactorNode(name="f4", probabilities=f4_probabilities, var_names=f4_varNames) f5_probabilities = pd f5_varNames = ["vd"] f5 = rxMsgPassing.sumProduct.FactorNode(name="f5", probabilities=f5_probabilities, var_names=f5_varNames) .. GENERATED FROM PYTHON SOURCE LINES 71-73 Create variable nodes ^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 73-80 .. code-block:: Python va = rxMsgPassing.sumProduct.VariableNode(name="va") vb = rxMsgPassing.sumProduct.VariableNode(name="vb") vc = rxMsgPassing.sumProduct.VariableNode(name="vc") vd = rxMsgPassing.sumProduct.VariableNode(name="vd") ve = rxMsgPassing.sumProduct.VariableNode(name="ve") .. GENERATED FROM PYTHON SOURCE LINES 81-83 Link variable nodes to factor nodes ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 83-90 .. code-block:: Python f1.neighbors = [va, vb] f2.neighbors = [vb, vc, vd] f3.neighbors = [vc] f4.neighbors = [vd, ve] f5.neighbors = [vd] .. GENERATED FROM PYTHON SOURCE LINES 91-93 Link factor nodes to variable nodes ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 93-100 .. code-block:: Python va.neighbors = [f1] vb.neighbors = [f1, f2] vc.neighbors = [f2, f3] vd.neighbors = [f2, f4, f5] ve.neighbors = [f4] .. GENERATED FROM PYTHON SOURCE LINES 101-103 Computer marginal of a by message passing ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 103-107 .. code-block:: Python m_a = va.marginal() print(f"message passing: p(a)={m_a}") .. rst-class:: sphx-glr-script-out .. code-block:: none Computed msg from factor f3 to variable vc: [0.2 0.8] Computed msg from variable vc to factor f2: [0.2 0.8] Computed msg from variable ve to factor f4: [1. 1. 1. 1. 1.] Computed msg from factor f4 to variable vd: [1. 1. 1.] Computed msg from factor f5 to variable vd: [0.1 0.3 0.6] Computed msg from variable vd to factor f2: [0.1 0.3 0.6] Computed msg from factor f2 to variable vb: [0.374 0.626] Computed msg from variable vb to factor f1: [0.374 0.626] Computed msg from factor f1 to variable va: [0.1496 0.1374 0.2748 0.4382] message passing: p(a)=[0.1496 0.1374 0.2748 0.4382] .. GENERATED FROM PYTHON SOURCE LINES 108-110 Computer marginal of a by brute force ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 110-136 .. code-block:: Python domain_a = np.arange(paGb.shape[0]) domain_b = np.arange(paGb.shape[1]) domain_c = np.arange(pbGcd.shape[1]) domain_d = np.arange(pbGcd.shape[2]) domain_e = np.arange(peGd.shape[0]) def pabcde(a, b, c, d, e): answer = paGb[tuple([a, b])].item() * \ pbGcd[tuple([b, c, d])].item() * \ pc[c].item() * \ pd[d].item() * \ peGd[tuple([e, d])].item() return answer bf_m_a = [None] * len(domain_a) for i, a in enumerate(domain_a): total = 0.0 for b in domain_b: for c in domain_c: for d in domain_d: for e in domain_e: total += pabcde(a=a, b=b, c=c, d=d, e=e) bf_m_a[i] = total print(f"brute force: p(a)={bf_m_a}") .. rst-class:: sphx-glr-script-out .. code-block:: none brute force: p(a)=[0.1496, 0.1374, 0.2748, 0.4381999999999999] .. GENERATED FROM PYTHON SOURCE LINES 137-139 Test agreement between message passing and brute force marginals ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 139-148 .. code-block:: Python tol = 1e-6 for i in range(len(m_a)): if abs(m_a[i] - bf_m_a[i]) < tol: print(f"Agreement in component {i}") else: print(f"Disagreement in component {i}") .. rst-class:: sphx-glr-script-out .. code-block:: none Agreement in component 0 Agreement in component 1 Agreement in component 2 Agreement in component 3 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.020 seconds) .. _sphx_glr_download_auto_examples_plot_sumProduct.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_sumProduct.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_sumProduct.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_sumProduct.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_