/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.propagation.back;

import org.encog.engine.network.train.prop.OpenCLTrainingProfile;
import org.encog.engine.network.train.prop.TrainFlatNetworkBackPropagation;
import org.encog.engine.network.train.prop.TrainFlatNetworkOpenCL;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.Momentum;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.neural.networks.training.strategy.SmartLearningRate;
import org.encog.neural.networks.training.strategy.SmartMomentum;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Backpropagation
extends Propagation
implements Momentum,
LearningRate {
    public static final String LAST_DELTA = "LAST_DELTA";
    private final Logger logger = LoggerFactory.getLogger(this.getClass());

    public Backpropagation(BasicNetwork basicNetwork, NeuralDataSet neuralDataSet) {
        this(basicNetwork, neuralDataSet, null, 0.0, 0.0);
        this.addStrategy(new SmartLearningRate());
        this.addStrategy(new SmartMomentum());
    }

    public Backpropagation(BasicNetwork basicNetwork, NeuralDataSet neuralDataSet, double d, double d2) {
        this(basicNetwork, neuralDataSet, null, d, d2);
    }

    public Backpropagation(BasicNetwork basicNetwork, NeuralDataSet neuralDataSet, OpenCLTrainingProfile openCLTrainingProfile, double d, double d2) {
        super(basicNetwork, neuralDataSet);
        if (openCLTrainingProfile == null) {
            TrainFlatNetworkBackPropagation trainFlatNetworkBackPropagation = new TrainFlatNetworkBackPropagation(basicNetwork.getStructure().getFlat(), this.getTraining(), d, d2);
            this.setFlatTraining(trainFlatNetworkBackPropagation);
        } else {
            TrainFlatNetworkOpenCL trainFlatNetworkOpenCL = new TrainFlatNetworkOpenCL(basicNetwork.getStructure().getFlat(), this.getTraining(), openCLTrainingProfile);
            trainFlatNetworkOpenCL.learnBPROP(d, d2);
            this.setFlatTraining(trainFlatNetworkOpenCL);
        }
    }

    public double[] getLastDelta() {
        return ((TrainFlatNetworkBackPropagation)this.getFlatTraining()).getLastDelta();
    }

    public double getLearningRate() {
        return ((TrainFlatNetworkBackPropagation)this.getFlatTraining()).getLearningRate();
    }

    public double getMomentum() {
        return ((TrainFlatNetworkBackPropagation)this.getFlatTraining()).getMomentum();
    }

    public boolean isValidResume(TrainingContinuation trainingContinuation) {
        if (!trainingContinuation.getContents().containsKey(LAST_DELTA)) {
            return false;
        }
        double[] dArray = (double[])trainingContinuation.get(LAST_DELTA);
        return dArray.length == this.getNetwork().getStructure().calculateSize();
    }

    public TrainingContinuation pause() {
        TrainingContinuation trainingContinuation = new TrainingContinuation();
        TrainFlatNetworkBackPropagation trainFlatNetworkBackPropagation = (TrainFlatNetworkBackPropagation)this.getFlatTraining();
        double[] dArray = trainFlatNetworkBackPropagation.getLastDelta();
        trainingContinuation.set(LAST_DELTA, dArray);
        return trainingContinuation;
    }

    public void resume(TrainingContinuation trainingContinuation) {
        if (!this.isValidResume(trainingContinuation)) {
            throw new TrainingError("Invalid training resume data length");
        }
        ((TrainFlatNetworkBackPropagation)this.getFlatTraining()).setLastDelta((double[])trainingContinuation.get(LAST_DELTA));
    }

    public void setLearningRate(double d) {
        ((TrainFlatNetworkBackPropagation)this.getFlatTraining()).setLearningRate(d);
    }

    public void setMomentum(double d) {
        ((TrainFlatNetworkBackPropagation)this.getFlatTraining()).setLearningRate(d);
    }
}

