#define POSITIVE_ETA 1.2f
#define NEGATIVE_ETA 0.5f
#define DELTA_MIN 0.00001f
#define MAX_STEP 50.0f	

#define PARRAY_INPUT_COUNT 0
#define PARRAY_OUTPUT_COUNT 1
#define PARRAY_LAYER_COUNT 2
#define PARRAY_LEARN 3 
#define PARRAY_START 4
#define PARRAY_ITEMS_PER 5
#define ITERATIONS 6

kernel void NetworkTrain(
    global read_only int *params,
    global write_only float *errors,
    global read_only int *layerIndex,
    global read_only int *layerCounts,
    global read_only int *layerFeedCounts,
    global read_only int *weightIndex,
    global read_only float* input,
    global read_only float* ideal,
    global read_only float* weightsIn,
    global write_only float* weightsOut,
    global write_only float *gradientsOut,
    global read_only int *activationType,
    global read_only float *tempDataIn,
    global read_only float *tempDataOut,
    global read_only float *gradientsIn
    )
{
	private float layerOutput[NEURON_COUNT];
	private float layerDelta[NEURON_COUNT];
			
	int taskIndex = get_global_id(0);
	int globalSize = get_global_size(0);
	
	int inputSize = params[PARRAY_INPUT_COUNT];
    int outputSize = params[PARRAY_OUTPUT_COUNT];
    int layerCount = params[PARRAY_LAYER_COUNT];
    int trainingOffset = params[PARRAY_START];
	int itemsPer = params[PARRAY_ITEMS_PER];
	
	int gradientOffset = (taskIndex*WEIGHT_COUNT);

	int iterations = params[ITERATIONS];
	
	while( (iterations--)> 0 )
	{
		// clear out the gradients and errors
		errors[taskIndex] = 0;
		for(int i=0;i<WEIGHT_COUNT;i++)
		{
			gradientsOut[gradientOffset+i] = 0;
		}
	
		for(int trainIndex=0;trainIndex<itemsPer;trainIndex++)
		{		 
			int subtaskIndex = (taskIndex*itemsPer)+trainIndex+trainingOffset;
		
			// part 1: forward pass
			int taskInputIndex = subtaskIndex * inputSize;
			int taskIdealIndex = subtaskIndex * outputSize;
	
			int sourceIndex = NEURON_COUNT - layerCounts[layerCount-1];
		
			for(int i=0;i<NEURON_COUNT;i++)
				layerOutput[i] = 1;
		
			// load the input into the layer output array, this feeds the first layer.
			for(int i=0;i<inputSize;i++)
				layerOutput[sourceIndex+i] = input[taskInputIndex+i];
				
			for (int currentLayer = layerCount - 1; currentLayer > 0; currentLayer--)
			{
				int inputIndex = layerIndex[currentLayer];
				int outputIndex = layerIndex[currentLayer - 1];
				int inputSize = layerCounts[currentLayer];
				int outputSize = layerFeedCounts[currentLayer - 1];
				int index = weightIndex[currentLayer - 1];

				global float *wptr = weightsIn+index;
				for (int x = 0; x < outputSize; x++)
				{
					float sum = 0;
					float *outputPtr = layerOutput+inputIndex;
					for (int y = 0; y < inputSize; y++)
					{
						sum += *(wptr++) * layerOutput[inputIndex + y];
					}
       
					layerOutput[outputIndex + x] = ACTIVATION(sum, 1.0);
				}
			}
		
			// part 2: backward pass
			// process the output layer first
	
			float e = 0;
   
			for(int i=0;i<outputSize;i++)
			{
				float diff = ideal[taskIdealIndex+i] - layerOutput[i];
				e+=diff*diff;
				layerDelta[i] = diff * DERIVATIVE(layerOutput[i], 1.0);
			}				

			errors[taskIndex] += e;

			// process hidden layers
		
			for(int currentLevel = 0; (currentLevel<layerCount-1); currentLevel++)
			{
            	int fromLayerIndex = layerIndex[currentLevel + 1];
            	int toLayerIndex = layerIndex[currentLevel];
            	int fromLayerSize = layerCounts[currentLevel + 1];
            	int toLayerSize = layerFeedCounts[currentLevel];

				int index = weightIndex[currentLevel];
	
				// handle weights
				int yi = fromLayerIndex;
				for (int y = 0; y < fromLayerSize; y++) 
				{
					float output = layerOutput[yi];
					float sum = 0;
					int wi = index+y;
					global float * gptr = gradientsOut+wi+gradientOffset;
					global float * wptr = weightsIn + wi;
					float * dptr = layerDelta + toLayerIndex;
					for (int x = 0; x < toLayerSize; x++) 
					{
						*gptr += output * (*dptr);
						sum += (*wptr) * (*dptr);
						wi+=fromLayerSize;
						wptr+=fromLayerSize;
						gptr+=fromLayerSize;
						dptr++;
					}
			
					layerDelta[yi] = sum * DERIVATIVE(
            	      layerOutput[yi],
                	  1.0);

					yi++;
				}
			}
		}	
	
		// now that the gradients have been calculated, update the network
		barrier(CLK_GLOBAL_MEM_FENCE);
	
		if( taskIndex==0 )
		{
			// loop over all gradients and sum them into the first "global" task
			for(int i=0;i<WEIGHT_COUNT;i++) 
			{
				gradientsOut[i]+=gradientsIn[i];
				for(int j=1;j<globalSize;j++)
				{		 
					gradientsOut[i] += gradientsOut[(j*WEIGHT_COUNT)+i];
				}
			}
		}

		if( taskIndex==0 && params[PARRAY_LEARN]>0 )
		{
			// teach the weights
#ifdef LEARN_RPROP
			global float *wptr = weightsIn;
			global float *gptr = gradientsOut;		
			for(int i=0;i<WEIGHT_COUNT;i++)
			{
				int change = sign((*gptr) * tempDataIn[i]);
				float weightChange = 0;

				// if the gradient has retained its sign, then we increase the
				// delta so that it will converge faster
				if (change > 0) 
				{
					float delta = tempDataIn[i+WEIGHT_COUNT]
						* POSITIVE_ETA;
					delta = min(delta, MAX_STEP);
					weightChange = sign(*gptr) * delta;
					tempDataIn[i+WEIGHT_COUNT] = delta;
					tempDataIn[i] = *gptr;
				}
				else if (change < 0) 
				{
					// if change<0, then the sign has changed, and the last
					// delta was too big
					float delta = tempDataIn[i+WEIGHT_COUNT]
						* NEGATIVE_ETA;
					delta = max(delta, DELTA_MIN);
					tempDataIn[i+WEIGHT_COUNT] = delta;
					// set the previous gradient to zero so that there will be no
					// adjustment the next iteration
					tempDataIn[i] = 0;
				} 
				else if (change == 0) 
				{
					// if change==0 then there is no change to the delta
					float delta = tempDataIn[i];
					weightChange = sign(*gptr) * delta;
					tempDataIn[i] = gradientsOut[i];
				}
			
				*(wptr++)+=weightChange;
				gptr++;
			}	
#endif
#ifdef LEARN_BPROP		
			for(int i=0;i<WEIGHT_COUNT;i++)
			{
				float delta = (gradientsOut[i]*tempDataIn[0]);
				weightsIn[i]+=delta+(tempDataIn[i+2]*tempDataIn[1]);
				tempDataIn[i+2] = delta;
			}
#endif
#ifdef LEARN_MANHATTAN
			for(int i=0;i<WEIGHT_COUNT;i++)
			{
				int direction = sign(gradientsOut[i]);
				weightsIn[i]+=tempDataIn[0]*direction;
			}
#endif
		}
	
		barrier(CLK_GLOBAL_MEM_FENCE);
	}
	
	if( taskIndex==0 )
	{
		// finally, after all is done, return the weights to the CPU
		for(int i=0;i<WEIGHT_COUNT;i++)
		{
			weightsOut[i] = weightsIn[i];
		}
		
		for(int i=0;i<(WEIGHT_COUNT*2);i++)
		{
			tempDataOut[i] = tempDataIn[i];
		}
	}
	
}