/*
 * Decompiled with CFR 0.152.
 */
package org.encog.engine.network.train.prop;

import org.encog.engine.EncogEngineError;
import org.encog.engine.concurrency.DetermineWorkload;
import org.encog.engine.concurrency.EngineConcurrency;
import org.encog.engine.concurrency.TaskGroup;
import org.encog.engine.data.EngineDataSet;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.train.TrainFlatNetwork;
import org.encog.engine.network.train.gradient.FlatGradientWorker;
import org.encog.engine.network.train.gradient.GradientWorkerCPU;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.IntRange;

public abstract class TrainFlatNetworkProp
implements TrainFlatNetwork {
    private int numThreads;
    protected double[] gradients;
    private double[] lastGradient;
    protected final FlatNetwork network;
    private final EngineDataSet training;
    private final EngineIndexableSet indexable;
    private FlatGradientWorker[] workers;
    private double totalError;
    protected double currentError;
    private Throwable reportedException;
    private int iteration;

    public TrainFlatNetworkProp(FlatNetwork flatNetwork, EngineDataSet engineDataSet) {
        if (!(engineDataSet instanceof EngineIndexableSet)) {
            throw new EncogEngineError("Training data must be Indexable for this training type.");
        }
        this.training = engineDataSet;
        this.network = flatNetwork;
        this.gradients = new double[this.network.getWeights().length];
        this.lastGradient = new double[this.network.getWeights().length];
        this.indexable = (EngineIndexableSet)engineDataSet;
        this.numThreads = 0;
        this.reportedException = null;
    }

    public void calculateGradients() {
        if (this.workers == null) {
            this.init();
        }
        this.workers[0].getNetwork().clearContext();
        this.totalError = 0.0;
        if (this.workers.length > 1) {
            TaskGroup taskGroup = EngineConcurrency.getInstance().createTaskGroup();
            for (FlatGradientWorker flatGradientWorker : this.workers) {
                EngineConcurrency.getInstance().processTask(flatGradientWorker, taskGroup);
            }
            taskGroup.waitForComplete();
        } else {
            this.workers[0].run();
        }
        this.currentError = this.totalError / (double)this.workers.length;
    }

    private void copyContexts() {
        for (int i = 0; i < this.workers.length - 1; ++i) {
            double[] dArray = this.workers[i].getNetwork().getLayerOutput();
            double[] dArray2 = this.workers[i + 1].getNetwork().getLayerOutput();
            EngineArray.arrayCopy(dArray, dArray2);
        }
        EngineArray.arrayCopy(this.workers[this.workers.length - 1].getNetwork().getLayerOutput(), this.network.getLayerOutput());
    }

    public void finishTraining() {
    }

    public double getError() {
        return this.currentError;
    }

    public double[] getLastGradient() {
        return this.lastGradient;
    }

    public FlatNetwork getNetwork() {
        return this.network;
    }

    public int getNumThreads() {
        return this.numThreads;
    }

    public EngineDataSet getTraining() {
        return this.training;
    }

    private void init() {
        DetermineWorkload determineWorkload = new DetermineWorkload(this.numThreads, (int)this.indexable.getRecordCount());
        this.workers = new FlatGradientWorker[determineWorkload.getThreadCount()];
        int n = 0;
        for (IntRange intRange : determineWorkload.calculateWorkers()) {
            this.workers[n++] = new GradientWorkerCPU(this.network.clone(), this, this.indexable.openAdditional(), intRange.getLow(), intRange.getHigh());
        }
    }

    public void iteration() {
        ++this.iteration;
        this.calculateGradients();
        if (this.network.isLimited()) {
            this.learnLimited();
        } else {
            this.learn();
        }
        for (FlatGradientWorker flatGradientWorker : this.workers) {
            EngineArray.arrayCopy(this.network.getWeights(), 0, flatGradientWorker.getWeights(), 0, this.network.getWeights().length);
        }
        this.copyContexts();
        if (this.reportedException != null) {
            throw new EncogEngineError(this.reportedException);
        }
    }

    protected void learn() {
        double[] dArray = this.network.getWeights();
        for (int i = 0; i < this.gradients.length; ++i) {
            int n = i;
            dArray[n] = dArray[n] + this.updateWeight(this.gradients, this.lastGradient, i);
            this.gradients[i] = 0.0;
        }
    }

    protected void learnLimited() {
        double d = this.network.getConnectionLimit();
        double[] dArray = this.network.getWeights();
        for (int i = 0; i < this.gradients.length; ++i) {
            if (dArray[i] < d) {
                dArray[i] = 0.0;
            } else {
                int n = i;
                dArray[n] = dArray[n] + this.updateWeight(this.gradients, this.lastGradient, i);
            }
            this.gradients[i] = 0.0;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void report(double[] dArray, double d, Throwable throwable) {
        TrainFlatNetworkProp trainFlatNetworkProp = this;
        synchronized (trainFlatNetworkProp) {
            if (throwable == null) {
                for (int i = 0; i < dArray.length; ++i) {
                    int n = i;
                    this.gradients[n] = this.gradients[n] + dArray[i];
                }
                this.totalError += d;
            } else {
                this.reportedException = throwable;
            }
        }
    }

    public void setNumThreads(int n) {
        this.numThreads = n;
    }

    public abstract double updateWeight(double[] var1, double[] var2, int var3);

    public void iteration(int n) {
        for (int i = 0; i < n; ++i) {
            this.iteration();
        }
    }

    public int getIteration() {
        return this.iteration;
    }

    public void setIteration(int n) {
        this.iteration = n;
    }
}

