#!/usr/bin/python3


import os
import matplotlib.pyplot as plt
import numpy as np
import functools as ft

def create_dataset(point_count = 10,
                   mean = [0, 0],
                   variance = 1,
                   seed = 0):
    '''Create a 2D data cluster with specified mean and variance

    @param {int} point_ount number of points
    @param {list} mean of the dataset
    @param {float} diagonal variance value
    @param {int} seed random seed number for reproducible results
    @return {list} List of (x, y) tuples

    '''
    RSTATE = np.random.RandomState(seed)
    cov = np.identity(2) * variance
    dset = []
    x, y = RSTATE.multivariate_normal(mean, cov, point_count).T
    for j, _ in enumerate(x):
        dset.append(np.array([x[j], y[j]], np.float64))
    return dset


def plot_data_points(ax, data_sets):
    '''Plot scatter

    @param {matplotlib.Axis} axis object to plot on
    @param {list} list of data sets
    @return {None}

    '''

    # Default colors used for plotting data points
    COLORS = ['#1f77b4', '#ff7f03', '#2ca02c',
              '#d62728', '#9467bd', '#8c564b']
    for class_id, points in enumerate(data_sets):
        x, y = zip(*points)
        ax.plot(x, y,
                marker='o',
                linestyle='',
                markersize=4,
                label='Set {}'.format(class_id+1),
                color=COLORS[class_id])


def data_mean_cov_set(dataset):
    '''Return thea mean and covariance

    @param {list} dataset list of (x,y) tuples
    @return {np.ndarray} mean of the dataset
            {np.ndarray} covariance of the dataset

    '''

    N = len(dataset)
    mean = ft.reduce(lambda a, b: np.add(a, b),
                     dataset)
    mean = 1.0/N * mean
    cov = np.zeros((2,2))
    for p in dataset:
        cov += np.outer(p-mean, p-mean)
    return mean, cov


def sw_set(data_set, mean):
    ''' Return the Sw matrix for a single set

    @param {list} data_set list of (x, y) tuple
    @param {mean} mean of the dataset
    @return {np.ndarray} Sw matrix

    '''
    N = len(data_set)
    Sw = np.zeros((2,2))
    for p in data_set:
        Sw += np.outer(p - mean, p - mean)
    return Sw

def sb_set(means):
    ''' Return the Sb matrix for two sets

    @param {list} means list array of means of the dataset
    @return {np.ndarray} Sb matrix

    '''
    return np.outer(means[1] - means[0], means[1] - means[0])


def Sb_Sw(datasets, means):
    '''Return the Sb and Sw matrix for two sets

    @param {list} datasets list of data sets
    @param {list} means means of the datasets
    @return {np.ndarray} Sb matrix
            {np.ndarray} Sw matrix

    '''
    set1, set2 = datasets[0], datasets[1]
    Sw = sw_set(set1, means[0]) + sw_set(set2, means[1])
    Sb = sb_set(means)
    return Sb, Sw


def zero_line(x, w, w0):
    '''Return the y values corresponding to a linear discriminant

    The discriminant is defined by
                      y = w^{T} * x + w0
                        = w_0 * x_0 + w_1 * x_1 + w0
                        = 0

    x_1 is computed as
                      x_1 = -1/w_1 * ( w_0 * x_0 + w0 )

    @param {np.ndarray} x values for which to compute the
            corresponding y-value
    @param {list} w array of linear discrimnant weights
    @return {list} y-values of the linear discriminant

    '''

    return -1/w[1] * (w[0] * x + w0)
    # return 1/w1*(w[0]*mean[0] + w[1] * mean[1] - w[0]*x)


def set_extent(dataset):
    x_min, x_max = dataset[0][0], dataset[0][0]
    y_min, y_max = dataset[0][1], dataset[0][1]

    for p in dataset:
        x_min = np.minimum(x_min, p[0])
        y_min = np.minimum(y_min, p[1])
        x_max = np.maximum(x_max, p[0])
        y_max = np.maximum(y_max, p[1])
    return [x_min, x_max], [y_min, y_max]


def sets_extent(datasets):
    x_range, y_range = set_extent(datasets[0])
    for dset in datasets:
        [x0, x1], [y0, y1] = set_extent(dset)
        x_range[0] = np.minimum(x_range[0], x0)
        x_range[1] = np.maximum(x_range[1], x1)
        y_range[0] = np.minimum(y_range[0], y0)
        y_range[1] = np.maximum(y_range[1], y1)
    return x_range, y_range


def create_plot(seed1, seed2):
    fig, ax = plt.subplots()

    set1 = create_dataset(30, mean=[+1, +1], variance=1, seed=seed1)
    set2 = create_dataset(30, mean=[-1, -1], variance=1, seed=seed2)
    data = [set1, set2]

    mean1, cov1 = data_mean_cov_set(set1)
    mean2, cov2 = data_mean_cov_set(set2)
    means = [mean1, mean2]

    x_range, y_range = sets_extent(data)
    x = np.linspace(x_range[0], x_range[1], 5)

    sb, sw = Sb_Sw(data, means)

    N = len(data[0]) + len(data[1])
    N1, N2 = len(data[0]), len(data[1])
    mean_total = 1.0/N*(N1 * means[0] + N2 * means[1])
    w = np.linalg.inv(sw).dot(means[1] - means[0])
    w0 = - mean_total.dot(w)
    lp1, = ax.plot(x, zero_line(x, w, w0), color='red')

    plot_data_points(ax, data)
    ax.legend(loc='best', fancybox=True,
              framealpha=0.5, fontsize='medium')
    ax.grid(True)
    return fig, ax


def main():
    fig, ax = create_plot(4, 8)

    __dirname = os.path.dirname(os.path.realpath(__file__))
    fig.patch.set_alpha(0.0)
    fn = '../img/classification-least-squares.svg'
    fn = os.path.join(__dirname, fn)
    plt.savefig(fn,
                facecolor=fig.get_facecolor(),
                edgecolor='none',
                bbox_inches=0)


if __name__ == '__main__':
    main()