Multivariate Bernoulli Distribution: Bayesian Analysis with PyMC3

import matplotlib.pyplot as plt
import pymc3 as pm
import aesara.tensor as at
import arviz as az
import scipy.stats as st
import numpy as np
az.style.use('arviz-whitegrid')

Multivariate bernoulli distribution

Sources

  1. https://doi.org/10.1016/0047-259X(90)90084-U

  2. https://arxiv.org/pdf/1206.1874

Motivation

I have recently been introduced to Bayesian framework for statistical analysis with PyMC3 package for Probabilistic Programming in Python 3.

I came across to a data given to me by my friend, which had multiple columns of binary data (responses to various questions with answers in binary format, yes or no). Each column could be represented by bernoulli process with an unknown probability of getting answer “yes”, which is what I wished to estimate. In addition to the above, the data also contains possible association between the columns, which I so dearly wanted to estimate. After a lot of literature search I came across the work done by B Dai in 2012 - 2013, which have been cited above in Sources.

The following post aims to translate the log likelihood function, joint probability distribution, marginal probability distribution and conditions for independence described above into PyMC3 code, which can be used by others to solve related problems.

Formulation

Let Y=(Y1,Y2,,YK) be a K dimensional random vector of possibly correlated Bernoulli random variables and let y=(y1,y2,,yK) be a realisation of Y. The more general form of p(y1,y2,,yK) of joint probability density is

P(Y1=y1,,YK=yK)=p0,0,..,0j=1K(1yj)p1,0,..,0y1j=2K(1yj)p0,1,..,0y1y2j=3K(1yj)p1,1,..,1j=Kyj

To simplify the notation, denote quantity S to be

Sj1j2jr=1srfjs+1s<trfjsjt++fj1j2jr

Also, we will define the interaction function B as

Bj1j2jr(y)=yj1yj2yjr

The log-linear formulation of multivariate Bernoulli distribution will be

l(y,f)=log[p(y)]=[r=1K(1j1<j2<<jrKfj1j2jrBj1j2jr(y))b(f)]

where, f=(f1,f2,,f12K)T, is the vector of natural parameters for the multivariate Bernoulli distribution.

The normalising factor b(f)=p0,0,,0 is defined as

b(f)=logr=1K[1+(1j1<j2<<jrKeSj1j2jr)]

PyMC3 implementation

from itertools import chain, combinations

def powerset(iterable):
    '''
    [0,1,2] -> [(0,), (1,), (2,), (0, 1), (0, 2), (1, 2), (0, 1, 2)]
    '''
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(1, len(s) + 1))
def bernoulli_mv(f, b_f, subsets):
    '''
    Multivariate Bernoulli Distribution
    ===================================
    f: vector of f
    subsets: list of tuples of indices like [0,1,2] -> [(0,), (1,), (2,), (0, 1), (0, 2), (1, 2), (0, 1, 2)]
    '''
    
    def logp_(y):
        res = 0
        for i, sub in enumerate(subsets):
            res += pm.math.sum(f[i] * pm.math.prod(y[:, sub], axis = 1))
        
        res = res - (b_f * y.shape[0])
        
        return res
    
    return logp_
y = np.random.choice([0, 1], (50,5))
subsets = list((powerset(range(y.shape[1]))))
subsets
[(0,),
 (1,),
 (2,),
 (3,),
 (4,),
 (0, 1),
 (0, 2),
 (0, 3),
 (0, 4),
 (1, 2),
 (1, 3),
 (1, 4),
 (2, 3),
 (2, 4),
 (3, 4),
 (0, 1, 2),
 (0, 1, 3),
 (0, 1, 4),
 (0, 2, 3),
 (0, 2, 4),
 (0, 3, 4),
 (1, 2, 3),
 (1, 2, 4),
 (1, 3, 4),
 (2, 3, 4),
 (0, 1, 2, 3),
 (0, 1, 2, 4),
 (0, 1, 3, 4),
 (0, 2, 3, 4),
 (1, 2, 3, 4),
 (0, 1, 2, 3, 4)]
with pm.Model() as mod:
    fs = pm.Normal('f', 0, 2, shape = len(subsets))

    b_f = 1
    for s in subsets:
        ss = list(powerset(s))
        b_f += pm.math.exp(pm.math.sum(fs[np.array([s in ss for s in subsets])]))
    b_f = pm.Deterministic('b_f', pm.math.log(b_f))
    
    ys = pm.DensityDist('ys', bernoulli_mv(fs, b_f, subsets), observed = {'y': y})
with mod:
    trace = pm.sample(return_inferencedata=True)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
/home/suman/miniconda3/lib/python3.8/site-packages/aesara/graph/fg.py:525: UserWarning: Variable Elemwise{mul,no_inplace}.0 cannot be replaced; it isn't in the FunctionGraph
  warnings.warn(
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [f]
100.00% [4000/4000 02:41<00:00 Sampling 2 chains, 0 divergences]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 166 seconds.
az.plot_trace(trace)
sumry = az.summary(trace)
sumry.index = subsets + ['b_f']
sumry
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
(0,) 0.105 0.811 -1.492 1.486 0.026 0.018 982.0 1376.0 1.00
(1,) -0.224 0.807 -1.653 1.322 0.030 0.021 705.0 796.0 1.00
(2,) -1.284 0.867 -3.030 0.183 0.031 0.022 779.0 995.0 1.00
(3,) 0.205 0.737 -1.250 1.529 0.025 0.017 891.0 1088.0 1.00
(4,) 1.300 0.703 -0.049 2.620 0.030 0.021 538.0 1077.0 1.00
(0, 1) 0.846 0.945 -0.828 2.673 0.031 0.023 917.0 1140.0 1.01
(0, 2) 0.678 1.003 -1.277 2.468 0.034 0.024 887.0 927.0 1.00
(0, 3) 0.012 0.949 -1.762 1.831 0.033 0.023 833.0 1173.0 1.00
(0, 4) -2.639 1.109 -4.892 -0.755 0.034 0.024 1075.0 1355.0 1.00
(1, 2) 0.588 0.975 -1.185 2.402 0.031 0.023 978.0 1196.0 1.00
(1, 3) -0.210 0.962 -2.129 1.499 0.039 0.028 610.0 870.0 1.00
(1, 4) 0.072 0.908 -1.601 1.734 0.034 0.024 721.0 1003.0 1.00
(2, 3) 0.721 0.990 -1.192 2.561 0.033 0.026 878.0 1012.0 1.00
(2, 4) 0.657 0.960 -1.319 2.210 0.034 0.024 818.0 1054.0 1.00
(3, 4) -0.691 0.865 -2.355 0.923 0.027 0.019 1031.0 947.0 1.00
(0, 1, 2) -0.950 1.164 -3.118 1.183 0.035 0.026 1087.0 920.0 1.00
(0, 1, 3) -0.967 1.189 -3.129 1.154 0.040 0.028 878.0 948.0 1.00
(0, 1, 4) 0.432 1.195 -1.754 2.582 0.036 0.026 1084.0 1280.0 1.00
(0, 2, 3) -0.200 1.171 -2.328 1.996 0.037 0.027 974.0 1180.0 1.00
(0, 2, 4) 0.579 1.259 -1.800 2.954 0.040 0.028 995.0 1087.0 1.01
(0, 3, 4) -0.727 1.414 -3.555 1.626 0.042 0.036 1139.0 943.0 1.01
(1, 2, 3) 1.388 1.144 -0.822 3.425 0.038 0.027 927.0 1070.0 1.00
(1, 2, 4) -0.323 1.125 -2.514 1.715 0.034 0.024 1076.0 1242.0 1.00
(1, 3, 4) 0.622 1.099 -1.404 2.655 0.037 0.026 861.0 1056.0 1.00
(2, 3, 4) -0.930 1.132 -3.074 1.220 0.034 0.024 1129.0 1156.0 1.00
(0, 1, 2, 3) -0.570 1.340 -3.090 1.855 0.040 0.028 1132.0 1169.0 1.00
(0, 1, 2, 4) 1.173 1.407 -1.447 3.790 0.043 0.032 1080.0 1207.0 1.00
(0, 1, 3, 4) 1.452 1.487 -1.266 4.308 0.047 0.038 1016.0 1039.0 1.01
(0, 2, 3, 4) 0.101 1.554 -2.724 3.120 0.045 0.040 1219.0 1013.0 1.00
(1, 2, 3, 4) -2.363 1.360 -4.810 0.373 0.043 0.032 1018.0 1085.0 1.00
(0, 1, 2, 3, 4) 1.404 1.547 -1.584 4.178 0.039 0.034 1544.0 1164.0 1.00
b_f 3.955 0.622 2.760 5.096 0.022 0.015 832.0 982.0 1.00

f Joint Probabilities

p(j1,j2,,jr positions are 1, others are 0)=eSj1j2jreb(f)

We can get all the joint probabilities as follows

p_names = []
for sub in subsets:
    p_name = ['0'] * y.shape[1]
    for i in sub:
        p_name[i] = '1'
    p_names.append(p_name)

p_names = ['p_' + ''.join(name) for name in p_names]

for p_name, s in zip(p_names, subsets):
    ss = list(powerset(s))
    idx_S = np.array([s in ss for s in subsets])
    trace.posterior[p_name] = np.exp(np.sum(trace.posterior['f'][...,idx_S], axis = 2)) / \
                              np.exp(trace.posterior['b_f'])

trace.posterior['p_' + '0' * y.shape[1]] = np.exp(-trace.posterior['b_f'])
az.summary(trace, var_names = ['p_'], filter_vars='like')
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
p_10000 0.027 0.018 0.003 0.063 0.000 0.000 1269.0 1114.0 1.00
p_01000 0.019 0.014 0.002 0.044 0.000 0.000 932.0 1178.0 1.00
p_00100 0.008 0.007 0.000 0.020 0.000 0.000 780.0 1120.0 1.00
p_00010 0.028 0.018 0.005 0.060 0.000 0.000 1356.0 1359.0 1.00
p_00001 0.077 0.034 0.022 0.142 0.001 0.001 1559.0 1458.0 1.00
p_11000 0.047 0.027 0.006 0.095 0.001 0.000 1579.0 1282.0 1.00
p_10100 0.016 0.014 0.001 0.043 0.000 0.000 1319.0 1217.0 1.00
p_10010 0.033 0.021 0.002 0.068 0.000 0.000 1837.0 1401.0 1.00
p_10001 0.009 0.010 0.000 0.026 0.000 0.000 1430.0 1357.0 1.00
p_01100 0.012 0.011 0.000 0.030 0.000 0.000 1416.0 1003.0 1.00
p_01010 0.021 0.016 0.000 0.050 0.000 0.000 1312.0 1325.0 1.00
p_01001 0.068 0.032 0.013 0.124 0.001 0.001 1914.0 1092.0 1.00
p_00110 0.019 0.015 0.001 0.046 0.000 0.000 1615.0 1232.0 1.00
p_00101 0.045 0.027 0.006 0.092 0.001 0.000 1930.0 1295.0 1.00
p_00011 0.051 0.029 0.006 0.104 0.001 0.000 1856.0 1548.0 1.00
p_11100 0.022 0.018 0.001 0.055 0.000 0.000 1921.0 1436.0 1.00
p_11010 0.021 0.018 0.000 0.052 0.000 0.000 1718.0 1417.0 1.00
p_11001 0.024 0.019 0.001 0.057 0.000 0.000 1884.0 1303.0 1.00
p_10110 0.031 0.022 0.001 0.070 0.001 0.000 1735.0 1441.0 1.00
p_10101 0.016 0.015 0.000 0.044 0.000 0.000 1869.0 1480.0 1.00
p_10011 0.004 0.007 0.000 0.014 0.000 0.000 1274.0 913.0 1.00
p_01110 0.071 0.035 0.011 0.130 0.001 0.001 2181.0 1661.0 1.00
p_01101 0.050 0.028 0.006 0.100 0.001 0.000 2134.0 1634.0 1.00
p_01011 0.064 0.033 0.014 0.125 0.001 0.001 2077.0 1318.0 1.00
p_00111 0.026 0.020 0.001 0.063 0.000 0.000 2007.0 1581.0 1.00
p_11110 0.030 0.022 0.002 0.072 0.001 0.000 1971.0 1584.0 1.00
p_11101 0.061 0.033 0.010 0.121 0.001 0.001 1785.0 1539.0 1.00
p_11011 0.020 0.018 0.000 0.053 0.000 0.000 1957.0 1526.0 1.00
p_10111 0.006 0.009 0.000 0.022 0.000 0.000 1804.0 1358.0 1.00
p_01111 0.018 0.017 0.000 0.050 0.000 0.000 2120.0 1667.0 1.00
p_11111 0.033 0.023 0.002 0.077 0.001 0.000 1862.0 1732.0 1.00
p_00000 0.023 0.014 0.003 0.049 0.000 0.000 832.0 982.0 1.01

Marginal Probabilities

We represent marginal probability as follows

pi is probability of ith column being 1, where i{1,2,,K}

Similarly, pij is the probability of ith and jth columns being 1, where i,j{1,2,,K}.

pi=for all ji=1pj1j2jk

def marg_prob(idxs, trace, num_columns):
    '''
    idxs: list of indexes which are 1. eg [0,1] means first and second columns are 1
    trace: trace of pymc3 model output
    '''
    assert max(idxs) < num_columns
    keys = [i for i in trace.posterior.keys() if i.startswith('p_')]
    res = []
    for k in keys:
        k_num = k[2:]
        for c, i in enumerate(idxs):
            if k_num[i] != '1':
                break
            if c + 1 == len(idxs):
                res.append('p_' + k_num)
    return np.sum(trace.posterior[res].to_array(), axis = 0)
for i in range(y.shape[1]):
    print(f'Marginal Probability {i + 1} ---')
    print(az.summary(marg_prob([i], trace, y.shape[1]), kind = 'stats'))
Marginal Probability 1 ---
    mean     sd  hdi_3%  hdi_97%
x  0.573  0.067   0.442     0.69
Marginal Probability 2 ---
   mean     sd  hdi_3%  hdi_97%
x   0.4  0.065    0.27    0.511
Marginal Probability 3 ---
   mean     sd  hdi_3%  hdi_97%
x  0.58  0.067    0.46    0.706
Marginal Probability 4 ---
    mean     sd  hdi_3%  hdi_97%
x  0.463  0.071   0.323    0.588
Marginal Probability 5 ---
    mean    sd  hdi_3%  hdi_97%
x  0.477  0.07   0.345    0.603

Independence of outcomes

The outcomes are independent, if

Sj1j2jrk=1rfjk=0,  r2

def trace_of_independence(idxs, trace, subsets, num_columns):
    '''
        idxs: list of indexes whose independence is to be checked, eg [0,1,2] means that 
        independence of 1st, 2nd and 3rd index needs to be checked
        trace: pymc3 trace, inferenceData
        subsets: list depicting powerset of the indexes
        num_columns: number of outcomes
    '''
    assert len(idxs) > 1
    
    assert max(idxs) < num_columns
    
    ss = [i for i in powerset(idxs) if len(i) > 1]
    idx_S = np.array([s in ss for s in subsets])
    
    return np.sum(trace.posterior['f'][...,idx_S], axis = 2)

Issues

The sampling takes prohibitively long time when the number of variables K exceeds 7 - 8. I am in the process to tackle the issue.

Also, I have to make mechanism by which random samples can be generated.