/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.competitive;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import org.encog.engine.util.Format;
import org.encog.mathutil.matrices.Matrix;
import org.encog.neural.data.NeuralData;
import org.encog.neural.data.NeuralDataPair;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.structure.FlatUpdateNeeded;
import org.encog.neural.networks.synapse.Synapse;
import org.encog.neural.networks.training.BasicTraining;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.competitive.BestMatchingUnit;
import org.encog.neural.networks.training.competitive.neighborhood.NeighborhoodFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CompetitiveTraining
extends BasicTraining
implements LearningRate {
    private final NeighborhoodFunction neighborhood;
    private double learningRate;
    private final BasicNetwork network;
    private final Layer inputLayer;
    private final Layer outputLayer;
    private final Collection<Synapse> synapses;
    private final int inputNeuronCount;
    private final int outputNeuronCount;
    private final BestMatchingUnit bmuUtil;
    private final Map<Synapse, Matrix> correctionMatrix = new HashMap<Synapse, Matrix>();
    private boolean forceWinner;
    private double startRate;
    private double endRate;
    private double startRadius;
    private double endRadius;
    private double autoDecayRate;
    private double autoDecayRadius;
    private final Logger logger = LoggerFactory.getLogger(this.getClass());
    private double radius;

    public CompetitiveTraining(BasicNetwork basicNetwork, double d, NeuralDataSet neuralDataSet, NeighborhoodFunction neighborhoodFunction) {
        this.neighborhood = neighborhoodFunction;
        this.setTraining(neuralDataSet);
        this.learningRate = d;
        this.network = basicNetwork;
        this.inputLayer = basicNetwork.getLayer("INPUT");
        this.outputLayer = basicNetwork.getLayer("OUTPUT");
        this.synapses = basicNetwork.getStructure().getPreviousSynapses(this.outputLayer);
        this.inputNeuronCount = this.inputLayer.getNeuronCount();
        this.outputNeuronCount = this.outputLayer.getNeuronCount();
        this.forceWinner = false;
        this.setError(0.0);
        for (Synapse synapse : this.synapses) {
            Matrix matrix = new Matrix(synapse.getMatrix().getRows(), synapse.getMatrix().getCols());
            this.correctionMatrix.put(synapse, matrix);
        }
        this.bmuUtil = new BestMatchingUnit(this);
    }

    private void applyCorrection() {
        for (Map.Entry<Synapse, Matrix> entry : this.correctionMatrix.entrySet()) {
            entry.getKey().getMatrix().set(entry.getValue());
        }
        this.network.getStructure().setFlatUpdate(FlatUpdateNeeded.Flatten);
    }

    public void autoDecay() {
        if (this.radius > this.endRadius) {
            this.radius += this.autoDecayRadius;
        }
        if (this.learningRate > this.endRate) {
            this.learningRate += this.autoDecayRate;
        }
        this.getNeighborhood().setRadius(this.radius);
    }

    private void copyInputPattern(Synapse synapse, int n, NeuralData neuralData) {
        for (int i = 0; i < this.inputNeuronCount; ++i) {
            synapse.getMatrix().set(i, n, neuralData.getData(i));
        }
    }

    public void decay(double d) {
        this.radius *= 1.0 - d;
        this.learningRate *= 1.0 - d;
    }

    public void decay(double d, double d2) {
        this.radius *= 1.0 - d2;
        this.learningRate *= 1.0 - d;
        this.getNeighborhood().setRadius(this.radius);
    }

    private double determineNewWeight(double d, double d2, int n, int n2) {
        double d3 = d + this.neighborhood.function(n, n2) * this.learningRate * (d2 - d);
        return d3;
    }

    private boolean forceWinners(Synapse synapse, int[] nArray, NeuralData neuralData) {
        double d = Double.MIN_VALUE;
        int n = -1;
        NeuralData neuralData2 = this.network.compute(neuralData);
        for (int i = 0; i < nArray.length; ++i) {
            if (nArray[i] != 0 || n != -1 && !(neuralData2.getData(i) > d)) continue;
            d = neuralData2.getData(i);
            n = i;
        }
        if (n != -1) {
            this.copyInputPattern(synapse, n, neuralData);
            return true;
        }
        return false;
    }

    public int getInputNeuronCount() {
        return this.inputNeuronCount;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public NeighborhoodFunction getNeighborhood() {
        return this.neighborhood;
    }

    public BasicNetwork getNetwork() {
        return this.network;
    }

    public int getOutputNeuronCount() {
        return this.outputNeuronCount;
    }

    public boolean isForceWinner() {
        return this.forceWinner;
    }

    public void iteration() {
        if (this.logger.isInfoEnabled()) {
            this.logger.info("Performing Competitive Training iteration.");
        }
        this.preIteration();
        this.bmuUtil.reset();
        int[] nArray = new int[this.outputNeuronCount];
        double d = Double.MAX_VALUE;
        NeuralData neuralData = null;
        for (Synapse synapse : this.synapses) {
            Matrix matrix = this.correctionMatrix.get(synapse);
            matrix.clear();
            for (NeuralDataPair neuralDataPair : this.getTraining()) {
                NeuralData neuralData2 = neuralDataPair.getInput();
                int n = this.bmuUtil.calculateBMU(synapse, neuralData2);
                if (this.forceWinner) {
                    int n2 = n;
                    nArray[n2] = nArray[n2] + 1;
                    NeuralData neuralData3 = this.network.compute(neuralDataPair.getInput());
                    if (neuralData3.getData(n) < d) {
                        d = neuralData3.getData(n);
                        neuralData = neuralDataPair.getInput();
                    }
                }
                this.train(n, synapse, neuralData2);
            }
            if (this.forceWinner) {
                if (this.forceWinners(synapse, nArray, neuralData)) continue;
                this.applyCorrection();
                continue;
            }
            this.applyCorrection();
        }
        this.network.getStructure().setFlatUpdate(FlatUpdateNeeded.Flatten);
        this.setError(this.bmuUtil.getWorstDistance());
        this.postIteration();
    }

    public void setAutoDecay(int n, double d, double d2, double d3, double d4) {
        this.startRate = d;
        this.endRate = d2;
        this.startRadius = d3;
        this.endRadius = d4;
        this.autoDecayRadius = (d4 - d3) / (double)n;
        this.autoDecayRate = (d2 - d) / (double)n;
        this.setParams(this.startRate, this.startRadius);
    }

    public void setForceWinner(boolean bl) {
        this.forceWinner = bl;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public void setParams(double d, double d2) {
        this.radius = d2;
        this.learningRate = d;
        this.getNeighborhood().setRadius(d2);
    }

    public String toString() {
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("Rate=");
        stringBuilder.append(Format.formatPercent(this.learningRate));
        stringBuilder.append(", Radius=");
        stringBuilder.append(Format.formatDouble(this.radius, 2));
        return stringBuilder.toString();
    }

    private void train(int n, Synapse synapse, NeuralData neuralData) {
        for (int i = 0; i < this.outputNeuronCount; ++i) {
            this.trainPattern(synapse, neuralData, i, n);
        }
    }

    public void trainPattern(NeuralData neuralData) {
        for (Synapse synapse : this.synapses) {
            NeuralData neuralData2 = neuralData;
            int n = this.bmuUtil.calculateBMU(synapse, neuralData2);
            this.train(n, synapse, neuralData2);
        }
        this.applyCorrection();
    }

    private void trainPattern(Synapse synapse, NeuralData neuralData, int n, int n2) {
        Matrix matrix = this.correctionMatrix.get(synapse);
        for (int i = 0; i < this.inputNeuronCount; ++i) {
            double d = synapse.getMatrix().get(i, n);
            double d2 = neuralData.getData(i);
            double d3 = this.determineNewWeight(d, d2, n, n2);
            matrix.set(i, n, d3);
        }
    }
}

