/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.svd;

import org.encog.engine.network.rbf.RadialBasisFunction;
import org.encog.engine.util.ObjectPair;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.layers.RadialBasisFunctionLayer;
import org.encog.neural.networks.structure.FlatUpdateNeeded;
import org.encog.neural.networks.training.BasicTraining;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.svd.SVD;
import org.encog.util.simple.TrainingSetUtil;

public class SVDTraining
extends BasicTraining {
    private BasicNetwork network;
    private RadialBasisFunctionLayer rbfLayer;

    public SVDTraining(BasicNetwork basicNetwork, NeuralDataSet neuralDataSet) {
        Layer layer = basicNetwork.getLayer("OUTPUT");
        if (layer == null) {
            throw new TrainingError("SVD requires an output layer.");
        }
        if (layer.getNeuronCount() != 1) {
            throw new TrainingError("SVD requires an output layer with a single neuron.");
        }
        if (basicNetwork.getLayer("RBF") == null) {
            throw new TrainingError("SVD is only tested to work on radial basis function networks.");
        }
        this.rbfLayer = (RadialBasisFunctionLayer)basicNetwork.getLayer("RBF");
        this.setTraining(neuralDataSet);
        this.network = basicNetwork;
    }

    public BasicNetwork getNetwork() {
        return this.network;
    }

    public void iteration() {
        Object object;
        RadialBasisFunction[] radialBasisFunctionArray = new RadialBasisFunction[this.rbfLayer.getNeuronCount()];
        for (int i = 0; i < this.rbfLayer.getNeuronCount(); ++i) {
            object = this.rbfLayer.getRadialBasisFunction()[i];
            radialBasisFunctionArray[i] = object;
        }
        ObjectPair<double[][], double[][]> objectPair = TrainingSetUtil.trainingToArray(this.getTraining());
        object = this.network.getStructure().getSynapses().get(0).getMatrix().getData();
        this.setError(SVD.svdfit(objectPair.getA(), objectPair.getB(), (double[][])object, radialBasisFunctionArray));
        this.network.getStructure().setFlatUpdate(FlatUpdateNeeded.Flatten);
    }
}

