#!/usr/bin/python3


import os
import numpy as np
import matplotlib.pyplot as plt


def noisy_sin_sample(max_x=2*np.pi, count=15,
                     freq=1.0, amp=1.0,
                     noise_std=0.2, noise_mean=0):
    '''Generate a randomly data set resembling real data


    @param {float} max_x maximum value on x axis
    @param {int} count maximum number of points
    @param {float} frequency of the target sin function
    @parma {float} amplitude of the target sin function
    @param {float} noise_std standard deviation of the noise
    @param {float} noise_mean mean of the noise
    @return {np.array} x x values
            {np.array} y y values

    '''

    rstate = np.random.RandomState(0)
    x = np.linspace(-max_x, max_x, count)
    target = amp*np.sin(freq*x)
    y = target + rstate.normal(noise_mean, noise_std, len(x))
    return x, y


def hFunc(x):
    '''Input activation function

    '''
    return np.tanh(x)


def predict(x, w1, w2):
    M = int(len(w1)/2.0)
    x_itr = iter(x)

    W1 = w1.reshape(M, 2)
    ret = []
    for xi in x_itr:
        a1 = W1.dot(np.array([1, xi]))
        z = np.concatenate(([1], hFunc(a1)))
        a2 = w2.dot(z)
        ret.append(a2)

    if len(ret) == 1:
        return ret[0]
    return np.array(ret)


def gradE_w(x, t, w1, w2):
    '''Gradient the error function w.r.t w1

    W1 is expected to be structured as follows:
       w1 = [w00, w01,
             w10, w11,
             w20, w21, ...]

       where w_{j0} are the biases, and w_{j1} are the weights

    W2 is expected to be structured as follows:
       w2 = [w0, w1, w2, ...]
    where `w0` here is the bias, and the rest are the multiplying
    weights

    grad_w1 = [grad_w00, grad_w10,
               grad_w01, grad_w11,
               ...]

    grad_w2 = [grad_w0, grad_w1, grad_w2, ...]

    len(w1) must be even (2*M)
    len(w2) must be M + 1

    @param {np.array} x  evaluation point of y
    @param {np.array} t  target value at x
    @param {np.array} w1 weighing vector for first layer
    @param {np.array} w2 weighing vector for second layer
    @return {np.array} gradE_w1 gradient of E wrt w1
            {np.array} gradE_w2 gradient of E wrt w2

    '''
    M = int(len(w1)/2.0)
    grad_w1 = np.zeros(len(w1))
    grad_w2 = np.zeros(len(w2))

    y = predict(x, w1, w2)
    # Error at each training point
    y_minus_t = y - t

    # Used to compute the first layer coefficients. w1 is expected to
    # have biases and weights interweaved
    # Example:
    # w1 = [w00, w01, w10, w11, w20, w21, ...]
    # w1.reshape(M,2) =
    #            [[w00, w01],
    #             [w10, w11],
    #             [w20, w21], ...]
    W1 = w1.reshape(M, 2)

    # Iterate through all data points
    for i, diff in enumerate(y_minus_t):
        a1 = W1.dot(np.array([1.0, x[i]]))
        h = hFunc(a1)
        hp = 1.0 - np.power(h, 2)
        for j in range(M):
            # Offset w2 by 1 because first entry is the bias
            grad_w1[2*j]     += diff * w2[j+1] * hp[j] * 1
            grad_w1[2*j + 1] += diff * w2[j+1] * hp[j] * x[i]
            grad_w2[j + 1]   += diff * h[j]
        # gradient wrt w2 bias
        grad_w2[0] += diff

    return grad_w1, grad_w2


def backtrack(f_func, grad_f, x, dx, alpha=0.1, beta=0.3, max_iter=20):
    '''Simple backtrack function

    If number of iterations exceed `max_iter`, returns last iteration value

    @param {func} f_func callable function being minimized
    @param {np.array} grad_f f_func gradient at 'x'
    @param {float} x current solution point
    @param {np.array} dx descent direction
    @param {float} alpha number in the open interval (0, 0.5)
    @param {float} beta number in the open interval (0, 1.0)
    @param {int} max_iter maximum number of iterations to perform
    @return {float} search scale factor (step length)

    '''
    t = 1.0
    iter_count = 1
    while (f_func(x + t*dx) > f_func(x) + alpha*t*np.dot(grad_f, dx)):
        t *= beta
        iter_count += 1
        if ((iter_count-1) > max_iter):
            return t
    return t


def Err(x, t, w1, w2):
    '''Objective function

    0.5 * squared L2 norm of the (total) error

    @param x  {np.array} evaluation point of y
    @param t  {np.array} target value at x
    @param w1 {np.array} weighing vector for first layer
    @param w2 {np.array} weighing vector for second layer
    @return   {number} Total error between model and target

    '''
    y = predict(x, w1, w2)
    return 0.5*np.linalg.norm(y - t)**2


def train(x, t, w1 = None, w2 = None, max_iter=1000, iter_cb = None):
    '''TODO

    @param {np.array1d} w1 optional initial guess for w1
    @param {np.array1d} w2 optional initial guess for w2
    @param {np.array1d} x training observation points
    @param {np.array1d} t training observation values
    @param {int} max_iter maximum number of iterations performed
    @param {func} iter_cb call back called with (iter_count, w1, w2)
           in each iteration
    @return {np.array1d} updated w1 values
            {np.array1d} updated w2 values

    '''

    # Objective function
    F = lambda W : Err(x, t, W[:len(w1)], W[len(w1):])

    W = np.concatenate((w1, w2))

    if iter_cb is None:
        iter_cb = lambda *args: None

    GRAD_TOL = 1E-8
    F_TOL = 1E-8

    # Minimize the square of the L2 error norm
    for i in range(max_iter):
        # gradient of minimizing function
        grad_w1, grad_w2 = gradE_w(x, t, w1, w2)
        gradW = np.concatenate((grad_w1, grad_w2))

        # Descent direction
        dW = -1 * gradW

        t_ = backtrack(F, gradW, W, dW, 0.3, 0.1)

        W += t_*dW
        w1 = W[:len(w1)]
        w2 = W[len(w1):]

        grad_w1, grad_w2 = gradE_w(x, t, w1, w2)
        gradW = np.concatenate((grad_w1, grad_w2))
        gradW_norm = np.linalg.norm(gradW)
        F_norm = np.linalg.norm(F(W))

        # Stop iteration if cb calls for it
        if (iter_cb(i, w1, w2, F_norm, gradW_norm)):
            break

        if ((gradW_norm < GRAD_TOL) or (F_norm < F_TOL)):
            break

    return w1, w2, gradW_norm, F_norm


def initial_guess(M):
    '''Initial guesses for various complexity

    Generated from previous iterations

    If some suitable value was not found, will return a 'random' but
    reproducible set of values

    '''

    cases = {
        1: {
            'w1': [8.86361978, 2.7181682],
            'w2': [0.32038421, -0.388877]
        },
        2: {
            'w1': [-0.2733636,   0.73393021, -0.21517983,  0.46452915],
            'w2': [ 0.03697154,  5.22825299, -5.8836233 ]
        },
        3: {
            'w1': [499.2930588464916, -219.13237238727913, -69.87306961741615,
                  328.104441010302, -294.1889361175237, -129.59594169967295],
            'w2': [0.10996566600959055, 0.5917810150746053, 0.5476244737835927,
                   0.497683594734165]
        },
        4: {
            'w1': [0.6146752, 0.17918364, -3.69317985, 1.10103362,
                   -1.36141653, 0.27402917, -2.7527837, -0.88794461],
            'w2': [0.65788761, 2.56326775, -2.37379236, 3.03258071, 1.67726044]
        },
        5: {
            'w1': [1.72620813, 0.33339141, -0.75851512,  0.81512907,  3.1821002,
                   0.92838425, 2.88101134, -0.86886922, -2.15493726,  0.37969748],
            'w2': [0.56316311, 2.00633212,  0.86616563, -1.66512667,  2.29848745,
                   2.58398541]
        }
    }

    if M not in cases:
        rstate = np.random.RandomState(3)
        w1 = rstate.uniform(size=M*2)
        w2 = rstate.uniform(size=M+1)
        return w1, w2

    c = cases[M]
    return np.array(c['w1']), np.array(c['w2'])


def train_model(x, t, M, print_iters=True):
    '''TODO

    '''
    w1, w2 = initial_guess(M)
    MAX_ITER = 100

    # Callback for each iteration -- print values to screen
    if print_iters:
        itercb = lambda iter_num, w1_, w2_, grad_err_norm, grad_w_norm: \
            print('{}: err = {} ... gw ={}'.format(iter_num, grad_err_norm,  grad_w_norm))
    else:
        itercb = None

    # Find the model weights
    return train(x, t, w1, w2, MAX_ITER, itercb)


def generate_random_results(x, t, complexities,
                            count, low=-1, high=1):
    '''TODO

    '''

    ret = {}
    for M in complexities:
        ret[M] = {}
        for c in range(count):
            # Use a fixed random state to make results reproducible
            rstate = np.random.RandomState(c)
            w1 = rstate.uniform(low=low, high=high, size=M*2)
            w2 = rstate.uniform(low=low, high=high, size=M+1)
            w1, w2, grad_err, err = train(x, t, w1, w2, max_iter=200)

            ret[M][c] = {
                'w1': w1,
                'w2': w2,
                'err': err,
                'grad_err': grad_err
            }
    return ret


def main():
    x, t = noisy_sin_sample(count=10)
    x_p = np.linspace(-2*np.pi, 2*np.pi, 100)
    target = np.sin(x_p)

    # Print results for model complexities 1 through 5
    for M in range(1, 5):
        w1, w2, grad_e_norm, error = train_model(x, t, M)
        y = predict(x_p, w1, w2)
        print('w1:', w1)
        print('w2:', w2)

        fig, ax = plt.subplots()
        ax.plot(x, t, linestyle='', marker='o', markersize=5, label='Data')
        ax.plot(x_p, target, linestyle='--', label='Target')
        ax.plot(x_p, y, label='Model')

        ax.grid(True)
        ax.legend(loc='best', fancybox=True, framealpha=0.5, fontsize='small')
        ax.set_title(r'$M$ = {:1}, $||E||$ = {:.3e}, $||\nabla_w E||$ = {:.3e}'.format(M,
                                                                                       error,
                                                                                       grad_e_norm))
        ax.set_ylabel('y')
        ax.set_xlabel('x')
        fig.patch.set_alpha(0.0)
        __dirname = os.path.dirname(os.path.realpath(__file__))
        fn = '../img/neural-network-regression_m{}.svg'.format(M)
        fn = os.path.join(__dirname, fn)
        plt.savefig(fn,
                    facecolor=fig.get_facecolor(),
                    edgecolor='none',
                    bbox_inches=0)
    # plt.show()


if __name__ == '__main__':
    main()