#!/usr/bin/python3

import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation


def create_dataset(point_count,
                   mean = [0, 0],
                   variance = 1,
                   seed = 0):
    '''Create a 2D data cluster with specified mean and variance

    @param {int} point_ount number of points
    @param {list} mean of the dataset
    @param {float} diagonal variance value
    @param {int} seed random seed number for reproducible results
    @return {list} List of (x, y) tuples

    '''
    RSTATE = np.random.RandomState(seed)
    cov = np.identity(2) * variance
    dset = []
    x, y = RSTATE.multivariate_normal(mean, cov, point_count).T
    return x, y


def create_clusters(counts, means, seeds):
    '''Create random and reproducible set of clusters with `counts` points
       each

    @param {int} number of points per dataset
    @param {list} means list of generated class means
    @param {list} list of seeds used to generate random number
    @return [list{[x,y]}] List of (x,y) coordinate vectors

    '''
    return [create_dataset(counts, mean, 1, seed) for mean, seed in zip(means, seeds)]


def gaussian(x, mu, sigma2):
    '''Gaussian function

    @param {float|np.ndarray} x gaussian evaluation location
    @param {float|np.ndarray} mu gaussian mean
    @param {float|np.ndarray} sigma2 gaussian variance
    @return {function} gaussian value at x
    '''

    if (isinstance(sigma2, float) or isinstance(sigma2, int)):
        coeff = 1.0/np.sqrt(2*np.pi*sigma2)
        return coeff*np.exp(-(x-mu)**2/(2*sigma2))

    D = sigma2.size
    detSigma = np.linalg.det(sigma2)
    coeff = 1.0/(2*np.pi)**(D/2.0) * detSigma**0.5
    sigmaInverse = np.linalg.inv(sigma2)
    return coeff*np.exp(-0.5*
                        np.inner(
                            (x-mu),
                            np.dot(sigmaInverse, x-mu),
                        ))


def flatten_clusters(clusters):
    '''Return an iterator of each point in the dataset

    @param {list} clusters list of [x,y] coordinates for each
                  cluster. Treated as incompleted dataset
    @return {iterator} x,y coordinates of points in dataset
    '''

    for class_id, dset in enumerate(clusters):
        for x in zip(*dset):
            yield x


def cluster_point_count(clusters):
    '''Return the total number of points in dataset

    @param {list} clusters list of [x,y] coordinates for each
                  cluster. Treated as incompleted dataset
    @return {int} total number of points in cluster

    '''
    N = 0
    for c in clusters:
        N += len(c[0])
    return N


def e_step(clusters, pis, mus, sigmas):
    '''(E)xpectation step

    Return the expected value of the responsibilities

    @param {list} clusters list of [x,y] coordinates for each
                  cluster. Treated as incompleted dataset
    @param {list} pis mixture coefficients
    @param {list} mus gaussian means
    @param {list} sigmas gaussian variances
    @return {np.ndarray} NxK indicator

    '''

    # Note: z_{nk} == indicator for nth point of class k
    # z_k \in {0, 1}
    N = cluster_point_count(clusters)
    K = len(pis)
    ret = np.zeros((N, K))
    n = 0
    for n, x in enumerate(flatten_clusters(clusters)):
        den = 0
        for j in range(K):
            den += pis[j] * gaussian(np.array(x), mus[j], sigmas[j])
        for k in range(K):
            num = pis[k] * gaussian(np.array(x), mus[k], sigmas[k])
            ret[n][k] = num/den
    return ret


def m_step(clusters, pis, mus, sigmas, gammas):
    '''(M)aximization step

    Return mixing coefficients, means, variances


    @param {list} clusters list of [x,y] coordinates for each
                  cluster. Treated as incompleted dataset
    @param {list} pis mixture coefficients
    @param {list} mus gaussian means
    @param {list} sigmas gaussian variances
    @param {np.ndarray} gammas NxK indicator
    @return {list} updated mixing coefficiens
            {list} updated means
            {list} updated variances

    '''

    K = len(pis)
    N = cluster_point_count(clusters)
    pis_new, mus_new, sigmas_new = [], [], []
    for k in range(K):
        Nk = np.sum(gammas[:, k])
        # mixing coefficients
        pis_new.append(Nk/float(N))

        # means
        muk = np.zeros(2)
        for n, x in enumerate(flatten_clusters(clusters)):
            muk += 1/Nk * gammas[n][k] * np.array(x)
        mus_new.append(muk)

        sigmak = np.zeros((2,2))
        for n, x in enumerate(flatten_clusters(clusters)):
            sigmak += 1.0/Nk * gammas[n][k]*np.outer((np.array(x)-muk),
                                                     (np.array(x)-muk))
            sigmas_new.append(sigmak)

    return pis_new, mus_new, sigmas_new


def gauss_mixture_grid(xp, yp, pis, mus, sigmas):
    '''Evaluate the Gaussian Mixture model over a grid

    @param {np.ndarray} xp x values on a grid
    @param {np.ndarray} yp y values on a grid
    @param {list} pis mixture coefficients
    @param {list} mus gaussian means
    @param {list} sigmas gaussian variances
    @return {np.ndarray} 2D grid with gaussian mixture model
                         values at [y, x]

    '''

    Z = np.zeros((len(yp), len(xp)))
    K = len(pis)
    for i, y in enumerate(yp):
        for j, x in enumerate(xp):
            val = 0
            for k in range(K):
                val += pis[k] * gaussian(np.array([x, y]), mus[k], sigmas[k])
            Z[i, j] = val
    return Z


def update_plot(frame, history, x_plot, y_plot, fig_data):
    '''FuncAnimation plot update function

    Add new random data to the plot and updates plots

    @param {int} frame frame number being rendered
    @param {list} history.mus means of the clusters
           {list} history.sigmas variances of the clusters
           {list} history.pis mixture coefficients
    @param {np.1darray} x_plot x-range of values to plot mixture model
    @param {np.1darray} y_plot y-range of values to plot mixture model
    @param {matplotlib.fig} fig_data.ax plot axis object
           {[matplotlib.Line2D]} fig_data.means mean scatter plot objs

    '''
    print ('frame {}  '.format(frame), end='\r')
    # fetch new pseudo random data
    row = history[frame]
    pis  = np.array(row['pis'])
    mus  = np.array(row['mus'])
    sigmas = np.array(row['sigmas'])

    ax = fig_data['ax']
    ax.set_title(r'Iteration {}'.format(frame))

    mean_plot = fig_data['means']
    for d, p in zip(mus, mean_plot):
        p.set_data(d)

    # Remove existing contour plot to speed up replotting
    if 'cplot' in fig_data:
        cplot = fig_data['cplot']
        for col in cplot.collections:
            col.remove()

    X, Y = np.meshgrid(x_plot, y_plot)
    Z = gauss_mixture_grid(x_plot, y_plot, pis, mus, sigmas)
    fig_data['cplot'] = ax.contourf(X, Y, Z)#, alpha=0.8)
    cmin_max = [np.min(Z), np.max(Z)]
    fig_data['cbar_map'].set_clim(cmin_max[0], cmin_max[1])

    return []


def plot_em_iterations(data_sets, history):
    '''Plot Expectation Maximization iterations

    @param {list} clusters list of [x,y] coordinates for each
                  cluster. Treated as incomplete dataset
    @param {list} history.mus means of the clusters
           {list} history.sigmas variances of the clusters
           {list} history.pis mixture coefficients
    @return {animation.FuncAnimation} animation object
            {animation.writers} writer object

    '''

    frame_count = len(history)
    Writer = animation.writers['ffmpeg']
    writer = Writer(fps=15, metadata=dict(artist='Me'), bitrate=1800)

    fig, ax = plt.subplots()
    ax.grid(True)
    ax.set_xlabel('$x_1$')
    ax.set_ylabel('$x_2$')

    # plot clusters as incomplete dataset (uniform color)
    for cid, dset in enumerate(data_sets):
        x, y = dset
        ax.plot(x, y,
                linestyle='',
                marker='o', markersize=4,
                color='c')

    # means
    row = history[0]
    means = row['mus']
    means_plot = []
    for m in means:
        p, = ax.plot(m, marker='x', markersize=8, markeredgewidth=4)
        means_plot.append(p)

    xr = ax.get_xlim()
    yr = ax.get_ylim()
    xp = np.linspace(xr[0], xr[1], 50)
    yp = np.linspace(yr[0], yr[1], 50)

    fig_data = {
        'ax': ax,
        'fig': fig,
        'means': means_plot
    }

    X, Y = np.meshgrid(xp, yp)
    Z = gauss_mixture_grid(xp, yp, row['pis'], means, row['sigmas'])
    fig_data['cplot'] = ax.contourf(X, Y, Z, alpha=0.8)
    cbar_map = plt.cm.ScalarMappable()
    cbar_map.set_clim(np.min(Z), np.max(Z))
    cbar = fig_data['fig'].colorbar(cbar_map, ax=ax, format='%.0e')
    fig_data['cbar_map'] = cbar_map

    c_ani = animation.FuncAnimation(fig, update_plot, frame_count,
                                    fargs=(history,
                                           xp, yp,
                                           fig_data),
                                    interval=20, blit=True)
    return c_ani, writer


def main():
    means = [[-3, +3],
             [+3, +3],
             [+3, -3],
             [-3, -3],
             [+0, -0]]
    seeds = [1, 4, 8, 9, 10]
    clusters = create_clusters(20, means, seeds)
    K = len(clusters)

    pis = [1.0/K] * K
    mus = []
    for i in range(K):
        RSTATE = np.random.RandomState(i)
        mus.append(RSTATE.uniform(-5,5, 2))

    sigmas = [np.eye(2)] * K

    history = []
    history.append({
        'mus': mus,
        'pis': pis,
        'sigmas': sigmas
    })
    for i in range(100):
        gammas = e_step(clusters, pis, mus, sigmas)
        pis, mus, sigmas = m_step(clusters, pis, mus, sigmas, gammas)
        history.append({
            'mus': mus,
            'pis': pis,
            'sigmas': sigmas
        })

    c_ani, writer = plot_em_iterations(clusters, history)
    __dirname = os.path.dirname(os.path.realpath(__file__))
    fn = os.path.join(__dirname, '../video/cluster-gaussian-mixture.mp4')
    c_ani.save(fn, writer=writer)


if __name__ == '__main__':
    main()