/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.propagation.resilient;

import org.encog.engine.network.train.prop.OpenCLTrainingProfile;
import org.encog.engine.network.train.prop.TrainFlatNetworkOpenCL;
import org.encog.engine.network.train.prop.TrainFlatNetworkResilient;
import org.encog.engine.util.EngineArray;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.TrainingContinuation;

public class ResilientPropagation
extends Propagation {
    public static final String LAST_GRADIENTS = "LAST_GRADIENTS";
    public static final String UPDATE_VALUES = "UPDATE_VALUES";

    public ResilientPropagation(BasicNetwork basicNetwork, NeuralDataSet neuralDataSet) {
        this(basicNetwork, neuralDataSet, null, 0.1, 50.0);
    }

    public ResilientPropagation(BasicNetwork basicNetwork, NeuralDataSet neuralDataSet, OpenCLTrainingProfile openCLTrainingProfile) {
        this(basicNetwork, neuralDataSet, openCLTrainingProfile, 0.1, 50.0);
    }

    public ResilientPropagation(BasicNetwork basicNetwork, NeuralDataSet neuralDataSet, OpenCLTrainingProfile openCLTrainingProfile, double d, double d2) {
        super(basicNetwork, neuralDataSet);
        if (openCLTrainingProfile == null) {
            TrainFlatNetworkResilient trainFlatNetworkResilient = new TrainFlatNetworkResilient(basicNetwork.getStructure().getFlat(), this.getTraining());
            this.setFlatTraining(trainFlatNetworkResilient);
        } else {
            TrainFlatNetworkOpenCL trainFlatNetworkOpenCL = new TrainFlatNetworkOpenCL(basicNetwork.getStructure().getFlat(), this.getTraining(), openCLTrainingProfile);
            trainFlatNetworkOpenCL.learnRPROP(d, d2);
            this.setFlatTraining(trainFlatNetworkOpenCL);
        }
    }

    public boolean canContinue() {
        return true;
    }

    public boolean isValidResume(TrainingContinuation trainingContinuation) {
        if (!trainingContinuation.getContents().containsKey(LAST_GRADIENTS) || !trainingContinuation.getContents().containsKey(UPDATE_VALUES)) {
            return false;
        }
        double[] dArray = (double[])trainingContinuation.get(LAST_GRADIENTS);
        return dArray.length == this.getNetwork().getStructure().calculateSize();
    }

    public TrainingContinuation pause() {
        TrainingContinuation trainingContinuation = new TrainingContinuation();
        if (this.getFlatTraining() instanceof TrainFlatNetworkResilient) {
            trainingContinuation.set(LAST_GRADIENTS, ((TrainFlatNetworkResilient)this.getFlatTraining()).getLastGradient());
            trainingContinuation.set(UPDATE_VALUES, ((TrainFlatNetworkResilient)this.getFlatTraining()).getUpdateValues());
        } else {
            trainingContinuation.set(LAST_GRADIENTS, ((TrainFlatNetworkOpenCL)this.getFlatTraining()).getLastGradient());
            trainingContinuation.set(UPDATE_VALUES, ((TrainFlatNetworkOpenCL)this.getFlatTraining()).getUpdateValues());
        }
        return trainingContinuation;
    }

    public void resume(TrainingContinuation trainingContinuation) {
        if (!this.isValidResume(trainingContinuation)) {
            throw new TrainingError("Invalid training resume data length");
        }
        double[] dArray = (double[])trainingContinuation.get(LAST_GRADIENTS);
        double[] dArray2 = (double[])trainingContinuation.get(UPDATE_VALUES);
        if (this.getFlatTraining() instanceof TrainFlatNetworkResilient) {
            EngineArray.arrayCopy(dArray, ((TrainFlatNetworkResilient)this.getFlatTraining()).getLastGradient());
            EngineArray.arrayCopy(dArray2, ((TrainFlatNetworkResilient)this.getFlatTraining()).getUpdateValues());
        } else if (this.getFlatTraining() instanceof TrainFlatNetworkOpenCL) {
            EngineArray.arrayCopy(dArray, ((TrainFlatNetworkOpenCL)this.getFlatTraining()).getLastGradient());
            EngineArray.arrayCopy(dArray2, ((TrainFlatNetworkOpenCL)this.getFlatTraining()).getUpdateValues());
        }
    }
}

