/*
 * Decompiled with CFR 0.152.
 */
package org.unijena.predictionnet;

import Jama.CholeskyDecomposition;
import Jama.Matrix;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import org.unijena.jams.data.JAMSBoolean;
import org.unijena.jams.data.JAMSDoubleArray;
import org.unijena.jams.data.JAMSEntity;
import org.unijena.jams.data.JAMSInteger;
import org.unijena.jams.data.JAMSString;
import org.unijena.jams.model.JAMSVarDescription;
import org.unijena.predictionnet.Learner;
import org.unijena.predictionnet.kernels.Exponential;
import org.unijena.predictionnet.kernels.FixedMeanModell;
import org.unijena.predictionnet.kernels.Kernel;
import org.unijena.predictionnet.kernels.LinearMeanModell;
import org.unijena.predictionnet.kernels.MaternClass;
import org.unijena.predictionnet.kernels.NeuralNetwork;
import org.unijena.predictionnet.kernels.QuadraticMeanModell;
import org.unijena.predictionnet.kernels.RationalQuadratic;
import org.unijena.predictionnet.kernels.SimpleExponential;
import org.unijena.predictionnet.kernels.SimpleMatern;
import org.unijena.predictionnet.kernels.SimpleNeuralNetwork;
import org.unijena.predictionnet.kernels.SimplePeriodic;
import org.unijena.predictionnet.kernels.SimpleRationalQuadratic;
import org.unijena.predictionnet.kernels.TestKernel;

public class GaussianLearner
extends Learner {
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSEntity trainData;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSEntity optimizationData;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSEntity validationData;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSInteger kernelMethod;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSInteger MeanMethod;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSInteger PerformanceMeasure;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSDoubleArray param_theta;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSString parameterFile = null;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSString resultFile = null;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSBoolean doOptimization;
    Matrix CovarianzMatrix;
    CholeskyDecomposition Solver;
    Matrix Observations;
    Matrix alpha;
    Matrix invCovarianzMatrix;
    double[] logtheta;
    double[] theta;
    static final int MAXIMIZATION = 1;
    static final int MINIMIZATION = 2;
    static final int ABSMAXIMIZATION = 3;
    static final int ABSMINIMIZATION = 4;
    int MaximizeEff = 1;
    Kernel kernel;
    static final double resolution = 0.01;
    static final double limit = 200.0;
    static double[] gaussianDistribution;
    public double variancecontrol = 1.0;

    public double getMarginalLikelihood() {
        double n = this.TrainLength;
        double term1 = -0.5 * this.Observations.transpose().times(this.alpha).get(0, 0);
        double term2 = 0.0;
        Matrix L = this.Solver.getL();
        for (int i = 0; i < L.getColumnDimension(); ++i) {
            term2 += 2.0 * Math.log(L.get(i, i));
        }
        term2 = -0.5 * term2;
        double term3 = -n / 2.0 * Math.log(Math.PI * 2);
        return term1 + term2 + term3;
    }

    public double getLOOlogPredictiveProbability() {
        this.invCovarianzMatrix = this.CovarianzMatrix.inverse();
        double logp = 0.0;
        for (int i = 0; i < this.TrainLength; ++i) {
            double mu_i = this.Observations.get(i, 0) - this.alpha.get(i, 0) / this.invCovarianzMatrix.get(i, i);
            double sigma_i = 1.0 / this.invCovarianzMatrix.get(i, i);
            if (sigma_i < 0.0) {
                sigma_i = -sigma_i;
            }
            sigma_i = Math.sqrt(sigma_i);
            logp += -Math.log(sigma_i) - Math.pow((this.Observations.get(i, 0) - mu_i) / sigma_i, 2.0) / 2.0;
        }
        return logp;
    }

    public double getLOOSquareError() {
        this.invCovarianzMatrix = this.CovarianzMatrix.inverse();
        double error = 0.0;
        for (int i = 0; i < this.TrainLength; ++i) {
            double mu_i = this.Observations.get(i, 0) - this.alpha.get(i, 0) / this.invCovarianzMatrix.get(i, i);
            error += (mu_i - this.Observations.get(i, 0)) * (mu_i - this.Observations.get(i, 0));
        }
        return -error;
    }

    public double getSplitValidationError() {
        double[] result = null;
        double[] correctValue = null;
        try {
            result = this.Predict(false);
            correctValue = (double[])this.validationData.getObject("predict");
        }
        catch (Exception e) {
            System.out.println("GP SplitValidation - Error: " + e.toString());
        }
        double sum = 0.0;
        for (int i = 0; i < result.length; ++i) {
            sum += (result[i] - correctValue[i]) * (result[i] - correctValue[i]);
        }
        return -sum;
    }

    public double Train(int PerformanceMeasure) {
        int i;
        this.CovarianzMatrix = new Matrix(this.TrainLength, this.TrainLength);
        this.Observations = new Matrix(this.TrainLength, 1);
        this.theta = new double[this.kernel.getParameterCount()];
        this.logtheta = new double[this.kernel.getParameterCount()];
        if (this.parameterFile != null && this.param_theta == null) {
            try {
                BufferedReader reader = new BufferedReader(new FileReader(this.parameterFile.getValue()));
                for (int i2 = 0; i2 < this.theta.length; ++i2) {
                    this.logtheta[i2] = Math.log(new Double(reader.readLine()));
                }
                reader.close();
            }
            catch (Exception e) {
                System.out.println("Could not open or read parameter file, becauce:" + e.toString());
                return 0.0;
            }
        }
        if (this.param_theta != null) {
            for (int i3 = 0; i3 < this.logtheta.length; ++i3) {
                this.logtheta[i3] = Math.log(this.param_theta.getValue()[i3]);
            }
        }
        for (i = 0; i < this.theta.length; ++i) {
            this.theta[i] = Math.exp(this.logtheta[i]);
        }
        if (!this.kernel.SetParameter(this.theta)) {
            System.out.println("zu wenig Parametern");
        }
        for (i = 0; i < this.TrainLength; ++i) {
            for (int j = 0; j < i; ++j) {
                double varianz = this.kernel.kernel(this.normalize(this.data[i]), this.normalize(this.data[j]), i, j);
                this.CovarianzMatrix.set(i, j, varianz);
                this.CovarianzMatrix.set(j, i, varianz);
            }
            double varianz = this.kernel.kernel(this.normalize(this.data[i]), this.normalize(this.data[i]), i, i);
            this.CovarianzMatrix.set(i, i, varianz);
        }
        this.Observations = this.kernel.MM.Transform(this.data, this.result);
        this.Solver = this.CovarianzMatrix.chol();
        if (!this.Solver.isSPD()) {
            System.out.println("NOT a SPD Matrix");
            return -1.0E12;
        }
        this.alpha = this.Solver.solve(this.Observations);
        switch (PerformanceMeasure) {
            case 1: {
                return this.getMarginalLikelihood();
            }
            case 2: {
                return this.getLOOlogPredictiveProbability();
            }
            case 3: {
                return this.getLOOSquareError();
            }
            case 4: {
                return this.getSplitValidationError();
            }
        }
        return 0.0;
    }

    public double GetMean(double[] x) {
        Matrix kstar = new Matrix(1, this.TrainLength);
        for (int i = 0; i < this.TrainLength; ++i) {
            double variance = this.kernel.kernel(this.normalize(this.data[i]), this.normalize(x), i, -1);
            kstar.set(0, i, variance);
        }
        Matrix prediction = kstar.times(this.alpha);
        double[][] x_tmp = new double[][]{x};
        this.result = this.kernel.MM.ReTransform(x_tmp, prediction);
        return this.result[0];
    }

    public double GetVariance(double[] x) {
        Matrix kstar = new Matrix(1, this.TrainLength);
        Matrix kstarT = new Matrix(this.TrainLength, 1);
        Matrix one = new Matrix(1, this.TrainLength);
        Matrix oneT = new Matrix(this.TrainLength, 1);
        for (int i = 0; i < this.TrainLength; ++i) {
            double variance = this.kernel.kernel(this.normalize(this.data[i]), this.normalize(x), i, -1);
            kstar.set(0, i, variance);
            kstarT.set(i, 0, variance);
            one.set(0, i, 1.0);
            oneT.set(i, 0, 1.0);
        }
        Matrix RMinus1r = this.Solver.solve(kstarT);
        Matrix rRMinus1r = kstar.times(RMinus1r);
        Matrix RMinus1Eins = this.Solver.solve(oneT);
        Matrix EinsRMinus1Eins = one.times(RMinus1Eins);
        double t = 1.0 - rRMinus1r.get(0, 0);
        double tOne = EinsRMinus1Eins.get(0, 0);
        double sigma2 = 1.0 / (double)this.Observations.getColumnDimension() * this.Observations.transpose().times(this.alpha).get(0, 0);
        return t;
    }

    public static double Gauss(double a) {
        return 1.0 / Math.sqrt(Math.PI * 2) * Math.exp(-0.5 * a * a);
    }

    public static void BuildGaussDistributionTable() {
        double x1 = 0.0;
        double x2 = 0.01;
        gaussianDistribution = new double[20001];
        int counter = 0;
        double integral = 0.5;
        while (x1 < 200.0) {
            GaussianLearner.gaussianDistribution[counter++] = integral += (x2 - x1) / 6.0 * (GaussianLearner.Gauss(x1) + 4.0 * GaussianLearner.Gauss(0.5 * (x1 + x2)) + GaussianLearner.Gauss(x2));
            x1 = x2;
            x2 += 0.01;
        }
    }

    public double GetProbabilityForXLessY(double[] x, double target) {
        double prob;
        long index;
        double mean = this.GetMean(x);
        double variance = this.GetVariance(x) * this.variancecontrol;
        if (variance < 1.0E-4) {
            return 0.0;
        }
        target -= mean;
        double tmp = target /= variance;
        if (tmp < 0.0) {
            tmp = -tmp;
        }
        if ((index = (long)(tmp / 200.0 * (double)gaussianDistribution.length)) >= (long)gaussianDistribution.length) {
            System.out.println("gp out of range!!");
            prob = 1.0;
        } else {
            prob = gaussianDistribution[(int)index];
        }
        if (target < 0.0) {
            prob = 1.0 - prob;
        }
        return prob;
    }

    public double[] Predict(boolean writeOutput) {
        double[][] x = null;
        double[] correctValue = null;
        try {
            x = (double[][])this.validationData.getObject("data");
            correctValue = (double[])this.validationData.getObject("predict");
        }
        catch (Exception e) {
            System.out.println("Could not find validation data. " + e.toString());
            return null;
        }
        int m = x.length;
        Matrix kstar = new Matrix(m, this.TrainLength);
        for (int j = 0; j < m; ++j) {
            for (int i = 0; i < this.TrainLength; ++i) {
                double varianz = this.kernel.kernel(this.normalize(this.data[i]), this.normalize(x[j]), i, -1);
                kstar.set(j, i, varianz);
            }
        }
        Matrix prediction = kstar.times(this.alpha);
        this.result = this.kernel.MM.ReTransform(x, prediction);
        if (!writeOutput) {
            return this.result;
        }
        BufferedWriter writer = null;
        try {
            writer = new BufferedWriter(new FileWriter(this.resultFile.getValue(), true));
        }
        catch (Exception e) {
            System.out.println("Could not open result file, becauce:" + e.toString());
            System.out.println("results won't be saved");
        }
        for (int i = 0; i < x.length; ++i) {
            try {
                writer.write(new String(correctValue[i] + "\t" + this.result[i] + "\n"));
                writer.flush();
                continue;
            }
            catch (Exception e) {
                System.out.println("could not write, because: " + e.toString());
            }
        }
        try {
            writer.close();
        }
        catch (Exception e) {
            System.out.println("GP - Error" + e.toString());
        }
        return this.result;
    }

    public void optInit() {
        ((Learner)this).trainData = this.optimizationData;
        ((Learner)this).validationData = this.validationData;
        try {
            super.run();
        }
        catch (Exception e) {
            System.out.println("GP Init Fehler - " + e.toString());
        }
    }

    public void trainInit() {
        ((Learner)this).trainData = this.trainData;
        ((Learner)this).validationData = this.validationData;
        try {
            super.run();
        }
        catch (Exception e) {
            System.out.println("GP Init Fehler - " + e.toString());
        }
    }

    public double funct(double[] x) {
        double value = 0.0;
        if (this.param_theta == null) {
            this.param_theta = new JAMSDoubleArray();
            double[] array = new double[x.length];
            this.param_theta.setValue(array);
        }
        for (int j = 0; j < x.length; ++j) {
            this.param_theta.getValue()[j] = Math.exp(x[j]);
        }
        double performance = this.Train(this.PerformanceMeasure.getValue());
        if (this.MaximizeEff == 2) {
            return performance;
        }
        if (this.MaximizeEff == 4) {
            return Math.abs(performance);
        }
        if (this.MaximizeEff == 3) {
            return -Math.abs(performance);
        }
        if (this.MaximizeEff == 1) {
            return -performance;
        }
        return 0.0;
    }

    public void GradientDescent(double[] x) {
        int i;
        double[] grad = new double[x.length];
        double[] alpha = new double[x.length];
        double[] xp = new double[x.length];
        double alpha_min = 0.001;
        double diff_min = 0.025;
        double approxError = 1.0E-4;
        double diff = 1.0;
        double y1 = this.funct(x);
        double y_neu = 1.0;
        double calpha = 0.1;
        for (i = 0; i < x.length; ++i) {
            alpha[i] = 0.1;
        }
        while (calpha > alpha_min && diff > diff_min) {
            double y_alt = y1;
            for (i = 0; i < x.length; ++i) {
                block10: {
                    if (alpha[i] == 0.0) continue;
                    for (int j = 0; j < x.length; ++j) {
                        xp[j] = j == i ? x[j] + approxError : x[j];
                    }
                    double y2 = this.funct(xp);
                    grad[i] = (y2 - y1) / approxError;
                    grad[i] = grad[i] < 0.0 ? -1.0 : 1.0;
                    int n = i;
                    alpha[n] = alpha[n] * 4.0;
                    if (alpha[i] >= 2.0) {
                        alpha[i] = 2.0;
                    }
                    do {
                        for (int k = 0; k < x.length; ++k) {
                            xp[k] = x[k];
                            if (k != i) continue;
                            xp[k] = x[i] - alpha[i] * grad[i];
                            if (xp[k] < -4.0) {
                                xp[k] = -4.0;
                            }
                            if (!(xp[k] > 4.0)) continue;
                            xp[k] = 4.0;
                        }
                        y_neu = this.funct(xp);
                        if (y_neu < y1) break block10;
                        int n2 = i;
                        alpha[n2] = alpha[n2] / 2.0;
                    } while (!(alpha[i] < alpha_min));
                    xp[i] = x[i];
                    alpha[i] = 0.0;
                    y_neu = this.funct(xp);
                }
                y1 = y_neu;
                String info = "Gradient:\t";
                String info2 = "Stelle:\t";
                for (int k = 0; k < x.length; ++k) {
                    x[k] = xp[k];
                    info = i == k ? info + grad[i] + "\t" : info + "0.0\t";
                    info2 = info2 + x[k] + "\t";
                }
                this.getModel().getRuntime().println(info);
                this.getModel().getRuntime().println(info2);
                this.getModel().getRuntime().println("Funktionswert:\t" + y1 + "\t Alpha: " + calpha + "\t diff:" + diff);
            }
            for (i = 0; i < x.length; ++i) {
                if (!(alpha[i] > calpha)) continue;
                calpha = alpha[i];
            }
            diff = Math.abs((y_neu - y_alt) / y_neu);
            y_alt = y_neu;
        }
    }

    public void MomentumGradientDescent(double[] x) {
        double[] grad = new double[x.length];
        double[] xp = new double[x.length];
        double alpha_min = 1.0E-15;
        double diff_min = 1.0E-10;
        double approxError = 1.0E-4;
        double alpha = 0.1;
        double diff = 1.0;
        double y1 = this.funct(x);
        while (alpha > alpha_min && diff > diff_min) {
            int i;
            double y_neu;
            int i2;
            for (int i3 = 0; i3 < x.length; ++i3) {
                for (int j = 0; j < x.length; ++j) {
                    xp[j] = j == i3 ? x[j] + approxError : x[j];
                }
                double y2 = this.funct(xp);
                grad[i3] = (y2 - y1) / approxError;
            }
            double sum = 0.0;
            for (i2 = 0; i2 < grad.length; ++i2) {
                sum += grad[i2] * grad[i2];
            }
            sum = Math.sqrt(sum);
            i2 = 0;
            while (i2 < grad.length) {
                int n = i2++;
                grad[n] = grad[n] / sum;
            }
            alpha *= 4.0;
            do {
                for (int i4 = 0; i4 < x.length; ++i4) {
                    xp[i4] = x[i4] - alpha * grad[i4];
                }
            } while (!((y_neu = this.funct(xp)) < y1) && !((alpha /= 2.0) < alpha_min));
            diff = Math.abs(y1 / y_neu - 1.0);
            y1 = y_neu;
            String info = "Gradient:\t";
            for (i = 0; i < x.length; ++i) {
                int n = i;
                x[n] = x[n] - alpha * grad[i];
                info = info + grad[i] + "\t";
            }
            this.getModel().getRuntime().println(info);
            info = "Stelle:\t\t";
            for (i = 0; i < x.length; ++i) {
                info = info + x[i] + "\t";
            }
            this.getModel().getRuntime().println(info);
            this.getModel().getRuntime().println("Funktionswert:\t" + y1 + "\t Alpha: " + alpha);
        }
    }

    public void setKernels() {
        switch (this.kernelMethod.getValue()) {
            case 0: {
                this.kernel = new TestKernel(this.DataLength);
                break;
            }
            case 2: {
                this.kernel = new Exponential(this.DataLength);
                break;
            }
            case 3: {
                this.kernel = new MaternClass(this.DataLength);
                break;
            }
            case 5: {
                this.kernel = new RationalQuadratic(this.DataLength);
                break;
            }
            case 6: {
                this.kernel = new NeuralNetwork(this.DataLength);
                break;
            }
            case 12: {
                this.kernel = new SimpleExponential(this.DataLength);
                break;
            }
            case 13: {
                this.kernel = new SimpleMatern(this.DataLength);
                break;
            }
            case 15: {
                this.kernel = new SimpleRationalQuadratic(this.DataLength);
                break;
            }
            case 16: {
                this.kernel = new SimpleNeuralNetwork(this.DataLength);
                break;
            }
            case 17: {
                this.kernel = new SimplePeriodic(this.DataLength);
                break;
            }
            default: {
                this.kernel = null;
                System.out.println("No valid Kernel specified, this will propably cause an error!");
            }
        }
        switch (this.MeanMethod.getValue()) {
            case 0: {
                this.kernel.SetMeanModell(new FixedMeanModell(this.DataLength));
                break;
            }
            case 1: {
                this.kernel.SetMeanModell(new LinearMeanModell(this.DataLength));
                break;
            }
            case 2: {
                this.kernel.SetMeanModell(new QuadraticMeanModell(this.DataLength));
                break;
            }
            default: {
                this.kernel.SetMeanModell(new FixedMeanModell(this.DataLength));
            }
        }
    }

    @Override
    public void run() {
        this.trainInit();
        this.setKernels();
        if (this.doOptimization.getValue()) {
            this.optInit();
            double[] x = new double[this.kernel.getParameterCount()];
            for (int i = 0; i < x.length; ++i) {
                x[i] = 1.0 / (double)this.kernel.getParameterCount();
            }
            this.GradientDescent(x);
        }
        this.trainInit();
        this.Train(0);
        this.Predict(true);
    }
}

