/*
 * Decompiled with CFR 0.152.
 */
package org.encog.engine.opencl.kernels;

import java.util.Map;
import org.encog.engine.data.BasicEngineData;
import org.encog.engine.data.EngineData;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.train.prop.OpenCLTrainingProfile;
import org.encog.engine.opencl.EncogCLDevice;
import org.encog.engine.opencl.EncogCLQueue;
import org.encog.engine.opencl.exceptions.OpenCLError;
import org.encog.engine.opencl.exceptions.OutOfOpenCLResources;
import org.encog.engine.opencl.kernels.EncogKernel;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.ResourceLoader;
import org.jocl.CLException;
import org.jocl.cl_mem;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class KernelNetworkTrain
extends EncogKernel {
    public static final int PARRAY_INPUT_COUNT = 0;
    public static final int PARRAY_OUTPUT_COUNT = 1;
    public static final int PARRAY_LAYER_COUNT = 2;
    public static final int PARRAY_LEARN = 3;
    public static final int PARRAY_START = 4;
    public static final int PARRAY_ITEMS_PER = 5;
    public static final int PARRAY_ITERATIONS = 6;
    private cl_mem weightInArrayBuffer;
    private cl_mem weightOutArrayBuffer;
    private cl_mem layerIndexBuffer;
    private cl_mem layerCountBuffer;
    private cl_mem layerFeedCountBuffer;
    private cl_mem weightIndexBuffer;
    private cl_mem activationTypeBuffer;
    private cl_mem tempDataInBuffer;
    private cl_mem tempDataOutBuffer;
    private final float[] weightInArray;
    private final float[] weightOutArray;
    private float[] tempDataArray;
    private int layerDeltaSize;
    private final float[] inputArray;
    private final float[] idealArray;
    private cl_mem inputBuffer;
    private cl_mem idealBuffer;
    private final int[] paramArray;
    private cl_mem paramBuffer;
    private cl_mem errorBuffer;
    private cl_mem gradientOutBuffer;
    private cl_mem gradientInBuffer;
    private final FlatNetwork flat;
    private float[] errors;
    private final float[] gradients;
    private final EngineIndexableSet training;
    private final EncogCLDevice device;
    private final int trainingLength;

    public KernelNetworkTrain(EncogCLDevice encogCLDevice, FlatNetwork flatNetwork, EngineIndexableSet engineIndexableSet, int n) {
        super(encogCLDevice, "org/encog/engine/resources/KernelNetTrain.txt", "NetworkTrain");
        int n2;
        this.training = engineIndexableSet;
        this.trainingLength = (int)this.training.getRecordCount();
        this.device = encogCLDevice;
        this.flat = flatNetwork;
        this.weightInArray = new float[flatNetwork.getWeights().length];
        this.weightOutArray = new float[flatNetwork.getWeights().length];
        this.tempDataArray = new float[n];
        this.gradients = new float[flatNetwork.getWeights().length];
        this.layerDeltaSize = 0;
        for (n2 = 0; n2 < flatNetwork.getLayerCounts().length; ++n2) {
            this.layerDeltaSize += flatNetwork.getLayerCounts()[n2];
        }
        n2 = flatNetwork.getInputCount();
        int n3 = flatNetwork.getOutputCount();
        this.inputArray = new float[n2 * this.trainingLength];
        this.idealArray = new float[n3 * this.trainingLength];
        this.paramArray = new int[10];
        EngineData engineData = BasicEngineData.createPair(flatNetwork.getInputCount(), flatNetwork.getOutputCount());
        int n4 = 0;
        int n5 = 0;
        for (int i = 0; i < this.trainingLength; ++i) {
            int n6;
            engineIndexableSet.getRecord(i, engineData);
            for (n6 = 0; n6 < flatNetwork.getInputCount(); ++n6) {
                this.inputArray[n4++] = (float)engineData.getInputArray()[n6];
            }
            for (n6 = 0; n6 < flatNetwork.getOutputCount(); ++n6) {
                this.idealArray[n5++] = (float)engineData.getIdealArray()[n6];
            }
        }
    }

    public void assignWorkgroupSizes(int n, int n2) {
        int n3 = Math.min(n, n2);
        this.setLocalWork(Math.min(this.getMaxWorkGroupSize(), n3));
        this.setGlobalWork(n3);
    }

    public void calculate(int n, int n2, boolean bl, int n3) {
        this.prepareKernel();
        this.paramArray[3] = bl ? 1 : 0;
        this.paramArray[4] = n;
        this.paramArray[5] = n2;
        this.paramArray[6] = n3;
        EngineArray.arrayCopy(this.flat.getWeights(), this.weightInArray);
        this.setArg(0, this.paramBuffer);
        this.setArg(1, this.errorBuffer);
        this.setArg(2, this.layerIndexBuffer);
        this.setArg(3, this.layerCountBuffer);
        this.setArg(4, this.layerFeedCountBuffer);
        this.setArg(5, this.weightIndexBuffer);
        this.setArg(6, this.inputBuffer);
        this.setArg(7, this.idealBuffer);
        this.setArg(8, this.weightInArrayBuffer);
        this.setArg(9, this.weightOutArrayBuffer);
        this.setArg(10, this.gradientOutBuffer);
        this.setArg(11, this.activationTypeBuffer);
        this.setArg(12, this.tempDataInBuffer);
        this.setArg(13, this.tempDataOutBuffer);
        this.setArg(14, this.gradientInBuffer);
        try {
            EncogCLQueue encogCLQueue = this.device.getQueue();
            EngineArray.fill(this.gradients, 0.0f);
            this.paramArray[3] = bl ? 1 : 0;
            this.paramArray[4] = n;
            encogCLQueue.array2Buffer(this.weightInArray, this.weightInArrayBuffer);
            encogCLQueue.array2Buffer(this.tempDataArray, this.tempDataInBuffer);
            encogCLQueue.array2Buffer(this.gradients, this.gradientInBuffer);
            encogCLQueue.array2Buffer(this.paramArray, this.paramBuffer);
            encogCLQueue.execute(this);
            encogCLQueue.waitFinish();
            encogCLQueue.buffer2Array(this.errorBuffer, this.errors);
            encogCLQueue.buffer2Array(this.weightOutArrayBuffer, this.weightOutArray);
            encogCLQueue.buffer2Array(this.tempDataOutBuffer, this.tempDataArray);
            encogCLQueue.buffer2Array(this.gradientOutBuffer, this.gradients);
        }
        catch (CLException cLException) {
            if (cLException.getMessage().equals("CL_OUT_OF_RESOURCES")) {
                throw new OutOfOpenCLResources(cLException);
            }
            throw new OpenCLError(cLException);
        }
        catch (Exception exception) {
            throw new OpenCLError(exception);
        }
    }

    public void compile(Map<String, String> map, OpenCLTrainingProfile openCLTrainingProfile, FlatNetwork flatNetwork) {
        ActivationFunction activationFunction = flatNetwork.getActivationFunctions()[0];
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("#define ACTIVATION(x,slope)");
        stringBuilder.append(activationFunction.getOpenCLExpression(false));
        stringBuilder.append("\r\n");
        stringBuilder.append("#define DERIVATIVE(x,slope)");
        stringBuilder.append(activationFunction.getOpenCLExpression(true));
        stringBuilder.append("\r\n");
        stringBuilder.append(ResourceLoader.loadString(this.getSourceName()));
        this.setCLSource(stringBuilder.toString());
        this.compile(map);
        openCLTrainingProfile.calculateKernelParams(this, this.training);
        this.init(openCLTrainingProfile);
    }

    public float[] getErrors() {
        return this.errors;
    }

    public float[] getTempDataArray() {
        return this.tempDataArray;
    }

    public float[] getWeightOutArray() {
        return this.weightOutArray;
    }

    public void init(OpenCLTrainingProfile openCLTrainingProfile) {
        int n = openCLTrainingProfile.getKernelGlobalWorkgroup();
        int n2 = openCLTrainingProfile.getKernelGlobalWorkgroup() * this.flat.getWeights().length;
        this.errors = new float[n];
        this.paramArray[0] = this.flat.getInputCount();
        this.paramArray[1] = this.flat.getOutputCount();
        this.paramArray[2] = this.flat.getLayerCounts().length;
        this.inputBuffer = this.createArrayReadOnly(this.inputArray);
        this.idealBuffer = this.createArrayReadOnly(this.idealArray);
        this.errorBuffer = this.createFloatArrayWriteOnly(n);
        this.gradientOutBuffer = this.createFloatArrayWriteOnly(n2);
        this.gradientInBuffer = this.createArrayReadOnly(this.gradients);
        this.paramBuffer = this.createArrayReadOnly(this.paramArray);
        this.layerIndexBuffer = this.createArrayReadOnly(this.flat.getLayerIndex());
        this.layerCountBuffer = this.createArrayReadOnly(this.flat.getLayerCounts());
        this.layerFeedCountBuffer = this.createArrayReadOnly(this.flat.getLayerFeedCounts());
        this.weightInArrayBuffer = this.createArrayReadOnly(this.weightInArray);
        this.weightOutArrayBuffer = this.createFloatArrayWriteOnly(this.weightInArray.length);
        this.weightIndexBuffer = this.createArrayReadOnly(this.flat.getWeightIndex());
        this.activationTypeBuffer = this.createArrayReadOnly(this.flat.getLayerCounts());
        this.tempDataInBuffer = this.createArrayReadOnly(this.tempDataArray);
        this.tempDataOutBuffer = this.createFloatArrayWriteOnly(this.tempDataArray.length);
    }

    @Override
    public void release() {
        super.release();
        this.releaseBuffer(this.activationTypeBuffer);
        this.releaseBuffer(this.errorBuffer);
        this.releaseBuffer(this.gradientOutBuffer);
        this.releaseBuffer(this.gradientInBuffer);
        this.releaseBuffer(this.idealBuffer);
        this.releaseBuffer(this.inputBuffer);
        this.releaseBuffer(this.layerCountBuffer);
        this.releaseBuffer(this.layerFeedCountBuffer);
        this.releaseBuffer(this.layerIndexBuffer);
        this.releaseBuffer(this.paramBuffer);
        this.releaseBuffer(this.tempDataInBuffer);
        this.releaseBuffer(this.tempDataOutBuffer);
        this.releaseBuffer(this.weightInArrayBuffer);
        this.releaseBuffer(this.weightIndexBuffer);
        this.releaseBuffer(this.weightOutArrayBuffer);
    }

    public void setTempDataArray(float[] fArray) {
        this.tempDataArray = fArray;
    }
}

