/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.cpn;

import org.encog.engine.util.BoundMath;
import org.encog.neural.data.NeuralData;
import org.encog.neural.data.NeuralDataPair;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.structure.FlatUpdateNeeded;
import org.encog.neural.networks.training.BasicTraining;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.cpn.FindCPN;

public class TrainInstar
extends BasicTraining
implements LearningRate {
    private final BasicNetwork network;
    private final NeuralDataSet training;
    private double learningRate;
    private boolean mustInit = true;
    private final FindCPN parts;

    public TrainInstar(BasicNetwork basicNetwork, NeuralDataSet neuralDataSet, double d) {
        this.network = basicNetwork;
        this.training = neuralDataSet;
        this.learningRate = d;
        this.parts = new FindCPN(basicNetwork);
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public BasicNetwork getNetwork() {
        return this.network;
    }

    private void initWeights() {
        int n = 0;
        for (NeuralDataPair neuralDataPair : this.training) {
            for (int i = 0; i < this.parts.getInputLayer().getNeuronCount(); ++i) {
                this.parts.getInstarSynapse().getMatrix().set(i, n, neuralDataPair.getInput().getData(i));
            }
            ++n;
        }
        this.network.getStructure().setFlatUpdate(FlatUpdateNeeded.Flatten);
        this.mustInit = false;
    }

    public void iteration() {
        if (this.mustInit) {
            this.initWeights();
        }
        double d = Double.NEGATIVE_INFINITY;
        for (NeuralDataPair neuralDataPair : this.training) {
            double d2;
            int n;
            NeuralData neuralData = this.parts.getInstarSynapse().compute(neuralDataPair.getInput());
            int n2 = this.parts.winner(neuralData);
            double d3 = 0.0;
            for (n = 0; n < neuralDataPair.getInput().size(); ++n) {
                d2 = neuralDataPair.getInput().getData(n) - this.parts.getInstarSynapse().getMatrix().get(n, n2);
                d3 += d2 * d2;
            }
            if ((d3 = BoundMath.sqrt(d3)) > d) {
                d = d3;
            }
            for (n = 0; n < this.parts.getInstarSynapse().getFromNeuronCount(); ++n) {
                d2 = this.learningRate * (neuralDataPair.getInput().getData(n) - this.parts.getInstarSynapse().getMatrix().get(n, n2));
                this.parts.getInstarSynapse().getMatrix().add(n, n2, d2);
            }
        }
        this.network.getStructure().setFlatUpdate(FlatUpdateNeeded.Flatten);
        this.setError(d);
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }
}

