#!/usr/bin/python3

import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import json


def sigma_func(v):
    '''Activation function

    @param {float|np.ndarray} v input arguments
    @return {float|np.ndarray} component wise value of the
                               activation function

    '''
    return 1.0/(1 + np.exp(-v))


def sigma_func_diff(v):
    '''Derivative of the model activation function

    @param {float|np.ndarray} v input arguments
    @return {float|np.ndarray} component-wise value of the
                               activiation function

    '''

    sig = sigma_func(v)
    return np.exp(-v) * sig**2


def h_func(vec):
    '''First layer activiation function

    @param {np.ndarray} vec input argument
    @return {np.ndarray} component-wise value of the activation
                         function

    '''
    return np.tanh(vec)


def grad_y_W1(x, W1, w10, W2, w20):
    '''Gradient of model output with respect to W1

    W1 is the first layer model weights

    @param {np.ndarray} x D-dimensional point to compute derivative
    @param {np.ndarray} W1 MxD dimensional first layer model weights
    @param {np.ndarray} w10 M dimensional first layer model biases
    @param {np.ndarray} W2 KxM dimensional second layer model weights
    @param {np.ndarray} w20 K dimensional second layer model biases
    @return {np.ndarray} KxMxD dimensional array denoting
                         derivative of the k-th wrt M,D component

    '''

    M, D = W1.shape
    K, _ = W2.shape
    ret = np.zeros((K, M, D), np.float64)
    a_vec = W1.dot(x) + w10
    a_out = W2.dot(h_func(a_vec)) + w20
    sod = sigma_func_diff(a_out)

    for k in range(K):
        for m in range(M):
            for d in range(D):
                ret[k,m,d] += sod[k] * (W2[k,m] * (1-h_func(a_vec[m])**2) * x[d])
                # for a in a_vec:
                #     ret[k,m,d] += sod[k] * (W2[k,m] * (1-h_func(a)**2) * x[d])
    return ret


def grad_y_w10(x, W1, w10, W2, w20):
    '''Gradient of model output with respect to w10

    w10 is the first layer model biases

    @param {np.ndarray} x D-dimensional point to compute derivative
    @param {np.ndarray} W1 MxD dimensional first layer model weights
    @param {np.ndarray} w10 M dimensional first layer model biases
    @param {np.ndarray} W2 KxM dimensional second layer model weights
    @param {np.ndarray} w20 K dimensional second layer model biases
    @return {np.ndarray} KxM dimensional array denoting
                         derivative of the k-th wrt m-th component

    '''

    M, D = W1.shape
    K, _ = W2.shape
    ret = np.zeros((K, M), np.float64)
    a_vec = W1.dot(x) + w10
    a_out = W2.dot(h_func(a_vec)) + w20
    sod = sigma_func_diff(a_out)

    for k in range(K):
        for m in range(M):
            ret[k,m] += sod[k] * (W2[k,m]*(1-h_func(a_vec[m])**2))
    return ret


def grad_y_W2(x, W1, w10, W2, w20):
    '''Gradient of model output with respect to W1

    W1 is the second layer model weights

    @param {np.ndarray} x D-dimensional point to compute derivative
    @param {np.ndarray} W1 MxD dimensional first layer model weights
    @param {np.ndarray} w10 M dimensional first layer model biases
    @param {np.ndarray} W2 KxM dimensional second layer model weights
    @param {np.ndarray} w20 K dimensional second layer model biases
    @return {np.ndarray} KxKxM dimensional array denoting
                         derivative of the k-th wrt k,m component of
                         W2

    '''

    M, D = W1.shape
    K, _ = W2.shape
    ret = np.zeros((K, K, M), np.float64)
    a_vec = W1.dot(x) + w10
    a_out = W2.dot(h_func(a_vec)) + w20
    sod = sigma_func_diff(a_out)
    for k in range(K):
        for n in range(M):
            ret[k, k, n] = sod[k] * h_func(a_vec[n])
    return ret


def grad_y_w20(x, W1, w10, W2, w20):
    '''Gradient of model output with respect to w10

    w20 is the second layer model biases

    @param {np.ndarray} x D-dimensional point to compute derivative
    @param {np.ndarray} W1 MxD dimensional first layer model weights
    @param {np.ndarray} w10 M dimensional first layer model biases
    @param {np.ndarray} W2 KxM dimensional second layer model weights
    @param {np.ndarray} w20 K dimensional second layer model biases
    @return {np.ndarray} KxK dimensional array denoting
                         derivative of the k-th wrt k-th component

    '''

    K, _ = W2.shape
    a_vec = W1.dot(x) + w10
    a_out = W2.dot(h_func(a_vec)) + w20
    sod = sigma_func_diff(a_out)
    return sod * np.eye(K)


def grad_y_w(x, W1, w10, W2, w20):
    '''Gradient of model output with respect to model weights

    @param {np.ndarray} x D-dimensional point to compute derivative
    @param {np.ndarray} W1 MxD dimensional first layer model weights
    @param {np.ndarray} w10 M dimensional first layer model biases
    @param {np.ndarray} W2 KxM dimensional second layer model weights
    @param {np.ndarray} w20 K dimensional second layer model biases
    @return {np.ndarray} KxMxD array; derivatives wrt W1
            {np.ndarray} KxKxM array; derivatives wrt W2
            {np.ndarray} KxM array; derivatives wrt w10
            {np.ndarray} KxK array; derivatives wrt w20

    '''

    gw1 = grad_y_W1(x, W1, w10, W2, w20)
    gw2 = grad_y_W2(x, W1, w10, W2, w20)
    gw10 = grad_y_w10(x, W1, w10, W2, w20)
    gw20 = grad_y_w20(x, W1, w10, W2, w20)
    return gw1, gw2, gw10, gw20


def y_vec(x, W1, w10, W2, w20):
    '''Model outputs for each class

    @param {np.ndarray} x D-dimensional point to compute derivative
    @param {np.ndarray} W1 MxD dimensional first layer model weights
    @param {np.ndarray} w10 M dimensional first layer model biases
    @param {np.ndarray} W2 KxM dimensional second layer model weights
    @param {np.ndarray} w20 K dimensional second layer model biases
    @return {np.ndarray} K dim array of model outputs

    '''

    ak = h_func(W1.dot(x) + w10)
    a  = W2.dot(ak) + w20
    return sigma_func(a)


def flatten_grad(grade_W1, grade_w10, grade_W2, grade_w20):
    return np.concatenate((grade_W1.flatten(),
                           grade_w10.flatten(),
                           grade_W2.flatten(),
                           grade_w20.flatten()))


def flatten_w(W1, w10, W2, w20):
    return np.concatenate((W1.flatten(),
                           w10.flatten(),
                           W2.flatten(),
                           w20.flatten()))


def unflatten_w(W, M, D, K):
    W1  = W[              : M*D              ].reshape(M, D)
    w10 = W[M*D           : M*D + M          ].reshape(M)
    W2  = W[M*D + M       : M*D + M + K*M    ].reshape(K, M)
    w20 = W[M*D + M + K*M :                  ].reshape(K)
    return W1, w10, W2, w20


def err(dataset, W, M, D):
    K = len(dataset)
    W1, w10, W2, w20 = unflatten_w(W, M, D, K)
    ret = 0
    for class_id, dset in enumerate(dataset):
        for x in zip(*dset):
            y = y_vec(x, W1, w10, W2, w20)
            ret -= np.log(y[class_id])
    return ret


def backtrack(f_func, grad_f, x, dx, alpha=0.1, beta=0.3, max_iter=20):
    '''Simple backtrack function

    If number of iterations exceed `max_iter`, returns last iteration value

    @param {func} f_func callable function being minimized
    @param {np.array} grad_f f_func gradient at 'x'
    @param {float} x current solution point
    @param {np.array} dx descent direction
    @param {float} alpha number in the open interval (0, 0.5)
    @param {float} beta number in the open interval (0, 1.0)
    @param {int} max_iter maximum number of iterations to perform
    @return {float} search scale factor (step length)

    '''
    t = 1.0
    iter_count = 1
    while (f_func(x + t*dx) > f_func(x) + alpha*t*np.dot(grad_f, dx)):
        t *= beta
        iter_count += 1
        if ((iter_count-1) > max_iter):
            return t
    return t


def grad_err(datasets, W1, w10, W2, w20):
    M, D = W1.shape
    K, _ = W2.shape
    # gradient wrt W2
    grad_e_W1 = np.zeros((M,D), np.float64)
    grad_e_W2 = np.zeros((K,M), np.float64)
    grad_e_w10 = np.zeros((M), np.float64)
    grad_e_w20 = np.zeros((K), np.float64)
    for class_id, dset in enumerate(datasets):
        for x in zip(*dset):
            gyw1, gyw2, gyw10, gyw20 = grad_y_w(x, W1, w10, W2, w20)
            y = y_vec(x, W1, w10, W2, w20)
            # grad wrt w10
            for m in range(M):
                grad_e_w10[m] -= 1.0/y[class_id] * gyw10[class_id, m]

            # gradient wrt w20
            for k in range(K):
                grad_e_w20[k] -= 1.0/y[class_id] * gyw20[class_id, k]

            # gradient wrt W2
            for k in range(K):
                for m in range(M):
                    grad_e_W2[k, m] -= 1.0/y[class_id] * gyw2[class_id, k, m]

            # Gradient wrt W1
            for m in range(M):
                for d in range(D):
                    grad_e_W1[m, d] -= 1.0/y[class_id] * gyw1[class_id, m, d]

    return grad_e_W1, grad_e_w10, grad_e_W2, grad_e_w20


def train(datasets, W1, w10, W2, w20, max_iter=20):
    '''TODO

    @param {np.array1d} w1 optional initial guess for w1
    @param {np.array1d} w2 optional initial guess for w2
    @param {np.array1d} x training observation points
    @param {np.array1d} t training observation values
    @param {int} max_iter maximum number of iterations performed
    @param {func} iter_cb call back called with (iter_count, w1, w2)
           in each iteration
    @return {np.array1d} updated w1 values
            {np.array1d} updated w2 values

    '''

    # Objective function
    M, D = W1.shape
    F = lambda W : err(datasets, W, M, D)

    GRAD_TOL = 1E-8
    F_TOL = 1E-8

    W = flatten_w(W1, w10, W2, w20)
    K = len(datasets)

    history = []
    # Minimize the square of the L2 error norm
    for i in range(max_iter):
        # gradient of minimizing function
        grad_e_W1, grad_e_w10, grad_e_W2, grad_e_w20 = grad_err(datasets, W1, w10, W2, w20)
        gradW = flatten_grad(grad_e_W1, grad_e_w10, grad_e_W2, grad_e_w20)

        # Descent direction
        dW = -1 * gradW

        t_ = backtrack(F, gradW, W, dW, 0.3, 0.1)
        if (t_ < 1E-10):
            t_ = 0.123

        W += t_*dW

        W1, w10, W2, w20 = unflatten_w(W, M, D, K)
        grad_e_W1, grad_e_w10, grad_e_W2, grad_e_w20 = grad_err(datasets, W1, w10, W2, w20)
        gradW = flatten_grad(grad_e_W1, grad_e_w10, grad_e_W2, grad_e_w20)
        gradW_norm = np.linalg.norm(gradW)
        F_norm = np.linalg.norm(F(W))

        history.append(
            {
                'W1':W1.copy().tolist(),
                'w10': w10.copy().tolist(),
                'W2': W2.copy().tolist(),
                'w20': w20.copy().tolist(),
                'f_norm': F_norm,
                'gradw_norm': gradW_norm
            }
        )
        print(i, t_, F_norm, gradW_norm)

        if ((gradW_norm < GRAD_TOL) or (F_norm < F_TOL)):
            break

    return W1, w10, W2, w20, history


def create_dataset(point_count,
                   mean = [0, 0],
                   variance = 1,
                   seed = 0):
    '''Create a 2D data cluster with specified mean and variance

    @param {int} point_ount number of points
    @param {list} mean of the dataset
    @param {float} diagonal variance value
    @param {int} seed random seed number for reproducible results
    @return {list} List of (x, y) tuples

    '''
    RSTATE = np.random.RandomState(seed)
    cov = np.identity(2) * variance
    dset = []
    x, y = RSTATE.multivariate_normal(mean, cov, point_count).T
    return x, y


def create_clusters(counts, means, seeds):
    '''Create random and reproducible set of clusters with `counts` points
       each

    @param {int} number of points per dataset
    @param {list} means list of generated class means
    @param {list} list of seeds used to generate random number
    @return [list{[x,y]}] List of (x,y) coordinate vectors

    '''
    return [create_dataset(counts, mean, 1, seed) for mean, seed in zip(means, seeds)]


def classify_grid(xgrid, ygrid, W1, w10, W2, w20):
    '''Return a 2D classified grid

    @param {np.ndarray} xgrid x values on a grid
    @param {np.ndarray} ygrid y values on a grid
    @param {np.ndarray} W1 first layer model weights
    @param {np.ndarray} w10 first layer model biases
    @param {np.ndarray} W2 second layer model weights
    @param {np.ndarray} w20 second layer model biases
    @return {np.ndarray} 2D grid with classification values

    '''

    ret = np.zeros((len(ygrid), len(xgrid)))
    for i, y in enumerate(ygrid):
        for j, x in enumerate(xgrid):
            yv = y_vec(np.array([x, y]), W1, w10, W2, w20)
            ret[i, j] = np.argmax(yv)
    return ret


def update_plot(frame, training_history, x_plot, y_plot, fig_data):
    '''FuncAnimation plot update function

    Add new random data to the plot and updates plots

    @param {int} frame frame number being rendered
    @param {list} history.W1 first layer model weights
           {list} history.w10 first layer model biases
           {list} history.W2 second layer model weights
           {list} history.w20 second layer model biases
    @param {np.1darray} x_plot x range of values to eval the model estimate
    @param {np.1darray} y_plot y range of values to eval the model estimate
    @param {matplotlib.fig} fig_data.fig parent figure object
           {matplotlib.Line2D} fig_data.scatter_plot data scatter plot obj
           {matplotlib.Line2D} fig_data.mean_plot model mean plot obj
           {matplotlib.Line2D} fig_data.fill_plot shaded region of std obj

    '''
    print ('frame {}  '.format(frame), end='\r')
    # fetch new pseudo random data
    row = training_history[frame]
    W1  = np.array(row['W1'])
    w10 = np.array(row['w10'])
    W2  = np.array(row['W2'])
    w20 = np.array(row['w20'])
    err = row['f_norm']
    grad_err = row['gradw_norm']

    ax = fig_data['ax']
    ax.set_title(r'$E$: {:.3e}, $||\nabla E||$: {:.3e}'.format(err, grad_err))

    if 'cplot' in fig_data:
        cplot = fig_data['cplot']
        for col in cplot.collections:
            fig_data['ax'].collections.remove(col)

    X, Y = np.meshgrid(x_plot, y_plot)
    Z = classify_grid(x_plot, y_plot, W1, w10, W2, w20)
    fig_data['cplot'] = ax.contourf(X, Y, Z, alpha=0.8)
    return []


def plot_training_iterations(dataset, training_history):
    xp = np.linspace(-5, 5)
    yp = np.linspace(-5, 5)
    fig, ax = plt.subplots()
    ax.grid(True)
    ax.set_xlabel('$x_1$')
    ax.set_ylabel('$x_2$')
    ax.set_title(r'')

    set_count = len(dataset)
    row = training_history[0]
    W1  = np.array(row['W1'])
    w10 = np.array(row['w10'])
    W2  = np.array(row['W2'])
    w20 = np.array(row['w20'])

    for class_id, dset in enumerate(dataset):
        x, y = dset
        ax.plot(x, y,
                marker='o', linestyle='',
                label='Set {}'.format(class_id + 1),
        )

    ax.legend(loc='best', fancybox=True, framealpha=0.5)
    ax.relim()
    ax.set_xlim([-5, 5])
    ax.set_ylim([-5, 5])

    fig_data = {
        'fig': fig,
        'ax': ax
    }

    frame_count = len(training_history)
    Writer = animation.writers['ffmpeg']
    writer = Writer(fps=15, metadata=dict(artist='Me'), bitrate=1800)
    c_ani = animation.FuncAnimation(fig, update_plot, frame_count,
                                    fargs=(training_history,
                                           xp, yp,
                                           fig_data),
                                    interval=80, blit=True)
    return c_ani, writer


def main():
    D = 2
    M = 8
    K = 5

    RSTATE = np.random.RandomState(9)

    W1  = RSTATE.uniform(-1, 1, (M,D))
    W2  = RSTATE.uniform(-1, 1, (K,M))
    w20 = RSTATE.uniform(-1, 1, K)
    w10 = RSTATE.uniform(-1, 1, M)

    means = [[-3, +3],
             [+3, +3],
             [+3, -3],
             [-3, -3],
             [+0, -0]]
    seeds = [1, 4, 8, 9, 10]
    datasets = create_clusters(20, means, seeds)

    W1, w10, W2, w20, history = train(datasets, W1, w10, W2, w20, max_iter=200)

    __dirname = os.path.dirname(os.path.realpath(__file__))
    fn = os.path.join(__dirname, '../video/neural-networks-classification.mp4')
    c_ani, writer = plot_training_iterations(datasets, history)
    c_ani.save(fn, writer=writer)


if __name__ == '__main__':
    main()