/*
 * Decompiled with CFR 0.152.
 */
package org.encog.engine.network.train.gradient;

import org.encog.engine.data.BasicEngineData;
import org.encog.engine.data.EngineData;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.train.gradient.FlatGradientWorker;
import org.encog.engine.network.train.prop.TrainFlatNetworkProp;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.ErrorCalculation;
import org.encog.engine.util.Stopwatch;

public class GradientWorkerCPU
implements FlatGradientWorker {
    private final FlatNetwork network;
    private final ErrorCalculation errorCalculation = new ErrorCalculation();
    private final double[] actual;
    private final double[] layerDelta;
    private final int[] layerCounts;
    private final int[] layerFeedCounts;
    private final int[] layerIndex;
    private final int[] weightIndex;
    private final double[] layerOutput;
    private final double[] gradients;
    private final double[] weights;
    private final EngineData pair;
    private final EngineIndexableSet training;
    private final int low;
    private final int high;
    private final TrainFlatNetworkProp owner;
    private long elapsedTime;
    private final Stopwatch stopwatch;

    public GradientWorkerCPU(FlatNetwork flatNetwork, TrainFlatNetworkProp trainFlatNetworkProp, EngineIndexableSet engineIndexableSet, int n, int n2) {
        this.network = flatNetwork;
        this.training = engineIndexableSet;
        this.low = n;
        this.high = n2;
        this.owner = trainFlatNetworkProp;
        this.stopwatch = new Stopwatch();
        this.layerDelta = new double[flatNetwork.getLayerOutput().length];
        this.gradients = new double[flatNetwork.getWeights().length];
        this.actual = new double[flatNetwork.getOutputCount()];
        this.weights = flatNetwork.getWeights();
        this.layerIndex = flatNetwork.getLayerIndex();
        this.layerCounts = flatNetwork.getLayerCounts();
        this.weightIndex = flatNetwork.getWeightIndex();
        this.layerOutput = flatNetwork.getLayerOutput();
        this.layerFeedCounts = flatNetwork.getLayerFeedCounts();
        this.pair = BasicEngineData.createPair(flatNetwork.getInputCount(), flatNetwork.getOutputCount());
    }

    public long getElapsedTime() {
        return this.elapsedTime;
    }

    public FlatNetwork getNetwork() {
        return this.network;
    }

    public double[] getWeights() {
        return this.weights;
    }

    private void process(double[] dArray, double[] dArray2) {
        int n;
        this.network.compute(dArray, this.actual);
        this.errorCalculation.updateError(this.actual, dArray2);
        for (n = 0; n < this.actual.length; ++n) {
            this.layerDelta[n] = this.network.getActivationFunctions()[0].derivativeFunction(this.actual[n]) * (dArray2[n] - this.actual[n]);
        }
        for (n = this.network.getBeginTraining(); n < this.network.getEndTraining(); ++n) {
            this.processLevel(n);
        }
    }

    private void processLevel(int n) {
        int n2 = this.layerIndex[n + 1];
        int n3 = this.layerIndex[n];
        int n4 = this.layerCounts[n + 1];
        int n5 = this.layerFeedCounts[n];
        int n6 = this.weightIndex[n];
        ActivationFunction activationFunction = this.network.getActivationFunctions()[n + 1];
        int n7 = n2;
        for (int i = 0; i < n4; ++i) {
            double d = this.layerOutput[n7];
            double d2 = 0.0;
            int n8 = n3;
            int n9 = n6 + i;
            for (int j = 0; j < n5; ++j) {
                int n10 = n9;
                this.gradients[n10] = this.gradients[n10] + d * this.layerDelta[n8];
                d2 += this.weights[n9] * this.layerDelta[n8];
                n9 += n4;
                ++n8;
            }
            this.layerDelta[n7] = d2 * activationFunction.derivativeFunction(this.layerOutput[n7]);
            ++n7;
        }
    }

    public void run() {
        try {
            this.stopwatch.reset();
            this.stopwatch.start();
            this.errorCalculation.reset();
            for (int i = this.low; i <= this.high; ++i) {
                this.training.getRecord(i, this.pair);
                this.process(this.pair.getInputArray(), this.pair.getIdealArray());
            }
            double d = this.errorCalculation.calculate();
            this.owner.report(this.gradients, d, null);
            EngineArray.fill(this.gradients, 0.0);
            this.stopwatch.stop();
            this.elapsedTime = this.stopwatch.getElapsedTicks();
        }
        catch (Throwable throwable) {
            this.owner.report(null, 0.0, throwable);
        }
    }
}

