#!/usr/bin/python3


import os
import numpy as np
from matplotlib import pyplot as plt


def data_set_linear(count = 10, sigma = 0.1):
    '''Return a linear dataset with normally distributed noise

    @param {int} count number of points
    @param {float} sigma std deviation of the added noise
    @return {np.ndarray} x data
            {np.ndarray} y data
            {np.ndarray} unperturned y data

    '''
    # Make random values reproducible
    rstate = np.random.RandomState(0)
    x = np.linspace(0, 1, count)
    # Linear line with noise with no bias (offset)
    y = x + rstate.normal(0, sigma, len(x))
    return x, y, x


def predict(x, w):
    ''' Predict values observation values at `x`

    @param {float|np.1darray} location to evaluate model
    @param {np.1darray} model coefficients
    @return {np.array} Predicted values at @x
    '''
    return np.array([w[0] + w[1] * x_ for x_ in iter(x)])


def train(x, t):
    N = len(x)
    # ||t||_2^2
    t_l22 = np.linalg.norm(t) ** 2
    t_mean = np.mean(t)
    # ||x||_2^2
    x_l22 = np.linalg.norm(x) ** 2
    x_mean = np.mean(x)

    num = t_l22 - N * t_mean**2 - x_l22 + N * x_mean**2
    den = 2*(N*x_mean*t_mean - np.dot(x, t))
    c = num/den
    w1 = -c + np.sqrt(c**2 + 1)
    w0 = t_mean - w1 * x_mean
    return np.array([w0, w1])


def main():
    x, t, t_unperturbed = data_set_linear()
    w = train(x, t)

    x_p = np.linspace(x[0], x[-1], len(x)*2)
    y = predict(x_p, w)

    fig, ax = plt.subplots()
    # "Training" dataset
    ax.plot(x, t, linestyle='', marker='o', markersize='4', label='Data set')
    # Target
    ax.plot(x, t_unperturbed, linestyle='--', label=r'Target')
    # Model graph
    ax.plot(x_p, y, linestyle='-', label=r'Model $y(x, \mathbf{w})$')

    ax.set_ylabel('t')
    ax.set_xlabel('x')
    ax.grid(True)
    ax.legend(loc='best', fancybox=True, framealpha=0.5, fontsize='medium')

    __dirname = os.path.dirname(os.path.realpath(__file__))
    fn = os.path.join(__dirname, '../img/total-regression.svg')
    fig.patch.set_alpha(0.0)
    plt.savefig(fn,
                facecolor=fig.get_facecolor(),
                edgecolor='none',
                bbox_inches=0)


if __name__ == '__main__':
    main()