'''
The posterior distribution results from running inference for a given query using MCMC.
This module contains
* Data structures to collect posterior samples
* Convergence diagnostics
The convergence diagnostics have to be called by the ProbReM project script, e.g. from outside the framework.
'''
import logging
import numpy as N
import pylab as PL
from network.vertices import ReferenceVertex
from inference import engine 
'''The engine module contains the :class:`.GBNgraph` instance
'''
samples = {}
'''
Dicitonary containing all the samples collected during inference. It is common to run more than one chain to montitor convergence, the key/value pairs are stored
    { key = 'chainIdentification' : value = :class:`numpy.array` }
'''
currentChain = None
'''
:class:`numpy.array` that is currently being used by the sampler
'''
posteriorVertices = None
'''
Dictionary of vertices that we are collecting samples from (e.g. the event vertices or all sampling vertices including latent variables), set 
'''
currentIndex = {}
'''
Dictionary mapping each event variable `ID` to an index used to access :attr:`.currentChain`
'''
[docs]def initChain(chainID,ITER,onlyEvent=False):
    '''
    Initializes a new MCMC run. Note that `onlyEvent=False`, so the samples or all `engine.GBN.samplingVertices` are collected. But the posterior is a joint distribution over the event variables (thus the other sampling variables are already marginalized)
    
    :arg chainID: String identification for new chain
    :arg ITER: Number of samples to be collected
    :arg onlyEvent: Bolean, if `True` only the values of the event vertices are collected (i.e. not latent sampling variables)
    '''
    global samples,currentChain,currentIndex,posteriorVertices
    
    posteriorVertices = engine.GBN.samplingVertices
    if onlyEvent:
        posteriorVertices = engine.GBN.eventVertices
    currentIndex = {}
    for i,vertexID in enumerate(posteriorVertices.keys()):
        currentIndex[vertexID] = i
        
    nVariables = len(posteriorVertices)
    samples[chainID] = N.zeros((ITER,nVariables))
    currentChain = samples[chainID]
        
 
[docs]def collectSamples(nSample):
    '''
    Extracting the value of a node and storing it in the appropriate `numpy.array`, :attr:`.currentChain`.
    
    :arg nSample: Int Count of the collected sample, i.e. the row number
    '''    
    
    for gbnID,gbnV in posteriorVertices.items():
        
        # If the vertex is a reference vertex, we collect the current reference 
        if isinstance(gbnV,ReferenceVertex):
            i = currentIndex[gbnID]
            # at this point we are assuming that k=1
            # we can extract the ID using the 1st entry of the obj list of the attribute, e.g.  2 in e.g. 'Professor.fame.2' 
            currentChain[nSample,i] = gbnV.references.values()[0].obj[0]
            
        # If the vertex is a normal vertex, we collect the sampled value
        else:
            
            i = currentIndex[gbnID]
            currentChain[nSample,i] = gbnV.value
 
[docs]def plotCumulativeMeanAllChains(**kwargs):
    '''
    Plots the cumulative mean of all available chains using :meth:`.cumulativeMean`.
    If the plots are to be displayed on the same figure, use the `figID` keyword.
    If the plots are to display only a specific variable, use the `gbnV` or `varIndex` keyword.
    :arg kwargs: Optional arguments for :meth:`.cumulativeMean`
    '''    
    # create new figure to plot all chains in
    fig = PL.figure()     
    for chainID in samples.keys():        
        plotCumulativeMean(fig=fig,chainID=chainID, **kwargs)        
 
[docs]def plotCumulativeMean(**kwargs):    
    '''
    Convergence diagnostics that plots the cumulative mean of all the sampling variables in the `currentChain`. If a `chainID` is provided the cumulative mean of the associated chain is plotted instead. If `sVarInd` is provided - e.g. the index of a sampling variable - only the cumulative mean of this variable is plotted.
    
    :arg chainID: Optional identification of chain to be analyzed
    :arg varIndex: Optional index of event variable to be analyzed
    :arg gbnV: Optional :class:`GBNvertex` to be analyzed
    :arg fig: Optional `matplotlib.figure.Figure` to be used 
    '''
    chain = currentChain
    if 'chainID' in kwargs:
        chain = samples[kwargs['chainID']]
    
    #calculate cumulative mean
    cumChain = cumulativeMean(chain)
        
    # either plot on a new or specific figure window
    fig = None
    if 'fig' in kwargs:
        fig = kwargs['fig']
        PL.figure(fig.number)
    else:
        fig = PL.figure()
        
    if 'varIndex' in kwargs:
        PL.plot(cumChain[:,kwargs['varIndex']])        
        PL.xlabel('Samples for index %s'%kwargs['varIndex'])
        PL.ylabel('Mean')        
    elif 'gbnV' in kwargs:
        gbnID = kwargs['gbnV'].ID
        PL.plot(cumChain[:,currentIndex[gbnID]])
        PL.xlabel('Samples for %s'%gbnID)
        PL.ylabel('Mean')        
    else:
        # plot cumulative for all posterior variables in subplots
        for i,pID in enumerate(posteriorVertices.keys()):        
            PL.subplot(len(posteriorVertices),1,(i+1))            
            PL.plot(cumChain[:,currentIndex[pID]])            
            PL.xlabel('Samples for %s'%pID)
            PL.ylabel('Mean')
 
[docs]def cumulativeMean(chain):
    '''
    Returns the cumulative mean of `chain`. 
    :arg chain: `numpy.array` 
    '''
    cumChain = chain.cumsum(axis=0)
    for i in range(cumChain.shape[0]):
        cumChain[i,:] = cumChain[i,:]/(i+1.)    
    return cumChain
 
[docs]def mean(chainID=None, gbnV=None, sVarInd=None, combined=False):
    '''
    Returns the posterior mean of all the sampling variables in the currentChain. 
    If a `chainID` is provided the mean of the associated chain is returned instead. If `sVarInd` or `gbnV` is provided, only the mean of this variable is returned.
    
    The :meth:`pylab.hist` method is used to compute the histogram
    
    :arg chainID: Optional identification of chain to be analyzed
    :arg sVarInd: Optional index of event variable to be analyzed
    :arg gbnV: Optional :class:`GBNvertex` event variable to be analyzed 
    :arg combined: Optional (default `False`), if `True` the mean of the mean of all event variables is returned (single value).   
    :returns: Posterior mean as :class:`numpy.array`. If `sVarInd` or `gbnV` are specified, this is a single value
    '''
    chain = currentChain
    if chainID is not None:
        chain = samples[chainID]
            
    #computing mean for one or all variables?
    if combined:
        return N.mean(chain)
    elif sVarInd is not None:
        return N.mean(chain[:,sVarInd])
    elif gbnV is not None:
        ind = currentIndex[gbnV]    
        return N.mean(chain[:,ind])
    else:
        return N.mean(chain,0)    
 
[docs]def histogramm(**kwargs):
    '''
    Convergence diagnostics that plots the posterior density function (using the matplotlib histogram) mean of all the sampling variables in the currentChain. 
    If a `chainID` is provided the histogram of the associated chain is plotted instead. If `varIndex` or `gbnV` is provided, only the histogram of this variable is plotted.
    
    :arg chainID: Optional identification of chain to be analyzed
    :arg varIndex: Optional index of event variable to be analyzed
    :arg gbnV: Optional :class:`GBNvertex` event variable to be analyzed    
    :arg fig: Optional `matplotlib.figure.Figure` to be used 
    '''
    chain = currentChain
    if 'chainID' in kwargs:
        chain = samples[kwargs['chainID']]
        
    
    # either plot on a new or specific figure window
    fig = None
    if 'fig' in kwargs:
        fig = kwargs['fig']
        PL.figure(fig.number)
    else:
        fig = PL.figure()
                
    if 'varIndex' in kwargs:        
        PL.hist(chain[:,kwargs['varIndex']])
    elif 'gbnV' in kwargs:        
        PL.hist(chain[:,currentIndex[kwargs['gbnV'].ID]])
    else:
        logging.info("WARNING: no posterior variable passed as argument for histogramm")
        PL.hist(chain)
 
[docs]def gelman_rubin():
    '''
    Plots the Gelman Rubin convergence diagnostic, according to `Probabilistic Graphical Models` (p. 523).
    '''
    # extract the number of samples M 
    M = engine.inferenceAlgo.ITER
    # number of discarded samples (burnin)
    T = engine.inferenceAlgo.BURNIN
    # number of chains
    K = engine.inferenceAlgo.CHAINS
    # cumulative mean of for all chains
    # f^{bar}_{k} in koller book
    f_bar_K = {}
    for ID,chain in samples.items():
        f_bar_k = cumulativeMean(chain)
        f_bar_K[ID] = f_bar_k
    
    # cumulative mean across all chains
    # f^{bar} in koller book
    f_bar = N.zeros(currentChain.shape,dtype=float)
    for f_bar_k in f_bar_K.values():
        f_bar += f_bar_k
    f_bar /= K
    # between chain variance
    # B in koller book
    B = N.zeros(currentChain.shape,dtype=float)
    for f_bar_k in f_bar_K.values():
        B += (f_bar_k - f_bar)**2
    B *= M/(K-1)
    # with-in chain variance
    # W in koller book
    W = N.zeros(currentChain.shape,dtype=float)
    # With just one sample variance will be inf
    # W[0,:] = N.inf
    W[0,:] = 1
    for ID,chain in samples.items():
        for j in range(1,M):
            W[j,:] += 1./(j)*((chain[0:(j+1),:]-f_bar_K[ID][j,:])**2).sum(axis=0)
        # W += ((chain - f_bar_K[ID])**2).cumsum(axis=0)
    W *= (1./K)
    # An estimator that can be shown to overestimate the variance
    # V in koller book
    V = (M-1.)/M * W + 1./M * B
    # For M to ininity, both V and W converge to the true variance of the estimate
    
    # Measure of disagreement between the chains
    # R_hat is koller book
    R_hat = N.sqrt(V/W)
    # print 'samples',samples.values()
    # print 'f_bar_K',f_bar_K
    # print 'f_bar',f_bar
    # print 'B',B
    # print 'W',W 
    # print 'V',V 
    # print 'R_hat',R_hat
    PL.figure()
    for i,pID in enumerate(posteriorVertices.keys()):        
        PL.subplot(len(posteriorVertices),1,(i+1))
        PL.plot(R_hat[1:,currentIndex[pID]])
        PL.xlabel(pID)
        PL.ylabel('Gelman-Rubin')
    
    return R_hat
     
[docs]def autocorrelation(max_l = 50, **kwargs):
    '''
    Plots the autocorrelation. Computed according to `Probabilistic Graphical Models` (p. 521). 
    :arg max_l: The autocorellation will be calculated up to lag `max_l` (default=50)
    :arg chainID: Optional identification of chain to be analyzed
    '''
    chain = currentChain
    if 'chainID' in kwargs:
        chain = samples[kwargs['chainID']]
    # extract the number of samples M 
    M = engine.inferenceAlgo.ITER
    
    # The cumulative mean
    E_bar = cumulativeMean(chain)
    # The `cumulative` Variance (equation 12.27)
    V = N.zeros(chain.shape,dtype=float)
    # Variance with one sample would be infinite I guess
    V[0,:] = N.inf
    for j in range(1,M):
        V[j,:] = 1./(j)*((chain[0:(j+1),:]-E_bar[j,:])**2).sum(axis=0)
    # Note that so far we have calculated the cumulative values or the mean and variance, i.e. E_bar and V are M x |Y| matrices
    # However, for the covariance and the autocorrelation we only use the mean/variance for all M samples, i.e. V[-1,:] and E_bar[-1,:]
    # The covariance (equation 12.28)
    # A max_l x |Y| matrix, column `l` is the covariance with lag `l`
    C = N.zeros((max_l,chain.shape[1]),dtype=float)
    # For l=0, the covariance is equal to the variane
    C[0,:] = V[-1,:]
    for l in range(1,max_l):
        C[l,:] = 1./(M-l) * ((chain[0:(M-l),:]-E_bar[-1])*(chain[l:,:]-E_bar[-1])).sum(axis=0)
    
    # autocorrelation   
    A = C / V[-1,:]
    # print A
    PL.figure()
    for i,pID in enumerate(posteriorVertices.keys()):        
        PL.subplot(len(posteriorVertices),1,(i+1))
        PL.bar(N.arange(max_l), A[:,currentIndex[pID]])        
        PL.xlabel(pID)
        PL.ylabel('Autocorrelation')
    
 
def __marginalize(margVertices):
    '''
    TODO: Marginalize `margVertices` (which must be a subset of `posteriorVertices`) from the posterior
    '''
    
def __repr__():                      
    return 'Posterior distribution module'