/*
 * Decompiled with CFR 0.152.
 */
package optas.regression;

import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.TreeSet;
import java.util.logging.Level;
import optas.data.SimpleEnsemble;
import optas.regression.SimpleInterpolation;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.activation.ActivationLinear;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.neural.data.Indexable;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.data.basic.BasicNeuralDataSet;
import org.encog.neural.data.folded.FoldedDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.logic.FeedforwardLogic;
import org.encog.neural.networks.logic.NeuralLogic;
import org.encog.neural.networks.training.Train;
import org.encog.neural.networks.training.cross.CrossValidationKFold;
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
import org.encog.util.logging.Logging;
import org.encog.util.obj.SerializeObject;

public class SimpleNeuralNetwork
extends SimpleInterpolation {
    boolean isTrained = false;
    double error = 0.0;
    BasicNetwork network;
    int complexityAdjustmentFactor = 3;

    @Override
    public void setData(SimpleEnsemble[] x, SimpleEnsemble[] y) {
        super.setData(x, y);
        this.network = new BasicNetwork();
        this.isTrained = false;
    }

    public boolean save(File f) {
        try {
            SerializeObject.save((String)f.getAbsolutePath(), (Serializable)this.network);
        }
        catch (IOException ioe) {
            ioe.printStackTrace();
            return false;
        }
        return true;
    }

    public boolean load(File f) {
        try {
            this.network = (BasicNetwork)SerializeObject.load((String)f.getAbsolutePath());
        }
        catch (IOException ioe) {
            ioe.printStackTrace();
            return false;
        }
        catch (ClassNotFoundException nfe) {
            nfe.printStackTrace();
            return false;
        }
        return true;
    }

    @Override
    public double init() {
        return this.trainNetwork();
    }

    private double trainNetwork() {
        if (this.isTrained) {
            return this.error;
        }
        this.error = this.trainNetwork(new TreeSet<Integer>());
        return this.error;
    }

    public void setComplexityAdjustmentFactor(int complexityAdjustmentFactor) {
        this.complexityAdjustmentFactor = complexityAdjustmentFactor;
    }

    public int getComplexityAdjustmentFactor() {
        return this.complexityAdjustmentFactor;
    }

    private double trainNetwork(TreeSet<Integer> leaveOutIndex) {
        this.log("Train Neural Network");
        this.setProgress(0.0);
        Logging.setConsoleLevel((Level)Level.OFF);
        this.network = new BasicNetwork();
        this.network.addLayer((Layer)new BasicLayer((ActivationFunction)new ActivationSigmoid(), true, this.x.length));
        this.network.addLayer((Layer)new BasicLayer((ActivationFunction)new ActivationSigmoid(), true, this.complexityAdjustmentFactor * (this.m + ((this.x.length + 1) / 2 + 1))));
        this.network.addLayer((Layer)new BasicLayer((ActivationFunction)new ActivationSigmoid(), true, 1 * (this.m + ((this.x.length + 1) / 2 + 1))));
        this.network.addLayer((Layer)new BasicLayer((ActivationFunction)new ActivationLinear(), true, this.m));
        this.network.setLogic((NeuralLogic)new FeedforwardLogic());
        this.network.getStructure().finalizeStructure();
        this.network.reset();
        ArrayList<double[]> xData = new ArrayList<double[]>();
        ArrayList<double[]> yData = new ArrayList<double[]>();
        for (int i = 0; i < this.L; ++i) {
            int j;
            int id_i = this.x[0].getId(i);
            if (leaveOutIndex.contains(id_i)) continue;
            double[] sampleX = new double[this.n];
            double[] sampleY = new double[this.m];
            for (j = 0; j < this.n; ++j) {
                sampleX[j] = this.x[j].getValue(id_i);
            }
            for (j = 0; j < this.m; ++j) {
                sampleY[j] = this.y[j].getValue(id_i);
            }
            xData.add(this.normalizeX(sampleX));
            yData.add(this.normalizeY(sampleY));
        }
        double[][] xDataArray = (double[][])xData.toArray((T[])new double[xData.size()][]);
        double[][] yDataArray = (double[][])yData.toArray((T[])new double[yData.size()][]);
        BasicNeuralDataSet basicNDS = new BasicNeuralDataSet(xDataArray, yDataArray);
        basicNDS.setDescription("testdataset");
        FoldedDataSet folded = new FoldedDataSet((Indexable)basicNDS);
        ResilientPropagation train = new ResilientPropagation(this.network, (NeuralDataSet)folded);
        CrossValidationKFold trainFolded = new CrossValidationKFold((Train)train, 4);
        int epoch = 1;
        int epochMax = 10000;
        double improvement = 1.0;
        double errorNow = 0.0;
        double errorLast = 0.0;
        do {
            trainFolded.iteration();
            if (epoch % 100 == 0) {
                System.out.println("Epoch #" + epoch + " Error:" + trainFolded.getError() + " improvement " + improvement);
            }
            errorLast = errorNow;
            errorNow = trainFolded.getError();
            if (++epoch > 2) {
                improvement = 0.95 * improvement + 0.05 * (1.0 - errorNow / errorLast);
            }
            this.setProgress((double)epoch / (double)epochMax);
        } while (improvement > 1.0E-4 && !trainFolded.isTrainingDone() && epoch < epochMax);
        this.isTrained = true;
        return trainFolded.getError();
    }

    @Override
    protected double[][] getInterpolatedValue(TreeSet<Integer> validationSet) {
        this.isTrained = false;
        this.trainNetwork(validationSet);
        double[][] values = new double[validationSet.size()][];
        int counter = 0;
        for (Integer i : validationSet) {
            values[counter++] = this.getInterpolatedValue(this.getX(i));
        }
        return values;
    }

    @Override
    public double[] getInterpolatedValue(double[] u) {
        this.trainNetwork();
        double[] wholeOutput = new double[this.m];
        this.network.compute(this.normalizeX(u), wholeOutput);
        return this.denormalizeY(wholeOutput);
    }
}

