/*
 * Decompiled with CFR 0.152.
 */
package org.encog.engine.network.train.prop;

import java.util.HashMap;
import java.util.Map;
import org.encog.engine.EncogEngine;
import org.encog.engine.EncogEngineError;
import org.encog.engine.data.EngineDataSet;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.flat.ValidateForOpenCL;
import org.encog.engine.network.train.TrainFlatNetwork;
import org.encog.engine.network.train.prop.OpenCLTrainingProfile;
import org.encog.engine.opencl.kernels.KernelNetworkTrain;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.ErrorCalculation;
import org.encog.engine.util.ErrorCalculationMode;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class TrainFlatNetworkOpenCL
implements TrainFlatNetwork {
    public static final int LEARN_RPROP = 0;
    public static final int LEARN_BPROP = 1;
    public static final int LEARN_MANHATTAN = 2;
    private double error;
    private final FlatNetwork network;
    private final EngineIndexableSet training;
    private int learningType;
    private double learningRate;
    private double momentum;
    private double initialUpdate;
    private double maxStep;
    private KernelNetworkTrain kernel;
    private int iteration;
    private final OpenCLTrainingProfile profile;

    public TrainFlatNetworkOpenCL(FlatNetwork flatNetwork, EngineDataSet engineDataSet, OpenCLTrainingProfile openCLTrainingProfile) {
        new ValidateForOpenCL().validate(flatNetwork);
        if (!(engineDataSet instanceof EngineIndexableSet)) {
            throw new EncogEngineError("Training data must be Indexable for this training type.");
        }
        if (EncogEngine.getInstance().getCL() == null) {
            throw new EncogEngineError("You must enable OpenCL before using this training type.");
        }
        this.profile = openCLTrainingProfile;
        this.network = flatNetwork;
        this.training = (EngineIndexableSet)engineDataSet;
    }

    private void callKernel(int n, int n2, boolean bl, int n3) {
        this.kernel.calculate(n, n2, bl, n3);
        double d = 0.0;
        for (int i = 0; i < this.kernel.getGlobalWork(); ++i) {
            d += (double)this.kernel.getErrors()[i];
        }
        this.error += d;
    }

    @Override
    public void finishTraining() {
        if (this.kernel != null) {
            this.kernel.release();
        }
    }

    @Override
    public double getError() {
        return this.error;
    }

    @Override
    public int getIteration() {
        return this.iteration;
    }

    public double[] getLastGradient() {
        double[] dArray = new double[this.network.getWeights().length];
        for (int i = 0; i < dArray.length; ++i) {
            dArray[i] = this.kernel.getTempDataArray()[i];
        }
        return dArray;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public int getLearningType() {
        return this.learningType;
    }

    public double getMaxStep() {
        return this.maxStep;
    }

    public double getMomentum() {
        return this.momentum;
    }

    @Override
    public FlatNetwork getNetwork() {
        return this.network;
    }

    @Override
    public int getNumThreads() {
        return 0;
    }

    private Map<String, String> getOptions(String string) {
        HashMap<String, String> hashMap = new HashMap<String, String>();
        hashMap.put("NEURON_COUNT", "" + this.network.getNeuronCount());
        hashMap.put("WEIGHT_COUNT", "" + this.network.getWeights().length);
        hashMap.put(string, null);
        return hashMap;
    }

    @Override
    public EngineDataSet getTraining() {
        return null;
    }

    public double[] getUpdateValues() {
        double[] dArray = new double[this.network.getWeights().length];
        int n = this.network.getWeights().length;
        for (int i = 0; i < dArray.length; ++i) {
            dArray[i] = this.kernel.getTempDataArray()[n + i];
        }
        return dArray;
    }

    @Override
    public void iteration() {
        this.iteration(1);
    }

    @Override
    public void iteration(int n) {
        if (this.learningType == -1) {
            throw new EncogEngineError("Learning type has not been defined yet, you must first call one of the learnXXXX methods, such as learnRPROP.");
        }
        this.iteration += n;
        int n2 = 0;
        this.error = 0.0;
        int n3 = this.profile.getKernelNumberOfCalls();
        if (n3 > 0 && n > 1) {
            throw new EncogEngineError("Must use an OpenCL ratio of 1.0 if you are going to use an iteration count > 1.");
        }
        this.kernel.setGlobalWork(this.profile.getKernelGlobalWorkgroup());
        this.kernel.setLocalWork(this.profile.getKernelLocalWorkgroup());
        while (n3 > 0) {
            this.callKernel(n2, this.profile.getKernelWorkPerCall(), false, 1);
            --n3;
            n2 += this.profile.getKernelWorkPerCall() * this.kernel.getGlobalWork();
        }
        this.kernel.setGlobalWork(this.profile.getKernelRemainderGlobal());
        this.kernel.setLocalWork(this.profile.getKernelRemainderGlobal());
        this.callKernel(n2, this.profile.getKernelRemainderPer(), true, n);
        n3 = (int)this.training.getRecordCount();
        this.error /= (double)(n3 * this.training.getIdealSize());
        if (ErrorCalculation.getMode() == ErrorCalculationMode.RMS) {
            this.error = Math.sqrt(this.error);
        }
        EngineArray.arrayCopy(this.kernel.getWeightOutArray(), this.network.getWeights());
    }

    public void learnBPROP(double d, double d2) {
        this.learningType = 1;
        this.momentum = d2;
        this.learningRate = d;
        this.learningType = 1;
        Map<String, String> map = this.getOptions("LEARN_BPROP");
        this.kernel = new KernelNetworkTrain(this.profile.getDevice(), this.network, this.training, this.network.getWeights().length + 2);
        this.kernel.compile(map, this.profile, this.network);
        this.kernel.getTempDataArray()[0] = (float)d;
        this.kernel.getTempDataArray()[1] = (float)d2;
    }

    public void learnManhattan(double d) {
        this.learningType = 2;
        this.learningRate = d;
        Map<String, String> map = this.getOptions("LEARN_MANHATTAN");
        this.kernel = new KernelNetworkTrain(this.profile.getDevice(), this.network, this.training, 1);
        this.kernel.compile(map, this.profile, this.network);
        this.kernel.getTempDataArray()[0] = (float)d;
    }

    public void learnRPROP() {
        this.learnRPROP(0.1, 50.0);
    }

    public void learnRPROP(double d, double d2) {
        this.learningType = 0;
        this.initialUpdate = d;
        this.maxStep = d2;
        Map<String, String> map = this.getOptions("LEARN_RPROP");
        this.kernel = new KernelNetworkTrain(this.profile.getDevice(), this.network, this.training, this.network.getWeights().length * 2);
        this.kernel.compile(map, this.profile, this.network);
        int n = this.network.getWeights().length;
        for (int i = 0; i < n; ++i) {
            this.kernel.getTempDataArray()[i] = 0.0f;
            this.kernel.getTempDataArray()[i + n] = (float)this.initialUpdate;
        }
    }

    @Override
    public void setIteration(int n) {
        this.iteration = n;
    }

    @Override
    public void setNumThreads(int n) {
    }
}

