#!/usr/bin/python3

import os
import matplotlib.pyplot as plt
import numpy as np
import functools as ft
import matplotlib.animation as animation



# Default colors used for plotting data points
COLORS = ['#1f77b4', '#ff7f03', '#2ca02c',
          '#d62728', '#9467bd', '#8c564b']


def phi(x):
    '''Basis function evaluated at given point

    @param x {array} data point
    @return basis {numpy.array} function value at `x`
    '''

    # First element is 1 for bias term, w0
    return np.array([1, x[0], x[1]], np.float64)


def w_class(w_vecs, class_index, block_size):
    '''Convenience function to get discriminate for given class

    Input vector layout is (should be)
       [w0_0, w0_1, w0_2,
        w1_0, w1_1, w1_2,
        w2_0, w2_1, w3_2, ...]

    where e.g. `w1_0` is the first component of the second
    discriminant vector

    @param w_vecs {numpy.array} vectorized dicriminants vector for all
           classes
    @param class_index {integer} class index
    @param block_size {integer} block size of vectorized
           components. Must match size number of basis functions. @see
           phi
    @return `class_index` {numpy.array} discriminant vectors

    '''
    i = class_index * block_size
    return np.array([w_vecs[i], w_vecs[i+1], w_vecs[i+2]], np.float64)


def y_func(w_vecs, x, class_index):
    '''Posterior conditional probability of a class. See Eq. 4.104

    p(C_i| phi(x)) = exp(w_i * phi(x)) / sum( exp(w_j * phi) )

    @param w_vecs {np.array} vectorized dicriminants vector for all
           classes
    @param x {array} evaluation point of the posterior probability
    @param class_index {integer} class id
    @return {float} Posterior conditional probablity of `class_index`
           at `x`
    '''
    num, den = 0.0, 0.0
    BASIS_COUNT = 3
    N = int(len(w_vecs)/BASIS_COUNT)
    for i in range(N):
        a = np.dot(w_class(w_vecs, i, BASIS_COUNT), phi(x))
        den += np.exp(a)

    a = np.dot(w_class(w_vecs, class_index, BASIS_COUNT), phi(x))
    num = np.exp(a)
    return num/den


def error(datasets, w_vecs):
    '''Error metric

    input data set is expected to be in following format:
    { <class_index>: [<data_points>], ... }

    Input vector layout is (should be)
       [w0_0, w0_1, w0_2,
        w1_0, w1_1, w1_2,
        w2_0, w2_1, w3_2, ...]

    @param data {dict} input data set
    @param w_vecs {np.array} vectorized dicriminants vector for all
           classes
    @return {float} error for given data set and model weights
    '''
    ret = 0
    TOL = 1E-5
    for class_id, dset in enumerate(datasets):
        for p in zip(*dset):
            y = y_func(w_vecs, p, class_id)
            ret -= np.log(np.maximum(y, TOL))
    return ret


def grad_error(datasets, w_vecs):
    '''input data set is expected to be in following format:
    { <class_index>: [<data_points>], ... }

    Input vector layout is (should be)
       [w0_0, w0_1, w0_2,
        w1_0, w1_1, w1_2,
        w2_0, w2_1, w3_2, ...]

    Return layout is
       [d/dw0_0, d/dw0_1, d/dw0_2,
        d/dw1_0, d/dw1_1, d/dw1_2,
        d/dw2_0, d/dw2_1, d/dw2_2, ...]

    @param data {dict} input data, see above
    @param w_vecs {np.array} vectorized dicriminants vector for all
           classes. See description
    @return {numpy.array} vectorized gradient of error wrt w. see
          above
    '''

    K = len(datasets)
    ret = []
    for k in range(K):
        # gradient wrt {w0, w1, w2}, for given set
        v = np.zeros(3)
        for class_id, dset in enumerate(datasets):
            for p in zip(*dset):
                y = y_func(w_vecs, p, k)
                t = (int)(k == class_id)
                v += (y-t) * phi(p)
        for _v in v:
            ret.append(_v)
    return np.array(ret, np.float64)


def backtrack(f_func, grad_f, x, dx, alpha=0.1, beta=0.3, max_iter=20):
    '''Suggested step size scale for gradient descent

    @param f_func {function} callable objective function being minimized
    @param grad_f {np.array} f_func gradient at `x`
    @param x {np.array} current solution point
    @param dx {np.array} descent direction
    @param alpha {float} number in the open interval (0, 0.5)
    @param beta {float} number in the open interval (0, 1.0)
    @param max_iter {integer} maximum number of iterations to perform
    @return {float} step scale for next iteration
    '''
    t = 1.0
    iter_count = 1
    f_n1 = f_func(x + t*dx)
    while ((f_n1 > (f_func(x) + alpha*t*np.dot(grad_f, dx))) or np.isnan(f_n1)):
        t *= beta
        iter_count += 1
        f_n1 = f_func(x + t*dx)
        if ((iter_count-1) > max_iter):
            print('warning: backtrack did not converge')
            return t
    return t


def classify(W, x, region_count):
    '''Classify a given point based

    @param w_vecs {np.array} vectorized model dicriminants vector for
           all classes
    @param x {array} point being classified
    @param region_count {integer} total number of regions in model
    '''

    # 'Likelihood' of belonging to each region
    probs = [y_func(W, x, i) for i in range(region_count)]
    # Return the region with highest 'likelihood'
    return np.argmax(probs)


def plot_class_regions(ax, W, region_count):
    # Generate a grid over the available plot region
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    x = np.linspace(xlim[0], xlim[1], 70)
    y = np.linspace(ylim[0], ylim[1], 70)
    X, Y = np.meshgrid(x, y)

    # Fill class region for each grid point
    Z = np.empty_like(X)
    M = Y.shape[1]
    N = X.shape[0]
    for i in range(M):
        for j in range(N):
            Z[i, j] = classify(W, [x[j], y[i]], region_count)

    cplot = ax.contourf(X, Y, Z, alpha=0.8)
    return cplot


def plot_data_points(ax, datasets):
    for class_id, dset in enumerate(datasets):
        x, y = dset
        ax.plot(x, y,
                marker='o', linestyle='',
                label='set {}'.format(class_id),
                color=COLORS[class_id])


def poly(x, w):
    ret = []
    for x_ in x:
        v = 0
        for i, a in enumerate(w):
            v += a * x_**i
        ret.append(v)
    return ret


def create_dataset(point_count, mean = [0, 0],
                   variance = 1, seed = 0):
    '''Create a 2D data cluster with specified mean and variance

    @param {int} point_ount number of points
    @param {list} mean of the dataset
    @param {float} diagonal variance value
    @param {int} seed random seed number for reproducible results
    @return {list} List of (x, y) tuples

    '''
    RSTATE = np.random.RandomState(seed)
    cov = np.identity(2) * variance
    dset = []
    x, y = RSTATE.multivariate_normal(mean, cov, point_count).T
    return x, y


def create_clusters(counts, means, seeds):
    '''Create random and reproducible set of clusters with `counts` points
       each

    @param {int} number of points per dataset
    @param {list} means list of generated class means
    @param {list} list of seeds used to generate random number
    @return [list{[x,y]}] List of (x,y) coordinate vectors

    '''
    return [create_dataset(counts, mean=mean, seed=seed) for mean, seed in zip(means, seeds)]


def train(datasets, W):
    TOL = 1E-5
    iter_num = 1
    MAX_ITER = 100

    # Error function for given model weights
    Err = lambda w_vec : error(datasets, w_vec)

    grad_err_norm = np.linalg.norm(grad_error(datasets, W))
    iteration_history = []

    while (grad_err_norm > TOL):
        iter_num += 1
        grad_err = grad_error(datasets, W)
        # Descent direction
        dW = -1*grad_err
        t = backtrack(Err, grad_err, W, dW, 0.4, 0.9)
        W += t*dW
        grad_err_norm = np.linalg.norm(grad_error(datasets, W))
        err = Err(W)
        print('{}: {} ... {}'.format(iter_num, err,  grad_err_norm))
        iteration_history.append([W.copy(), err, grad_err_norm])
        if (iter_num > MAX_ITER):
            break
    return W, iteration_history


def class_discriminant(x, W, class_id):
    w0 = W[class_id*3]
    w1 = W[class_id*3 + 1]
    w2 = W[class_id*3 + 2]
    return poly(x, [-w0/w2, -w1/w2])


def update_plot(frame, training_history, x_plot, fig_data):
    '''FuncAnimation plot update function

    Add new random data to the plot and updates plots

    @param {int} frame frame number being rendered
    @param {np.1darray} data.x observation points. Updated after this call
           {np.1darray} data.t observation values. Updated after this call
    @param {np.1darray} x_plot range of values to eval the model estimate
    @param {float} alpha weight parameter prior variance
    @param {matplotlib.fig} fig_data.fig parent figure object
           {matplotlib.Line2D} fig_data.scatter_plot data scatter plot obj
           {matplotlib.Line2D} fig_data.mean_plot model mean plot obj
           {matplotlib.Line2D} fig_data.fill_plot shaded region of std obj

    '''
    print ('frame {}  '.format(frame), end='\r')
    # fetch new pseudo random data
    training_row = training_history[frame]
    W = training_row[0]
    err = training_row[1]
    grad_err = training_row[2]

    discriminants = fig_data['discriminants']
    for i, lp in enumerate(discriminants):
        lp.set_data(x_plot, class_discriminant(x_plot, W, i))

    ax = fig_data['ax']
    ax.set_title(r'$E$: {:.3e}, $\nabla E$: {:.3e}'.format(err, grad_err))

    if 'cplot' in fig_data:
        cplot = fig_data['cplot']
        for col in cplot.collections:
            fig_data['ax'].collections.remove(col)
        # set_count = len(discriminants)
        # fig_data['cplot'].remove()

    set_count = len(discriminants)
    fig_data['cplot'] = plot_class_regions(ax, W, set_count)

    # if (frame == (len(training_history)-1)):
    #     plot_class_regions(ax, W, set_count)

    return []


def plot_training_iterations(dataset, training_history, filename):
    x_plot = [-12, 12]
    fig, ax = plt.subplots()
    ax.grid(True)
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
    ax.set_title(r'')

    set_count = len(dataset)
    discriminants = []
    W = training_history[0][0]
    for i in range(set_count):
        lp, = ax.plot(x_plot, class_discriminant(x_plot, W, i))
        discriminants.append(lp)

    for class_id, dset in enumerate(dataset):
        x, y = dset
        ax.plot(x, y,
                marker='o', linestyle='',
                label='Set {}'.format(class_id + 1),
                color=COLORS[class_id])

    ax.legend(loc='best', fancybox=True, framealpha=0.5)
    ax.relim()
    ax.set_xlim([-12, 12])
    ax.set_ylim([-12, 12])

    fig_data = {
        'discriminants': discriminants,
        'fig': fig,
        'ax': ax
    }

    frame_count = len(training_history)
    Writer = animation.writers['ffmpeg']
    writer = Writer(fps=15, metadata=dict(artist='Me'), bitrate=1800)
    c_ani = animation.FuncAnimation(fig, update_plot, frame_count,
                                    fargs=(training_history,
                                           x_plot,
                                           fig_data),
                                    interval=50, blit=True)
    c_ani.save(filename, writer=writer)


def main():
    means = [[+5, +5],
             [-5, -5],
             [-5, +5],
             [-2, -1]]
    seeds = [1, 2, 3, 4]
    data = create_clusters(30, means, seeds)
    SET_COUNT = len(data)

    # Initial guess
    W = np.array([
        1, .1,  1,              # first data set  (w0, w1, w2)
        1,  1,  1,              # second data set (w0, w1, w2)
        1,  1, .1,              # third data set  (w0, w1, w2)
        1, .1, .2,              # fourth data set (w0, w1, w2)
    ], np.float64)

    W, train_history = train(data, W)

    __dirname = os.path.dirname(os.path.realpath(__file__))
    fn = os.path.join(__dirname, '../video/logistic-regression-bt.mp4')
    plot_training_iterations(data, train_history, fn)


if __name__ == '__main__':
    main()