/*
 * Decompiled with CFR 0.152.
 */
package optas.regression;

import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.TreeSet;
import java.util.logging.Level;
import optas.data.SimpleEnsemble;
import optas.data.TimeSerieEnsemble;
import optas.regression.Interpolation;
import optas.regression.TimeSeriesInterpolation;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.activation.ActivationLinear;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.data.basic.BasicNeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.logic.FeedforwardLogic;
import org.encog.neural.networks.logic.NeuralLogic;
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
import org.encog.util.logging.Logging;
import org.encog.util.obj.SerializeObject;

public class TimeSerieNeuralNetwork
extends TimeSeriesInterpolation {
    boolean isTrained = false;
    double error = 0.0;
    BasicNetwork network;

    @Override
    public void setData(SimpleEnsemble[] x, TimeSerieEnsemble y) {
        super.setData(x, y);
        this.network = new BasicNetwork();
        this.isTrained = false;
    }

    @Override
    public double init() {
        return this.trainNetwork();
    }

    public boolean save(File f) {
        try {
            SerializeObject.save((String)f.getAbsolutePath(), (Serializable)this.network);
        }
        catch (IOException ioe) {
            ioe.printStackTrace();
            return false;
        }
        return true;
    }

    public boolean load(File f) {
        try {
            this.network = (BasicNetwork)SerializeObject.load((String)f.getAbsolutePath());
            this.isTrained = true;
        }
        catch (IOException ioe) {
            ioe.printStackTrace();
            return false;
        }
        catch (ClassNotFoundException nfe) {
            nfe.printStackTrace();
            return false;
        }
        return true;
    }

    private double trainNetwork() {
        if (this.isTrained) {
            return this.error;
        }
        this.error = this.trainNetwork(new TreeSet<Integer>());
        return this.error;
    }

    private double trainNetwork(TreeSet<Integer> leaveOutIndex) {
        this.log("Train Neural Network");
        this.setProgress(0.0);
        Logging.setConsoleLevel((Level)Level.OFF);
        this.network = new BasicNetwork();
        BasicLayer layerIn = new BasicLayer((ActivationFunction)new ActivationSigmoid(), true, this.x.length);
        BasicLayer hidden1 = new BasicLayer((ActivationFunction)new ActivationSigmoid(), true, this.m);
        BasicLayer hidden2 = new BasicLayer((ActivationFunction)new ActivationLinear(), true, this.m);
        this.network.addLayer((Layer)layerIn);
        this.network.addLayer((Layer)hidden1);
        this.network.addLayer((Layer)hidden2);
        this.network.setLogic((NeuralLogic)new FeedforwardLogic());
        this.network.getStructure().finalizeStructure();
        this.network.reset();
        ArrayList<double[]> xData = new ArrayList<double[]>();
        ArrayList<double[]> yData = new ArrayList<double[]>();
        for (int i = 0; i < this.L; ++i) {
            int j;
            int id_i = this.x[0].getId(i);
            if (leaveOutIndex.contains(id_i)) continue;
            double[] sampleX = new double[this.n];
            double[] sampleY = new double[this.m];
            for (j = 0; j < this.n; ++j) {
                sampleX[j] = this.x[j].getValue(id_i);
            }
            for (j = 0; j < this.m; ++j) {
                sampleY[j] = this.getYData(id_i, j);
            }
            xData.add(this.normalizeX(sampleX));
            yData.add(this.normalizeY(sampleY));
        }
        double[][] xDataArray = (double[][])xData.toArray((T[])new double[xData.size()][]);
        double[][] yDataArray = (double[][])yData.toArray((T[])new double[yData.size()][]);
        BasicNeuralDataSet basicNDS = new BasicNeuralDataSet(xDataArray, yDataArray);
        basicNDS.setDescription("testdataset");
        ResilientPropagation backpropagation = new ResilientPropagation(this.network, (NeuralDataSet)basicNDS);
        backpropagation.setError(1.0);
        int epoch = 1;
        int epochMax = 1500;
        do {
            backpropagation.iteration();
            System.out.println("Epoch #" + epoch + " Error:" + backpropagation.getError());
            this.setProgress((double)(++epoch) / (double)epochMax);
        } while (backpropagation.getError() > 0.005 && !backpropagation.isTrainingDone() && epoch < epochMax);
        this.isTrained = true;
        return 0.0;
    }

    @Override
    protected double[][] getInterpolatedValue(TreeSet<Integer> validationSet) {
        this.isTrained = false;
        this.trainNetwork(validationSet);
        double[][] values = new double[validationSet.size()][];
        int counter = 0;
        for (Integer i : validationSet) {
            values[counter++] = this.getInterpolatedValue(this.getX(i));
        }
        return values;
    }

    @Override
    public double[] getInterpolatedValue(double[] u) {
        this.trainNetwork();
        double[] wholeOutput = new double[this.m];
        this.network.compute(this.normalizeX(u), wholeOutput);
        return this.denormalizeY(wholeOutput);
    }

    @Override
    public double[] estimateCrossValidationError(int K, Interpolation.ErrorMethod e) {
        int t;
        double[] result;
        if (!this.network.getProperties().containsKey("CVError")) {
            result = super.estimateCrossValidationError(K, e);
            for (t = 0; t < this.m; ++t) {
                this.network.setProperty(Integer.toString(t), result[t]);
            }
            this.network.setProperty("CVError", 1.0);
        }
        result = new double[this.m];
        for (t = 0; t < this.m; ++t) {
            result[t] = this.network.getPropertyDouble(Integer.toString(t));
        }
        return result;
    }
}

