/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.cpn;

import org.encog.engine.util.ErrorCalculation;
import org.encog.neural.data.NeuralData;
import org.encog.neural.data.NeuralDataPair;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.structure.FlatUpdateNeeded;
import org.encog.neural.networks.training.BasicTraining;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.cpn.FindCPN;

public class TrainOutstar
extends BasicTraining
implements LearningRate {
    private double learningRate;
    private final BasicNetwork network;
    private final NeuralDataSet training;
    private boolean mustInit = true;
    private final FindCPN parts;

    public TrainOutstar(BasicNetwork basicNetwork, NeuralDataSet neuralDataSet, double d) {
        this.network = basicNetwork;
        this.training = neuralDataSet;
        this.learningRate = d;
        this.parts = new FindCPN(this.network);
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public BasicNetwork getNetwork() {
        return this.network;
    }

    private void initWeight() {
        for (int i = 0; i < this.parts.getOutstarLayer().getNeuronCount(); ++i) {
            int n = 0;
            for (NeuralDataPair neuralDataPair : this.training) {
                this.parts.getOutstarSynapse().getMatrix().set(n++, i, neuralDataPair.getIdeal().getData(i));
            }
        }
        this.mustInit = false;
        this.network.getStructure().setFlatUpdate(FlatUpdateNeeded.Flatten);
    }

    public void iteration() {
        if (this.mustInit) {
            this.initWeight();
        }
        ErrorCalculation errorCalculation = new ErrorCalculation();
        for (NeuralDataPair neuralDataPair : this.training) {
            NeuralData neuralData = this.parts.getInstarSynapse().compute(neuralDataPair.getInput());
            errorCalculation.updateError(neuralData.getData(), neuralDataPair.getIdeal().getData());
            int n = this.parts.winner(neuralData);
            for (int i = 0; i < this.parts.getOutstarLayer().getNeuronCount(); ++i) {
                double d = this.learningRate * (neuralDataPair.getIdeal().getData(i) - this.parts.getOutstarSynapse().getMatrix().get(n, i));
                this.parts.getOutstarSynapse().getMatrix().add(n, i, d);
            }
        }
        this.setError(errorCalculation.calculate());
        this.network.getStructure().setFlatUpdate(FlatUpdateNeeded.Flatten);
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }
}

