/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.prune;

import java.util.ArrayList;
import java.util.List;
import org.encog.engine.StatusReportable;
import org.encog.engine.concurrency.job.ConcurrentJob;
import org.encog.engine.concurrency.job.JobUnitContext;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.data.buffer.BufferedNeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
import org.encog.neural.networks.training.strategy.StopTrainingStrategy;
import org.encog.neural.pattern.NeuralNetworkPattern;
import org.encog.neural.prune.HiddenLayerParams;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class PruneIncremental
extends ConcurrentJob {
    private boolean done = false;
    private final Logger logger = LoggerFactory.getLogger(this.getClass());
    private final NeuralDataSet training;
    private final NeuralNetworkPattern pattern;
    private final List<HiddenLayerParams> hidden = new ArrayList<HiddenLayerParams>();
    private final int iterations;
    private final BasicNetwork[] topNetworks;
    private final double[] topErrors;
    private BasicNetwork bestNetwork;
    private int currentTry;
    private final StatusReportable report;
    private int[] hiddenCounts;
    private double high;
    private double low;
    private double[][] results;
    private int hidden1Size;
    private int hidden2Size;
    private final int weightTries;

    public static String networkToString(BasicNetwork basicNetwork) {
        StringBuilder stringBuilder = new StringBuilder();
        int n = 1;
        Layer layer = basicNetwork.getLayer("INPUT");
        Layer[] layerArray = new Layer[basicNetwork.getStructure().getLayers().size()];
        boolean bl = false;
        while (layer.getNext().size() > 0 && !bl) {
            layer = layer.getNext().get(0).getToLayer();
            for (int i = 0; i < layerArray.length; ++i) {
                if (layer == layerArray[i]) {
                    bl = true;
                    break;
                }
                if (layerArray[i] != null) continue;
                layerArray[i] = layer;
                break;
            }
            if (layer.getNext().size() <= 0 || bl) continue;
            if (stringBuilder.length() > 0) {
                stringBuilder.append(",");
            }
            stringBuilder.append("H");
            stringBuilder.append(n++);
            stringBuilder.append("=");
            stringBuilder.append(layer.getNeuronCount());
        }
        return stringBuilder.toString();
    }

    public PruneIncremental(NeuralDataSet neuralDataSet, NeuralNetworkPattern neuralNetworkPattern, int n, int n2, int n3, StatusReportable statusReportable) {
        super(statusReportable);
        this.training = neuralDataSet;
        this.pattern = neuralNetworkPattern;
        this.iterations = n;
        this.report = statusReportable;
        this.weightTries = n2;
        this.topNetworks = new BasicNetwork[n3];
        this.topErrors = new double[n3];
    }

    public void addHiddenLayer(int n, int n2) {
        HiddenLayerParams hiddenLayerParams = new HiddenLayerParams(n, n2);
        this.hidden.add(hiddenLayerParams);
    }

    private BasicNetwork generateNetwork() {
        this.pattern.clear();
        for (int n : this.hiddenCounts) {
            if (n <= 0) continue;
            this.pattern.addHiddenLayer(n);
        }
        return this.pattern.generate();
    }

    public BasicNetwork getBestNetwork() {
        return this.bestNetwork;
    }

    public List<HiddenLayerParams> getHidden() {
        return this.hidden;
    }

    public int getHidden1Size() {
        return this.hidden1Size;
    }

    public int getHidden2Size() {
        return this.hidden2Size;
    }

    public double getHigh() {
        return this.high;
    }

    public int getIterations() {
        return this.iterations;
    }

    public double getLow() {
        return this.low;
    }

    public NeuralNetworkPattern getPattern() {
        return this.pattern;
    }

    public double[][] getResults() {
        return this.results;
    }

    public double[] getTopErrors() {
        return this.topErrors;
    }

    public BasicNetwork[] getTopNetworks() {
        return this.topNetworks;
    }

    public NeuralDataSet getTraining() {
        return this.training;
    }

    private boolean increaseHiddenCounts() {
        int n = 0;
        do {
            HiddenLayerParams hiddenLayerParams = this.hidden.get(n);
            int n2 = n;
            this.hiddenCounts[n2] = this.hiddenCounts[n2] + 1;
            if (this.hiddenCounts[n] <= hiddenLayerParams.getMax()) {
                return true;
            }
            this.hiddenCounts[n] = hiddenLayerParams.getMin();
        } while (++n < this.hiddenCounts.length);
        return false;
    }

    public void init() {
        if (this.hidden.size() == 1) {
            this.hidden1Size = this.hidden.get(0).getMax() - this.hidden.get(0).getMin() + 1;
            this.hidden2Size = 0;
            this.results = new double[this.hidden1Size][1];
        } else if (this.hidden.size() == 2) {
            this.hidden1Size = this.hidden.get(0).getMax() - this.hidden.get(0).getMin() + 1;
            this.hidden2Size = this.hidden.get(1).getMax() - this.hidden.get(1).getMin() + 1;
            this.results = new double[this.hidden1Size][this.hidden2Size];
        } else {
            this.hidden1Size = 0;
            this.hidden2Size = 0;
            this.results = null;
        }
        this.high = Double.NEGATIVE_INFINITY;
        this.low = Double.POSITIVE_INFINITY;
    }

    @Override
    public int loadWorkload() {
        int n = 1;
        for (HiddenLayerParams hiddenLayerParams : this.hidden) {
            n *= hiddenLayerParams.getMax() - hiddenLayerParams.getMin() + 1;
        }
        this.init();
        return n;
    }

    @Override
    public void performJobUnit(JobUnitContext jobUnitContext) {
        int n;
        int n2;
        BasicNetwork basicNetwork = (BasicNetwork)jobUnitContext.getJobUnit();
        BufferedNeuralDataSet bufferedNeuralDataSet = null;
        NeuralDataSet neuralDataSet = this.training;
        if (this.training instanceof BufferedNeuralDataSet) {
            bufferedNeuralDataSet = (BufferedNeuralDataSet)this.training;
            neuralDataSet = bufferedNeuralDataSet.openAdditional();
        }
        double d = Double.POSITIVE_INFINITY;
        for (n2 = 0; n2 < this.weightTries; ++n2) {
            basicNetwork.reset();
            ResilientPropagation resilientPropagation = new ResilientPropagation(basicNetwork, neuralDataSet);
            StopTrainingStrategy stopTrainingStrategy = new StopTrainingStrategy(0.001, 5);
            resilientPropagation.addStrategy(stopTrainingStrategy);
            resilientPropagation.setNumThreads(1);
            for (n = 0; n < this.iterations && !this.getShouldStop() && !stopTrainingStrategy.shouldStop(); ++n) {
                resilientPropagation.iteration();
            }
            d = Math.min(d, resilientPropagation.getError());
        }
        if (bufferedNeuralDataSet != null) {
            bufferedNeuralDataSet.close();
        }
        if (!this.getShouldStop()) {
            this.high = Math.max(this.high, d);
            this.low = Math.min(this.low, d);
            if (this.hidden1Size > 0) {
                int n3;
                int n4;
                if (basicNetwork.getStructure().getLayers().size() > 3) {
                    n4 = basicNetwork.getStructure().getLayers().get(1).getNeuronCount();
                    n2 = basicNetwork.getStructure().getLayers().get(2).getNeuronCount();
                } else {
                    n4 = 0;
                    n2 = basicNetwork.getStructure().getLayers().get(1).getNeuronCount();
                }
                if (this.hidden2Size == 0) {
                    n3 = n2 - this.hidden.get(0).getMin();
                    n = 0;
                } else {
                    n3 = n2 - this.hidden.get(0).getMin();
                    n = n4 - this.hidden.get(1).getMin();
                }
                this.results[n3][n] = d;
            }
            ++this.currentTry;
            this.updateBest(basicNetwork, d);
            this.reportStatus(jobUnitContext, "Current: " + PruneIncremental.networkToString(basicNetwork) + "; Best: " + PruneIncremental.networkToString(this.bestNetwork));
        }
    }

    @Override
    public void process() {
        if (this.hidden.size() == 0 && this.logger.isErrorEnabled()) {
            this.logger.error("To calculate the optimal hidden size, at least one hidden layer must be defined.");
        }
        this.hiddenCounts = new int[this.hidden.size()];
        this.bestNetwork = null;
        int n = 0;
        for (HiddenLayerParams hiddenLayerParams : this.hidden) {
            this.hiddenCounts[n++] = hiddenLayerParams.getMin();
        }
        if (this.hiddenCounts[0] == 0 && this.logger.isErrorEnabled()) {
            this.logger.error("To calculate the optimal hidden size, at least one neuron must be the minimum for the first hidden layer.");
        }
        super.process();
    }

    @Override
    public Object requestNextTask() {
        if (this.done || this.getShouldStop()) {
            return null;
        }
        BasicNetwork basicNetwork = this.generateNetwork();
        if (!this.increaseHiddenCounts()) {
            this.done = true;
        }
        return basicNetwork;
    }

    private synchronized void updateBest(BasicNetwork basicNetwork, double d) {
        this.high = Math.max(this.high, d);
        this.low = Math.min(this.low, d);
        int n = -1;
        for (int i = 0; i < this.topNetworks.length; ++i) {
            if (this.topNetworks[i] == null) {
                n = i;
                break;
            }
            if (!(this.topErrors[i] > d) || n != -1 && !(this.topErrors[n] < this.topErrors[i])) continue;
            n = i;
        }
        if (n != -1) {
            this.topErrors[n] = d;
            this.topNetworks[n] = basicNetwork;
        }
        BasicNetwork basicNetwork2 = null;
        for (BasicNetwork basicNetwork3 : this.topNetworks) {
            if (basicNetwork3 == null) continue;
            if (basicNetwork2 == null) {
                basicNetwork2 = basicNetwork3;
                continue;
            }
            if (basicNetwork3.getStructure().calculateSize() >= basicNetwork2.getStructure().calculateSize()) continue;
            basicNetwork2 = basicNetwork3;
        }
        if (basicNetwork2 != this.bestNetwork) {
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("Prune found new best network: error=" + d + ", network=" + basicNetwork2);
            }
            this.bestNetwork = basicNetwork2;
        }
    }
}

