/*
 * Decompiled with CFR 0.152.
 */
package org.encog.engine.opencl.kernels;

import java.util.HashMap;
import org.encog.engine.data.BasicEngineData;
import org.encog.engine.data.EngineData;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.opencl.EncogCLDevice;
import org.encog.engine.opencl.EncogCLQueue;
import org.encog.engine.opencl.exceptions.OpenCLError;
import org.encog.engine.opencl.exceptions.OutOfOpenCLResources;
import org.encog.engine.opencl.kernels.EncogKernel;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.ErrorCalculation;
import org.encog.engine.util.ResourceLoader;
import org.jocl.CLException;
import org.jocl.cl_mem;

public class KernelNetworkCalc
extends EncogKernel {
    public static final int PARRAY_INPUT_COUNT = 0;
    public static final int PARRAY_OUTPUT_COUNT = 1;
    public static final int PARRAY_LAYER_COUNT = 2;
    public static final int PARRAY_LEARN = 3;
    public static final int PARRAY_START = 4;
    public static final int PARRAY_ITEMS_PER = 5;
    public static final int PARRAY_ITERATIONS = 6;
    private cl_mem weightInArrayBuffer;
    private cl_mem layerIndexBuffer;
    private cl_mem layerCountBuffer;
    private cl_mem layerFeedCountBuffer;
    private cl_mem weightIndexBuffer;
    private float[] weightInArray;
    private float[] inputArray;
    private float[] idealArray;
    private cl_mem inputBuffer;
    private cl_mem layerOutputBuffer;
    private cl_mem idealBuffer;
    private float[] layerOutput;
    private int[] paramArray;
    private cl_mem paramBuffer;
    private cl_mem errorBuffer;
    private FlatNetwork flat;
    private float[] errors;
    private EngineIndexableSet training;
    private final EncogCLDevice device;
    private int trainingLength;

    public KernelNetworkCalc(EncogCLDevice encogCLDevice) {
        super(encogCLDevice, "org/encog/engine/resources/KernelNetCalc.txt", "NetworkCalc");
        this.device = encogCLDevice;
        this.paramArray = new int[10];
        this.paramBuffer = this.createArrayReadOnly(this.paramArray);
    }

    public void calculate(int n, int n2) {
        this.prepareKernel();
        this.paramArray[4] = n;
        this.paramArray[5] = n2;
        this.setGlobalWork(n2);
        this.setLocalWork(64);
        EngineArray.arrayCopy(this.flat.getWeights(), this.weightInArray);
        this.setArg(0, this.paramBuffer);
        this.setArg(1, this.errorBuffer);
        this.setArg(2, this.layerIndexBuffer);
        this.setArg(3, this.layerCountBuffer);
        this.setArg(4, this.layerFeedCountBuffer);
        this.setArg(5, this.weightIndexBuffer);
        this.setArg(6, this.inputBuffer);
        this.setArg(7, this.idealBuffer);
        this.setArg(8, this.weightInArrayBuffer);
        this.setArg(9, this.layerOutputBuffer);
        try {
            EncogCLQueue encogCLQueue = this.device.getQueue();
            this.paramArray[4] = n;
            encogCLQueue.array2Buffer(this.weightInArray, this.weightInArrayBuffer);
            encogCLQueue.array2Buffer(this.paramArray, this.paramBuffer);
            encogCLQueue.execute(this);
            encogCLQueue.waitFinish();
            encogCLQueue.buffer2Array(this.errorBuffer, this.errors);
            encogCLQueue.buffer2Array(this.layerOutputBuffer, this.layerOutput);
        }
        catch (CLException cLException) {
            if (cLException.getMessage().equals("CL_OUT_OF_RESOURCES")) {
                throw new OutOfOpenCLResources(cLException);
            }
            throw new OpenCLError(cLException);
        }
        catch (Exception exception) {
            throw new OpenCLError(exception);
        }
    }

    public void compile(FlatNetwork flatNetwork) {
        ActivationFunction activationFunction = flatNetwork.getActivationFunctions()[0];
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("#define ACTIVATION(x,slope)");
        stringBuilder.append(activationFunction.getOpenCLExpression(false));
        stringBuilder.append("\r\n");
        stringBuilder.append(ResourceLoader.loadString(this.getSourceName()));
        this.setCLSource(stringBuilder.toString());
        HashMap<String, String> hashMap = new HashMap<String, String>();
        hashMap.put("NEURON_COUNT", "" + flatNetwork.getNeuronCount());
        hashMap.put("WEIGHT_COUNT", "" + flatNetwork.getWeights().length);
        this.compile(hashMap);
    }

    public float[] getErrors() {
        return this.errors;
    }

    public void release() {
        super.release();
        if (this.errorBuffer != null) {
            this.releaseBuffer(this.errorBuffer);
            this.errorBuffer = null;
        }
        if (this.idealBuffer != null) {
            this.releaseBuffer(this.idealBuffer);
            this.idealBuffer = null;
        }
        if (this.inputBuffer != null) {
            this.releaseBuffer(this.inputBuffer);
            this.inputBuffer = null;
        }
        if (this.layerCountBuffer != null) {
            this.releaseBuffer(this.layerCountBuffer);
            this.layerCountBuffer = null;
        }
        if (this.layerFeedCountBuffer != null) {
            this.releaseBuffer(this.layerFeedCountBuffer);
            this.layerFeedCountBuffer = null;
        }
        if (this.layerIndexBuffer != null) {
            this.releaseBuffer(this.layerIndexBuffer);
            this.layerIndexBuffer = null;
        }
        if (this.paramBuffer != null) {
            this.releaseBuffer(this.paramBuffer);
            this.paramBuffer = null;
        }
        if (this.weightInArrayBuffer != null) {
            this.releaseBuffer(this.weightInArrayBuffer);
            this.weightInArrayBuffer = null;
        }
        if (this.weightIndexBuffer != null) {
            this.releaseBuffer(this.weightIndexBuffer);
            this.weightIndexBuffer = null;
        }
    }

    public FlatNetwork getFlat() {
        return this.flat;
    }

    public void setFlat(FlatNetwork flatNetwork) {
        this.flat = flatNetwork;
        this.weightInArray = new float[flatNetwork.getWeights().length];
        int n = flatNetwork.getInputCount();
        int n2 = flatNetwork.getOutputCount();
        this.paramArray[0] = this.flat.getInputCount();
        this.paramArray[1] = this.flat.getOutputCount();
        this.paramArray[2] = this.flat.getLayerCounts().length;
        if (this.layerCountBuffer != null) {
            this.releaseBuffer(this.layerCountBuffer);
            this.layerCountBuffer = null;
        }
        if (this.layerFeedCountBuffer != null) {
            this.releaseBuffer(this.layerFeedCountBuffer);
            this.layerFeedCountBuffer = null;
        }
        if (this.layerIndexBuffer != null) {
            this.releaseBuffer(this.layerIndexBuffer);
            this.layerIndexBuffer = null;
        }
        if (this.weightInArrayBuffer != null) {
            this.releaseBuffer(this.weightInArrayBuffer);
            this.weightInArrayBuffer = null;
        }
        if (this.weightIndexBuffer != null) {
            this.releaseBuffer(this.weightIndexBuffer);
            this.weightIndexBuffer = null;
        }
        this.layerIndexBuffer = this.createArrayReadOnly(this.flat.getLayerIndex());
        this.layerCountBuffer = this.createArrayReadOnly(this.flat.getLayerCounts());
        this.layerFeedCountBuffer = this.createArrayReadOnly(this.flat.getLayerFeedCounts());
        this.weightInArrayBuffer = this.createArrayReadOnly(this.weightInArray);
        this.weightIndexBuffer = this.createArrayReadOnly(this.flat.getWeightIndex());
        this.allocateCommon();
        this.compile(flatNetwork);
    }

    private void allocateCommon() {
        if (this.training != null && this.flat != null) {
            if (this.layerOutputBuffer != null) {
                this.releaseBuffer(this.layerOutputBuffer);
                this.layerOutputBuffer = null;
            }
            this.layerOutput = new float[this.flat.getLayerOutput().length * this.trainingLength];
            this.layerOutputBuffer = this.createFloatArrayWriteOnly(this.layerOutput.length);
        }
    }

    public EngineIndexableSet getTraining() {
        return this.training;
    }

    public void setTraining(EngineIndexableSet engineIndexableSet) {
        int n;
        this.training = engineIndexableSet;
        this.trainingLength = (int)this.training.getRecordCount();
        EngineData engineData = BasicEngineData.createPair(this.flat.getInputCount(), this.flat.getOutputCount());
        this.inputArray = new float[engineIndexableSet.getInputSize() * this.trainingLength];
        this.idealArray = new float[engineIndexableSet.getIdealSize() * this.trainingLength];
        int n2 = 0;
        int n3 = 0;
        for (n = 0; n < this.trainingLength; ++n) {
            int n4;
            engineIndexableSet.getRecord(n, engineData);
            for (n4 = 0; n4 < this.flat.getInputCount(); ++n4) {
                this.inputArray[n2++] = (float)engineData.getInputArray()[n4];
            }
            for (n4 = 0; n4 < this.flat.getOutputCount(); ++n4) {
                this.idealArray[n3++] = (float)engineData.getIdealArray()[n4];
            }
        }
        n = (int)engineIndexableSet.getRecordCount();
        this.errors = new float[n];
        if (this.errorBuffer != null) {
            this.releaseBuffer(this.errorBuffer);
            this.errorBuffer = null;
        }
        if (this.idealBuffer != null) {
            this.releaseBuffer(this.idealBuffer);
            this.idealBuffer = null;
        }
        if (this.inputBuffer != null) {
            this.releaseBuffer(this.inputBuffer);
            this.inputBuffer = null;
        }
        this.errorBuffer = this.createFloatArrayWriteOnly(n);
        this.inputBuffer = this.createArrayReadOnly(this.inputArray);
        this.idealBuffer = this.createArrayReadOnly(this.idealArray);
        this.allocateCommon();
    }

    public double getError() {
        ErrorCalculation errorCalculation = new ErrorCalculation();
        double d = 0.0;
        for (int i = 0; i < this.errors.length; ++i) {
            d += (double)this.errors[i];
        }
        return d / (double)(this.errors.length * this.flat.getOutputCount());
    }
}

