import matplotlib.pyplot as plt
import pymc3 as pm
import aesara.tensor as at
import arviz as az
import scipy.stats as st
import numpy as np
az.style.use('arviz-whitegrid')
Multivariate bernoulli distribution
Sources
Motivation
I have recently been introduced to Bayesian framework for statistical analysis with PyMC3 package for Probabilistic Programming in Python 3.
I came across to a data given to me by my friend, which had multiple columns of binary data (responses to various questions with answers in binary format, yes or no). Each column could be represented by bernoulli process with an unknown probability of getting answer “yes”, which is what I wished to estimate. In addition to the above, the data also contains possible association between the columns, which I so dearly wanted to estimate. After a lot of literature search I came across the work done by B Dai in 2012 - 2013, which have been cited above in Sources.
The following post aims to translate the log likelihood function, joint probability distribution, marginal probability distribution and conditions for independence described above into PyMC3 code, which can be used by others to solve related problems.
Formulation
Let be a dimensional random vector of possibly correlated Bernoulli random variables and let be a realisation of . The more general form of of joint probability density is
To simplify the notation, denote quantity to be
Also, we will define the interaction function as
The log-linear formulation of multivariate Bernoulli distribution will be
where, , is the vector of natural parameters for the multivariate Bernoulli distribution.
The normalising factor is defined as
PyMC3 implementation
from itertools import chain, combinations
def powerset(iterable):
'''
[0,1,2] -> [(0,), (1,), (2,), (0, 1), (0, 2), (1, 2), (0, 1, 2)]
'''
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(1, len(s) + 1))
def bernoulli_mv(f, b_f, subsets):
'''
Multivariate Bernoulli Distribution
===================================
f: vector of f
subsets: list of tuples of indices like [0,1,2] -> [(0,), (1,), (2,), (0, 1), (0, 2), (1, 2), (0, 1, 2)]
'''
def logp_(y):
res = 0
for i, sub in enumerate(subsets):
res += pm.math.sum(f[i] * pm.math.prod(y[:, sub], axis = 1))
res = res - (b_f * y.shape[0])
return res
return logp_
y = np.random.choice([0, 1], (50,5))
subsets = list((powerset(range(y.shape[1]))))
subsets
[(0,),
(1,),
(2,),
(3,),
(4,),
(0, 1),
(0, 2),
(0, 3),
(0, 4),
(1, 2),
(1, 3),
(1, 4),
(2, 3),
(2, 4),
(3, 4),
(0, 1, 2),
(0, 1, 3),
(0, 1, 4),
(0, 2, 3),
(0, 2, 4),
(0, 3, 4),
(1, 2, 3),
(1, 2, 4),
(1, 3, 4),
(2, 3, 4),
(0, 1, 2, 3),
(0, 1, 2, 4),
(0, 1, 3, 4),
(0, 2, 3, 4),
(1, 2, 3, 4),
(0, 1, 2, 3, 4)]
with pm.Model() as mod:
fs = pm.Normal('f', 0, 2, shape = len(subsets))
b_f = 1
for s in subsets:
ss = list(powerset(s))
b_f += pm.math.exp(pm.math.sum(fs[np.array([s in ss for s in subsets])]))
b_f = pm.Deterministic('b_f', pm.math.log(b_f))
ys = pm.DensityDist('ys', bernoulli_mv(fs, b_f, subsets), observed = {'y': y})
with mod:
trace = pm.sample(return_inferencedata=True)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
/home/suman/miniconda3/lib/python3.8/site-packages/aesara/graph/fg.py:525: UserWarning: Variable Elemwise{mul,no_inplace}.0 cannot be replaced; it isn't in the FunctionGraph
warnings.warn(
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [f]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 166 seconds.
az.plot_trace(trace)
sumry = az.summary(trace)
sumry.index = subsets + ['b_f']
sumry
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
(0,) | 0.105 | 0.811 | -1.492 | 1.486 | 0.026 | 0.018 | 982.0 | 1376.0 | 1.00 |
(1,) | -0.224 | 0.807 | -1.653 | 1.322 | 0.030 | 0.021 | 705.0 | 796.0 | 1.00 |
(2,) | -1.284 | 0.867 | -3.030 | 0.183 | 0.031 | 0.022 | 779.0 | 995.0 | 1.00 |
(3,) | 0.205 | 0.737 | -1.250 | 1.529 | 0.025 | 0.017 | 891.0 | 1088.0 | 1.00 |
(4,) | 1.300 | 0.703 | -0.049 | 2.620 | 0.030 | 0.021 | 538.0 | 1077.0 | 1.00 |
(0, 1) | 0.846 | 0.945 | -0.828 | 2.673 | 0.031 | 0.023 | 917.0 | 1140.0 | 1.01 |
(0, 2) | 0.678 | 1.003 | -1.277 | 2.468 | 0.034 | 0.024 | 887.0 | 927.0 | 1.00 |
(0, 3) | 0.012 | 0.949 | -1.762 | 1.831 | 0.033 | 0.023 | 833.0 | 1173.0 | 1.00 |
(0, 4) | -2.639 | 1.109 | -4.892 | -0.755 | 0.034 | 0.024 | 1075.0 | 1355.0 | 1.00 |
(1, 2) | 0.588 | 0.975 | -1.185 | 2.402 | 0.031 | 0.023 | 978.0 | 1196.0 | 1.00 |
(1, 3) | -0.210 | 0.962 | -2.129 | 1.499 | 0.039 | 0.028 | 610.0 | 870.0 | 1.00 |
(1, 4) | 0.072 | 0.908 | -1.601 | 1.734 | 0.034 | 0.024 | 721.0 | 1003.0 | 1.00 |
(2, 3) | 0.721 | 0.990 | -1.192 | 2.561 | 0.033 | 0.026 | 878.0 | 1012.0 | 1.00 |
(2, 4) | 0.657 | 0.960 | -1.319 | 2.210 | 0.034 | 0.024 | 818.0 | 1054.0 | 1.00 |
(3, 4) | -0.691 | 0.865 | -2.355 | 0.923 | 0.027 | 0.019 | 1031.0 | 947.0 | 1.00 |
(0, 1, 2) | -0.950 | 1.164 | -3.118 | 1.183 | 0.035 | 0.026 | 1087.0 | 920.0 | 1.00 |
(0, 1, 3) | -0.967 | 1.189 | -3.129 | 1.154 | 0.040 | 0.028 | 878.0 | 948.0 | 1.00 |
(0, 1, 4) | 0.432 | 1.195 | -1.754 | 2.582 | 0.036 | 0.026 | 1084.0 | 1280.0 | 1.00 |
(0, 2, 3) | -0.200 | 1.171 | -2.328 | 1.996 | 0.037 | 0.027 | 974.0 | 1180.0 | 1.00 |
(0, 2, 4) | 0.579 | 1.259 | -1.800 | 2.954 | 0.040 | 0.028 | 995.0 | 1087.0 | 1.01 |
(0, 3, 4) | -0.727 | 1.414 | -3.555 | 1.626 | 0.042 | 0.036 | 1139.0 | 943.0 | 1.01 |
(1, 2, 3) | 1.388 | 1.144 | -0.822 | 3.425 | 0.038 | 0.027 | 927.0 | 1070.0 | 1.00 |
(1, 2, 4) | -0.323 | 1.125 | -2.514 | 1.715 | 0.034 | 0.024 | 1076.0 | 1242.0 | 1.00 |
(1, 3, 4) | 0.622 | 1.099 | -1.404 | 2.655 | 0.037 | 0.026 | 861.0 | 1056.0 | 1.00 |
(2, 3, 4) | -0.930 | 1.132 | -3.074 | 1.220 | 0.034 | 0.024 | 1129.0 | 1156.0 | 1.00 |
(0, 1, 2, 3) | -0.570 | 1.340 | -3.090 | 1.855 | 0.040 | 0.028 | 1132.0 | 1169.0 | 1.00 |
(0, 1, 2, 4) | 1.173 | 1.407 | -1.447 | 3.790 | 0.043 | 0.032 | 1080.0 | 1207.0 | 1.00 |
(0, 1, 3, 4) | 1.452 | 1.487 | -1.266 | 4.308 | 0.047 | 0.038 | 1016.0 | 1039.0 | 1.01 |
(0, 2, 3, 4) | 0.101 | 1.554 | -2.724 | 3.120 | 0.045 | 0.040 | 1219.0 | 1013.0 | 1.00 |
(1, 2, 3, 4) | -2.363 | 1.360 | -4.810 | 0.373 | 0.043 | 0.032 | 1018.0 | 1085.0 | 1.00 |
(0, 1, 2, 3, 4) | 1.404 | 1.547 | -1.584 | 4.178 | 0.039 | 0.034 | 1544.0 | 1164.0 | 1.00 |
b_f | 3.955 | 0.622 | 2.760 | 5.096 | 0.022 | 0.015 | 832.0 | 982.0 | 1.00 |
f Joint Probabilities
We can get all the joint probabilities as follows
p_names = []
for sub in subsets:
p_name = ['0'] * y.shape[1]
for i in sub:
p_name[i] = '1'
p_names.append(p_name)
p_names = ['p_' + ''.join(name) for name in p_names]
for p_name, s in zip(p_names, subsets):
ss = list(powerset(s))
idx_S = np.array([s in ss for s in subsets])
trace.posterior[p_name] = np.exp(np.sum(trace.posterior['f'][...,idx_S], axis = 2)) / \
np.exp(trace.posterior['b_f'])
trace.posterior['p_' + '0' * y.shape[1]] = np.exp(-trace.posterior['b_f'])
az.summary(trace, var_names = ['p_'], filter_vars='like')
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
p_10000 | 0.027 | 0.018 | 0.003 | 0.063 | 0.000 | 0.000 | 1269.0 | 1114.0 | 1.00 |
p_01000 | 0.019 | 0.014 | 0.002 | 0.044 | 0.000 | 0.000 | 932.0 | 1178.0 | 1.00 |
p_00100 | 0.008 | 0.007 | 0.000 | 0.020 | 0.000 | 0.000 | 780.0 | 1120.0 | 1.00 |
p_00010 | 0.028 | 0.018 | 0.005 | 0.060 | 0.000 | 0.000 | 1356.0 | 1359.0 | 1.00 |
p_00001 | 0.077 | 0.034 | 0.022 | 0.142 | 0.001 | 0.001 | 1559.0 | 1458.0 | 1.00 |
p_11000 | 0.047 | 0.027 | 0.006 | 0.095 | 0.001 | 0.000 | 1579.0 | 1282.0 | 1.00 |
p_10100 | 0.016 | 0.014 | 0.001 | 0.043 | 0.000 | 0.000 | 1319.0 | 1217.0 | 1.00 |
p_10010 | 0.033 | 0.021 | 0.002 | 0.068 | 0.000 | 0.000 | 1837.0 | 1401.0 | 1.00 |
p_10001 | 0.009 | 0.010 | 0.000 | 0.026 | 0.000 | 0.000 | 1430.0 | 1357.0 | 1.00 |
p_01100 | 0.012 | 0.011 | 0.000 | 0.030 | 0.000 | 0.000 | 1416.0 | 1003.0 | 1.00 |
p_01010 | 0.021 | 0.016 | 0.000 | 0.050 | 0.000 | 0.000 | 1312.0 | 1325.0 | 1.00 |
p_01001 | 0.068 | 0.032 | 0.013 | 0.124 | 0.001 | 0.001 | 1914.0 | 1092.0 | 1.00 |
p_00110 | 0.019 | 0.015 | 0.001 | 0.046 | 0.000 | 0.000 | 1615.0 | 1232.0 | 1.00 |
p_00101 | 0.045 | 0.027 | 0.006 | 0.092 | 0.001 | 0.000 | 1930.0 | 1295.0 | 1.00 |
p_00011 | 0.051 | 0.029 | 0.006 | 0.104 | 0.001 | 0.000 | 1856.0 | 1548.0 | 1.00 |
p_11100 | 0.022 | 0.018 | 0.001 | 0.055 | 0.000 | 0.000 | 1921.0 | 1436.0 | 1.00 |
p_11010 | 0.021 | 0.018 | 0.000 | 0.052 | 0.000 | 0.000 | 1718.0 | 1417.0 | 1.00 |
p_11001 | 0.024 | 0.019 | 0.001 | 0.057 | 0.000 | 0.000 | 1884.0 | 1303.0 | 1.00 |
p_10110 | 0.031 | 0.022 | 0.001 | 0.070 | 0.001 | 0.000 | 1735.0 | 1441.0 | 1.00 |
p_10101 | 0.016 | 0.015 | 0.000 | 0.044 | 0.000 | 0.000 | 1869.0 | 1480.0 | 1.00 |
p_10011 | 0.004 | 0.007 | 0.000 | 0.014 | 0.000 | 0.000 | 1274.0 | 913.0 | 1.00 |
p_01110 | 0.071 | 0.035 | 0.011 | 0.130 | 0.001 | 0.001 | 2181.0 | 1661.0 | 1.00 |
p_01101 | 0.050 | 0.028 | 0.006 | 0.100 | 0.001 | 0.000 | 2134.0 | 1634.0 | 1.00 |
p_01011 | 0.064 | 0.033 | 0.014 | 0.125 | 0.001 | 0.001 | 2077.0 | 1318.0 | 1.00 |
p_00111 | 0.026 | 0.020 | 0.001 | 0.063 | 0.000 | 0.000 | 2007.0 | 1581.0 | 1.00 |
p_11110 | 0.030 | 0.022 | 0.002 | 0.072 | 0.001 | 0.000 | 1971.0 | 1584.0 | 1.00 |
p_11101 | 0.061 | 0.033 | 0.010 | 0.121 | 0.001 | 0.001 | 1785.0 | 1539.0 | 1.00 |
p_11011 | 0.020 | 0.018 | 0.000 | 0.053 | 0.000 | 0.000 | 1957.0 | 1526.0 | 1.00 |
p_10111 | 0.006 | 0.009 | 0.000 | 0.022 | 0.000 | 0.000 | 1804.0 | 1358.0 | 1.00 |
p_01111 | 0.018 | 0.017 | 0.000 | 0.050 | 0.000 | 0.000 | 2120.0 | 1667.0 | 1.00 |
p_11111 | 0.033 | 0.023 | 0.002 | 0.077 | 0.001 | 0.000 | 1862.0 | 1732.0 | 1.00 |
p_00000 | 0.023 | 0.014 | 0.003 | 0.049 | 0.000 | 0.000 | 832.0 | 982.0 | 1.01 |
Marginal Probabilities
We represent marginal probability as follows
is probability of column being 1, where
Similarly, is the probability of and columns being 1, where .
def marg_prob(idxs, trace, num_columns):
'''
idxs: list of indexes which are 1. eg [0,1] means first and second columns are 1
trace: trace of pymc3 model output
'''
assert max(idxs) < num_columns
keys = [i for i in trace.posterior.keys() if i.startswith('p_')]
res = []
for k in keys:
k_num = k[2:]
for c, i in enumerate(idxs):
if k_num[i] != '1':
break
if c + 1 == len(idxs):
res.append('p_' + k_num)
return np.sum(trace.posterior[res].to_array(), axis = 0)
for i in range(y.shape[1]):
print(f'Marginal Probability {i + 1} ---')
print(az.summary(marg_prob([i], trace, y.shape[1]), kind = 'stats'))
Marginal Probability 1 ---
mean sd hdi_3% hdi_97%
x 0.573 0.067 0.442 0.69
Marginal Probability 2 ---
mean sd hdi_3% hdi_97%
x 0.4 0.065 0.27 0.511
Marginal Probability 3 ---
mean sd hdi_3% hdi_97%
x 0.58 0.067 0.46 0.706
Marginal Probability 4 ---
mean sd hdi_3% hdi_97%
x 0.463 0.071 0.323 0.588
Marginal Probability 5 ---
mean sd hdi_3% hdi_97%
x 0.477 0.07 0.345 0.603
Independence of outcomes
The outcomes are independent, if
def trace_of_independence(idxs, trace, subsets, num_columns):
'''
idxs: list of indexes whose independence is to be checked, eg [0,1,2] means that
independence of 1st, 2nd and 3rd index needs to be checked
trace: pymc3 trace, inferenceData
subsets: list depicting powerset of the indexes
num_columns: number of outcomes
'''
assert len(idxs) > 1
assert max(idxs) < num_columns
ss = [i for i in powerset(idxs) if len(i) > 1]
idx_S = np.array([s in ss for s in subsets])
return np.sum(trace.posterior['f'][...,idx_S], axis = 2)
Issues
The sampling takes prohibitively long time when the number of variables exceeds 7 - 8. I am in the process to tackle the issue.
Also, I have to make mechanism by which random samples can be generated.