/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.prune;

import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.encog.mathutil.matrices.Matrix;
import org.encog.mathutil.matrices.MatrixMath;
import org.encog.mathutil.randomize.Distort;
import org.encog.neural.NeuralNetworkError;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.synapse.Synapse;
import org.encog.persist.EncogPersistedObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PruneSelective {
    private final BasicNetwork network;
    private final Logger logger = LoggerFactory.getLogger(this.getClass());

    public PruneSelective(BasicNetwork basicNetwork) {
        this.network = basicNetwork;
    }

    public void changeNeuronCount(Layer layer, int n) {
        if (n == 0) {
            throw new NeuralNetworkError("Can't decrease to zero neurons.");
        }
        if (n == layer.getNeuronCount()) {
            return;
        }
        if (n > layer.getNeuronCount()) {
            this.increaseNeuronCount(layer, n);
        } else {
            this.decreaseNeuronCount(layer, n);
        }
    }

    private void decreaseNeuronCount(Layer layer, int n) {
        int n2 = layer.getNeuronCount() - n;
        int[] nArray = this.findWeakestNeurons(layer, n2);
        for (int i = 0; i < n2; ++i) {
            this.prune(layer, nArray[i] - i);
        }
    }

    public double determineNeuronSignificance(Layer layer, int n) {
        double d = 0.0;
        if (layer.hasBias()) {
            d += layer.getBiasWeight(n);
        }
        for (Synapse object2 : layer.getNext()) {
            for (int i = 0; i < object2.getToNeuronCount(); ++i) {
                d += object2.getMatrix().get(n, i);
            }
        }
        List<Synapse> list = this.network.getStructure().getPreviousSynapses(layer);
        Iterator iterator = list.iterator();
        while (iterator.hasNext()) {
            Synapse synapse = (Synapse)iterator.next();
            if (synapse.getMatrix() == null) continue;
            for (int i = 0; i < synapse.getFromNeuronCount(); ++i) {
                d += synapse.getMatrix().get(i, n);
            }
        }
        return Math.abs(d);
    }

    private int[] findWeakestNeurons(Layer layer, int n) {
        int n2;
        double[] dArray = new double[n];
        int[] nArray = new int[n];
        for (n2 = 0; n2 < n; ++n2) {
            nArray[n2] = n2;
            dArray[n2] = this.determineNeuronSignificance(layer, n2);
        }
        block1: for (n2 = n; n2 < layer.getNeuronCount(); ++n2) {
            double d = this.determineNeuronSignificance(layer, n2);
            for (int i = 0; i < n; ++i) {
                if (!(dArray[i] > d)) continue;
                nArray[i] = n2;
                dArray[i] = d;
                continue block1;
            }
        }
        return nArray;
    }

    public BasicNetwork getNetwork() {
        return this.network;
    }

    private void increaseNeuronCount(Layer layer, int n) {
        int n2;
        EncogPersistedObject encogPersistedObject;
        double[] dArray = new double[n];
        if (layer.hasBias()) {
            for (int i = 0; i < layer.getNeuronCount(); ++i) {
                dArray[i] = layer.getBiasWeight(i);
            }
            layer.setBiasWeights(dArray);
        }
        for (Synapse object2 : layer.getNext()) {
            encogPersistedObject = new Matrix(n, object2.getToNeuronCount());
            for (int matrix = 0; matrix < layer.getNeuronCount(); ++matrix) {
                for (n2 = 0; n2 < object2.getToNeuronCount(); ++n2) {
                    ((Matrix)encogPersistedObject).set(matrix, n2, object2.getMatrix().get(matrix, n2));
                }
            }
            object2.setMatrix((Matrix)encogPersistedObject);
        }
        List<Synapse> list = this.network.getStructure().getPreviousSynapses(layer);
        Iterator iterator = list.iterator();
        while (iterator.hasNext()) {
            encogPersistedObject = (Synapse)iterator.next();
            if (encogPersistedObject.getMatrix() == null) continue;
            Matrix matrix = new Matrix(encogPersistedObject.getFromNeuronCount(), n);
            for (n2 = 0; n2 < encogPersistedObject.getFromNeuronCount(); ++n2) {
                for (int i = 0; i < encogPersistedObject.getToNeuronCount(); ++i) {
                    matrix.set(n2, i, encogPersistedObject.getMatrix().get(n2, i));
                }
            }
            encogPersistedObject.setMatrix(matrix);
        }
        if (layer.hasBias()) {
            double[] dArray2 = new double[n];
            for (int i = 0; i < layer.getNeuronCount(); ++i) {
                dArray2[i] = layer.getBiasWeight(i);
            }
            layer.setBiasWeights(dArray2);
        }
        layer.setNeuronCount(n);
    }

    public void prune(Layer layer, int n) {
        for (Synapse object2 : layer.getNext()) {
            object2.setMatrix(MatrixMath.deleteRow(object2.getMatrix(), n));
        }
        Collection<Layer> collection = this.network.getStructure().getPreviousLayers(layer);
        Iterator iterator = collection.iterator();
        while (iterator.hasNext()) {
            Layer layer2 = (Layer)iterator.next();
            if (collection == null) continue;
            for (Synapse synapse : layer2.getNext()) {
                if (synapse.getMatrix() == null) continue;
                synapse.setMatrix(MatrixMath.deleteCol(synapse.getMatrix(), n));
            }
        }
        if (layer.hasBias()) {
            double[] dArray = new double[layer.getNeuronCount() - 1];
            int n2 = 0;
            for (int i = 0; i < layer.getNeuronCount(); ++i) {
                if (i == n) continue;
                dArray[n2++] = layer.getBiasWeight(i);
            }
            layer.setBiasWeights(dArray);
        }
        layer.setNeuronCount(layer.getNeuronCount() - 1);
    }

    public void stimulateNeuron(double d, Layer layer, int n) {
        Distort distort = new Distort(d);
        if (layer.hasBias()) {
            layer.setBiasWeight(n, distort.randomize(layer.getBiasWeight(n)));
        }
        for (Synapse object2 : layer.getNext()) {
            for (int i = 0; i < object2.getToNeuronCount(); ++i) {
                double i2 = object2.getMatrix().get(n, i);
                object2.getMatrix().set(n, i, distort.randomize(i2));
            }
        }
        List<Synapse> list = this.network.getStructure().getPreviousSynapses(layer);
        Iterator iterator = list.iterator();
        while (iterator.hasNext()) {
            Synapse synapse = (Synapse)iterator.next();
            for (int i = 0; i < synapse.getFromNeuronCount(); ++i) {
                double d2 = synapse.getMatrix().get(i, n);
                synapse.getMatrix().set(i, n, distort.randomize(d2));
            }
        }
    }

    public void stimulateWeakNeurons(Layer layer, int n, double d) {
        int[] nArray;
        for (int n2 : nArray = this.findWeakestNeurons(layer, n)) {
            this.stimulateNeuron(d, layer, n2);
        }
    }
}

