/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.lma;

import org.encog.mathutil.matrices.Matrix;
import org.encog.mathutil.matrices.decomposition.LUDecomposition;
import org.encog.neural.data.Indexable;
import org.encog.neural.data.NeuralData;
import org.encog.neural.data.NeuralDataPair;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.data.basic.BasicNeuralData;
import org.encog.neural.data.basic.BasicNeuralDataPair;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.structure.NetworkCODEC;
import org.encog.neural.networks.training.BasicTraining;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.lma.JacobianChainRule;

public class LevenbergMarquardtTraining
extends BasicTraining {
    public static final double SCALE_LAMBDA = 10.0;
    public static final double LAMBDA_MAX = 1.0E25;
    private final BasicNetwork network;
    private final Indexable indexableTraining;
    private final int trainingLength;
    private final int parametersLength;
    private double[] weights;
    private final Matrix hessianMatrix;
    private final double[][] hessian;
    private double alpha;
    private double beta;
    private double lambda;
    private final double[] gradient;
    private final double[] diagonal;
    private double[] deltas;
    private double gamma;
    private final NeuralDataPair pair;
    private boolean useBayesianRegularization;

    public static double trace(double[][] dArray) {
        double d = 0.0;
        for (int i = 0; i < dArray.length; ++i) {
            d += dArray[i][i];
        }
        return d;
    }

    public LevenbergMarquardtTraining(BasicNetwork basicNetwork, NeuralDataSet neuralDataSet) {
        if (!(neuralDataSet instanceof Indexable)) {
            throw new TrainingError("Levenberg Marquardt requires an indexable training set.");
        }
        Layer layer = basicNetwork.getLayer("OUTPUT");
        if (layer == null) {
            throw new TrainingError("Levenberg Marquardt requires an output layer.");
        }
        if (layer.getNeuronCount() != 1) {
            throw new TrainingError("Levenberg Marquardt requires an output layer with a single neuron.");
        }
        this.setTraining(neuralDataSet);
        this.indexableTraining = (Indexable)this.getTraining();
        this.network = basicNetwork;
        this.trainingLength = (int)this.indexableTraining.getRecordCount();
        this.parametersLength = this.network.getStructure().calculateSize();
        this.hessianMatrix = new Matrix(this.parametersLength, this.parametersLength);
        this.hessian = this.hessianMatrix.getData();
        this.alpha = 0.0;
        this.beta = 1.0;
        this.lambda = 0.1;
        this.deltas = new double[this.parametersLength];
        this.gradient = new double[this.parametersLength];
        this.diagonal = new double[this.parametersLength];
        BasicNeuralData basicNeuralData = new BasicNeuralData(this.indexableTraining.getInputSize());
        BasicNeuralData basicNeuralData2 = new BasicNeuralData(this.indexableTraining.getIdealSize());
        this.pair = new BasicNeuralDataPair(basicNeuralData, basicNeuralData2);
    }

    public void calculateHessian(double[][] dArray, double[] dArray2) {
        int n;
        for (n = 0; n < this.parametersLength; ++n) {
            int n2;
            double d = 0.0;
            for (n2 = 0; n2 < this.trainingLength; ++n2) {
                d += dArray[n2][n] * dArray2[n2];
            }
            this.gradient[n] = d;
            for (n2 = 0; n2 < this.parametersLength; ++n2) {
                double d2 = 0.0;
                for (int i = 0; i < this.trainingLength; ++i) {
                    d2 += dArray[i][n] * dArray[i][n2];
                }
                this.hessian[n][n2] = this.beta * d2;
            }
        }
        for (n = 0; n < this.parametersLength; ++n) {
            this.diagonal[n] = this.hessian[n][n];
        }
    }

    private double calculateSumOfSquaredWeights() {
        double d = 0.0;
        for (double d2 : this.weights) {
            d += d2 * d2;
        }
        return d / 2.0;
    }

    public BasicNetwork getNetwork() {
        return this.network;
    }

    public boolean isUseBayesianRegularization() {
        return this.useBayesianRegularization;
    }

    public void iteration() {
        LUDecomposition lUDecomposition = null;
        double d = 0.0;
        this.preIteration();
        this.weights = NetworkCODEC.networkToArray(this.network);
        JacobianChainRule jacobianChainRule = new JacobianChainRule(this.network, this.indexableTraining);
        double d2 = jacobianChainRule.calculate(this.weights);
        double d3 = this.calculateSumOfSquaredWeights();
        this.calculateHessian(jacobianChainRule.getJacobian(), jacobianChainRule.getRowErrors());
        double d4 = this.beta * d2 + this.alpha * d3;
        double d5 = d4 + 1.0;
        this.lambda /= 10.0;
        while (d5 >= d4 && this.lambda < 1.0E25) {
            int n;
            this.lambda *= 10.0;
            for (n = 0; n < this.parametersLength; ++n) {
                this.hessian[n][n] = this.diagonal[n] + (this.lambda + this.alpha);
            }
            lUDecomposition = new LUDecomposition(this.hessianMatrix);
            if (!lUDecomposition.isNonsingular()) continue;
            this.deltas = lUDecomposition.Solve(this.gradient);
            d3 = this.updateWeights();
            d2 = 0.0;
            for (n = 0; n < this.trainingLength; ++n) {
                this.indexableTraining.getRecord(n, this.pair);
                NeuralData neuralData = this.network.compute(this.pair.getInput());
                double d6 = this.pair.getIdeal().getData(0) - neuralData.getData(0);
                d2 += d6 * d6;
            }
            d5 = this.beta * (d2 /= 2.0) + this.alpha * d3;
        }
        this.lambda /= 10.0;
        if (this.useBayesianRegularization && lUDecomposition != null) {
            d = LevenbergMarquardtTraining.trace(lUDecomposition.inverse());
            this.gamma = (double)this.parametersLength - this.alpha * d;
            this.alpha = (double)this.parametersLength / (2.0 * d3 + d);
            this.beta = Math.abs(((double)this.trainingLength - this.gamma) / (2.0 * d2));
        }
        this.setError(d2);
        this.postIteration();
    }

    public void setUseBayesianRegularization(boolean bl) {
        this.useBayesianRegularization = bl;
    }

    public double updateWeights() {
        double d = 0.0;
        double[] dArray = (double[])this.weights.clone();
        for (int i = 0; i < dArray.length; ++i) {
            int n = i;
            dArray[n] = dArray[n] + this.deltas[i];
            d += dArray[i] * dArray[i];
        }
        NetworkCODEC.arrayToNetwork(dArray, this.network);
        return d / 2.0;
    }
}

