/*
 * Decompiled with CFR 0.152.
 */
package org.unijena.predictionnet;

import Jama.CholeskyDecomposition;
import Jama.Matrix;
import java.util.Vector;
import org.unijena.jams.data.JAMSBoolean;
import org.unijena.jams.data.JAMSDouble;
import org.unijena.jams.data.JAMSEntity;
import org.unijena.jams.data.JAMSInteger;
import org.unijena.jams.data.JAMSString;
import org.unijena.jams.model.JAMSVarDescription;
import org.unijena.predictionnet.GenericFunction;
import org.unijena.predictionnet.InputNeuron;
import org.unijena.predictionnet.Learner;
import org.unijena.predictionnet.LogisticFunction;
import org.unijena.predictionnet.Neuron;

public class NNLearner
extends Learner {
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSEntity trainData;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSEntity validationData;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSDouble learningrate;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSDouble momentum;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSString layers;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSInteger epochen;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSString resultFile = null;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSString options = null;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="TimeSerie of Temp Data")
    public JAMSBoolean applyLinearRegression;
    int LayerCount;
    Vector<Neuron>[] Layers;
    Vector<Integer> LayerSize;
    Matrix R;
    Neuron outNeuron;

    public NNLearner() {
        this.normalizeData = true;
        this.LayerCount = 3;
        this.Layers = new Vector[this.LayerCount];
        this.LayerSize = new Vector();
    }

    public void setLayerSize(int option) {
        if (option == 1) {
            this.LayerCount = 3;
            this.LayerSize.add(0, new Integer(this.DataLength) + 1);
            this.LayerSize.add(1, new Integer((this.DataLength + 1) / 2));
            this.LayerSize.add(2, 1);
        }
        if (option == 2) {
            this.LayerCount = 4;
            this.LayerSize.add(0, new Integer(this.DataLength) + 1);
            this.LayerSize.add(1, new Integer(this.DataLength + 1) / 2);
            this.LayerSize.add(2, new Integer(this.DataLength + 1) / 2);
            this.LayerSize.add(3, 1);
        }
        if (option == 3) {
            this.LayerCount = 3;
            this.LayerSize.add(0, new Integer(this.DataLength) + 1);
            this.LayerSize.add(1, new Integer(2));
            this.LayerSize.add(2, 1);
        }
    }

    public void setupNET() {
        int i;
        this.generator.setSeed(-1L);
        for (i = 0; i < this.LayerCount; ++i) {
            this.Layers[i] = new Vector();
            this.Layers[i].setSize(this.LayerSize.get(i));
        }
        for (i = 0; i < this.LayerSize.get(0); ++i) {
            InputNeuron inNeuron = new InputNeuron();
            inNeuron.initalize();
            this.Layers[0].set(i, inNeuron);
        }
        for (int m = 1; m < this.LayerCount; ++m) {
            for (int i2 = 0; i2 < this.LayerSize.get(m); ++i2) {
                Neuron innerNeuron = new Neuron();
                LogisticFunction logf = new LogisticFunction(1.0);
                GenericFunction gf = new GenericFunction(logf);
                innerNeuron.initalize();
                if (m != this.LayerCount - 1) {
                    innerNeuron.addFilter(gf);
                }
                this.Layers[m].set(i2, innerNeuron);
                for (int k = 0; k < this.Layers[m - 1].size(); ++k) {
                    innerNeuron.AddConnection(this.Layers[m - 1].get(k), innerNeuron, this.generator.nextDouble() * 1.0 - 0.5);
                }
            }
        }
        this.outNeuron = this.Layers[this.LayerCount - 1].get(0);
    }

    public double Predict(double[] p) {
        int i;
        double[] tmp = new double[p.length];
        if (this.normalizeData) {
            for (i = 0; i < p.length; ++i) {
                tmp[i] = 2.0 * (p[i] - this.base[i]) / (this.max[i] - this.min[i]);
            }
        } else {
            tmp = p;
        }
        for (i = 0; i <= this.DataLength; ++i) {
            InputNeuron inNeuron = (InputNeuron)this.Layers[0].get(i);
            if (i == this.DataLength) {
                inNeuron.SetInput(1.0);
                continue;
            }
            inNeuron.SetInput(tmp[i]);
        }
        this.Propagate();
        return this.outNeuron.getActivation();
    }

    public double Propagate() {
        for (int k = 0; k < this.LayerCount; ++k) {
            for (int i = 0; i < this.Layers[k].size(); ++i) {
                this.Layers[k].get(i).propagate();
            }
        }
        return this.outNeuron.getActivation();
    }

    public void BackPropagate(double error) {
        this.outNeuron.addToError(error);
        for (int k = this.LayerCount - 1; k >= 0; --k) {
            for (int i = 0; i < this.Layers[k].size(); ++i) {
                this.Layers[k].get(i).backpropagate();
                this.Layers[k].get(i).updateWeightDelta();
            }
        }
    }

    public void AdjustWeights() {
        for (int k = 1; k < this.LayerCount; ++k) {
            for (int i = 0; i < this.Layers[k].size(); ++i) {
                this.Layers[k].get(i).adjustWeight();
            }
        }
    }

    public double TrainCycle() {
        double accError = 0.0;
        for (int p = 0; p < this.TrainLength; ++p) {
            double predValue = this.Predict(this.data[p]);
            double correctValue = this.result[p];
            accError += Math.abs(correctValue - predValue);
            this.BackPropagate(correctValue - predValue);
        }
        this.AdjustWeights();
        return accError;
    }

    public double[] Predict(double[][] data, double[] predict, boolean writeResult) {
        int M = data[0].length;
        int P = data.length;
        return null;
    }

    public void Train(double[][] data, double[] predict) {
        int M = data[0].length;
        int N = data.length;
        double[] predict_linreg = null;
        predict_linreg = this.applyLinearRegression.getValue() ? this.LinearRegression(data, predict) : predict;
    }

    public double SingleRun(double[][] trainData, double[] trainPredict, double[][] valData, double[] valPredict, boolean writeResult) {
        this.Train(trainData, trainPredict);
        double[] result = this.Predict(valData, valPredict, writeResult);
        double MSE = 0.0;
        for (int i = 0; i < result.length; ++i) {
            MSE += (valPredict[i] - result[i]) * (valPredict[i] - result[i]);
        }
        return MSE;
    }

    public double crossvalidation(double[][] data, double[] predict) {
        long t1 = System.currentTimeMillis();
        int k = 5;
        double error = 0.0;
        int N = data.length;
        int M = data[0].length;
        int d = N / k;
        if (d * k != N) {
            ++d;
        }
        for (int i = 0; i < k; ++i) {
            int trainCounter = 0;
            int valCounter = 0;
            for (int j = 0; j < N; ++j) {
                if (j / d == i) {
                    ++valCounter;
                    continue;
                }
                ++trainCounter;
            }
            double[][] valData = new double[valCounter][];
            double[] valPredict = new double[valCounter];
            double[][] trainData = new double[trainCounter][];
            double[] trainPredict = new double[trainCounter];
            trainCounter = 0;
            valCounter = 0;
            for (int j = 0; j < N; ++j) {
                if (j / d == i) {
                    valData[valCounter] = data[j];
                    valPredict[valCounter] = predict[j];
                    ++valCounter;
                    continue;
                }
                trainData[trainCounter] = data[j];
                trainPredict[trainCounter] = predict[j];
                ++trainCounter;
            }
            error += this.SingleRun(trainData, trainPredict, valData, valPredict, false);
        }
        long t2 = System.currentTimeMillis();
        return error;
    }

    public void SetParams(int EpochCounter, double LearningRate, double Momentum) {
        if (EpochCounter <= 0) {
            EpochCounter = 1;
        }
        if (LearningRate <= 0.0) {
            LearningRate = 0.01;
        }
        if (Momentum < 0.0) {
            Momentum = 0.01;
        }
        if (this.epochen == null) {
            this.epochen = new JAMSInteger();
        }
        if (this.learningrate == null) {
            this.learningrate = new JAMSDouble();
        }
        if (this.momentum == null) {
            this.momentum = new JAMSDouble();
        }
        this.epochen.setValue(EpochCounter);
        this.learningrate.setValue(LearningRate);
        this.momentum.setValue(Momentum);
    }

    public void optimize(double[][] data, double[] predict) {
        System.out.println("Optimization");
        boolean noImprovement = false;
        double EpochCounter = 1000.0;
        double LearningRate = 0.3;
        double Momentum = 0.2;
        double delta = 0.001;
        double alpha_min = 1.0E-4;
        int alpha_Epoch = 10;
        double alpha_lrate = 0.1;
        double alpha_mom = 0.1;
        double y_best = 1.0E8;
        double y_best_alt = 1.0E11;
        int Counter = 0;
        double bestEpoch = 0.0;
        do {
            double y_neu;
            this.SetParams((int)EpochCounter, LearningRate, Momentum);
            y_best_alt = y_best = (y_neu = this.crossvalidation(data, predict));
            System.out.println("Startwert:" + y_best_alt);
            this.SetParams((int)EpochCounter, LearningRate + 0.01, Momentum);
            double y_tmp = this.crossvalidation(data, predict);
            double gradient = y_tmp < y_neu ? 1.0 : -1.0;
            alpha_lrate = 4.0 * alpha_lrate + 0.05;
            double ysave = y_neu;
            while (true) {
                if (LearningRate + (alpha_lrate /= 2.0) * gradient <= 0.0) {
                    continue;
                }
                if (alpha_lrate <= 1.0E-4) {
                    alpha_lrate = 0.0;
                    y_neu = ysave;
                    break;
                }
                this.SetParams((int)EpochCounter, LearningRate + alpha_lrate * gradient, Momentum);
                y_neu = this.crossvalidation(data, predict);
                if (!(y_neu >= y_best)) break;
            }
            y_best = y_neu;
            System.out.println("Epochen:" + EpochCounter + "\tLernrate:" + (LearningRate += alpha_lrate * gradient) + "\tMomentum:" + Momentum + "\tyBest:" + y_best);
            this.SetParams((int)EpochCounter, LearningRate, Momentum + 0.01);
            y_tmp = this.crossvalidation(data, predict);
            gradient = y_tmp < y_neu ? 1.0 : -1.0;
            alpha_mom = 4.0 * alpha_mom + 0.1;
            ysave = y_neu;
            while (true) {
                if (Momentum + (alpha_mom /= 2.0) * gradient <= 0.0) {
                    continue;
                }
                if (alpha_mom <= 1.0E-4) {
                    alpha_mom = 0.0;
                    y_neu = ysave;
                    break;
                }
                this.SetParams((int)EpochCounter, LearningRate, Momentum + alpha_mom * gradient);
                y_neu = this.crossvalidation(data, predict);
                if (!(y_neu > y_best)) break;
            }
            y_best = y_neu;
            System.out.println("Epochen:" + EpochCounter + "\tLernrate:" + LearningRate + "\tMomentum:" + (Momentum += alpha_mom * gradient) + "\tyBest:" + y_best);
            System.out.println("Verbessung in diesem Durchgang:" + y_best_alt / y_best);
        } while (y_best_alt / y_best > 1.05 || ++Counter <= 3);
    }

    double[] LinearRegression(double[][] data, double[] predict) {
        Matrix A = new Matrix(data.length, data[0].length + 1);
        Matrix B = new Matrix(data.length, 1);
        for (int i = 0; i < data.length; ++i) {
            for (int j = 0; j < data[0].length; ++j) {
                A.set(i, j, data[i][j]);
            }
            A.set(i, data[0].length, 1.0);
            B.set(i, 0, predict[i]);
        }
        Matrix K = A.transpose().times(A);
        B = A.transpose().times(B);
        CholeskyDecomposition C = K.chol();
        if (!C.isSPD()) {
            System.out.println("Error!! not a SPD - Matrix");
        }
        this.R = C.solve(B);
        Matrix B2 = A.times(this.R);
        double[] result = new double[predict.length];
        for (int i = 0; i < data.length; ++i) {
            result[i] = predict[i] - B2.get(i, 0);
        }
        return result;
    }

    public void layerTest(double[][] data, double[] predict, double[][] validation_data, double[] validation_predict) {
        for (int i = 1; i < 40; i += 2) {
            this.resultFile.setValue(this.options.toString() + "_result_" + Integer.toString(i) + ".txt");
            this.layers.setValue(Integer.toString(i));
            this.Train(data, predict);
            this.Predict(validation_data, validation_predict, true);
        }
    }

    @Override
    public void run() {
        double[][] data = null;
        double[] predict = null;
        double[][] validation_data = null;
        double[] validation_predict = null;
        try {
            data = (double[][])this.trainData.getObject("data");
            predict = (double[])this.trainData.getObject("predict");
            validation_data = (double[][])this.validationData.getObject("data");
            validation_predict = (double[])this.validationData.getObject("predict");
        }
        catch (Exception e) {
            System.out.println("could not find data!!" + e.toString());
        }
        this.Train(data, predict);
        this.Predict(validation_data, validation_predict, true);
    }
}

