/*
 * Decompiled with CFR 0.152.
 */
package optas.datamining.kernels;

import Jama.Matrix;
import optas.datamining.kernels.Kernel;

public class NeuralNetwork
extends Kernel {
    Matrix nnKernel = null;
    double[] diag = null;

    public NeuralNetwork(int inputDim) {
        this.inputDim = inputDim;
        this.parameterCount = inputDim + 2;
        this.KernelParameterCount = inputDim + 2;
    }

    @Override
    public String[] getParameterNames() {
        super.getParameterNames();
        for (int i = 0; i < this.inputDim; ++i) {
            this.KernelParameterNames[i] = "l_" + i;
        }
        this.KernelParameterNames[this.inputDim] = "bias";
        this.KernelParameterNames[this.inputDim + 1] = "sigma";
        return this.KernelParameterNames;
    }

    @Override
    public boolean SetParameter(double[] theta) {
        if (theta.length < this.parameterCount) {
            return false;
        }
        this.theta = theta;
        this.nnKernel = new Matrix(this.KernelParameterCount - 1, this.KernelParameterCount - 1, 0.0);
        this.diag = new double[this.KernelParameterCount - 1];
        for (int i = 0; i < this.KernelParameterCount - 1; ++i) {
            this.nnKernel.set(i, i, theta[i]);
            this.diag[i] = theta[i];
        }
        return true;
    }

    public double SqrDistance2(double[] x, double[] y) {
        double sum = 0.0;
        for (int i = 0; i < x.length; ++i) {
            double tmp = (x[i] - y[i]) / this.theta[i];
            sum += tmp * tmp;
        }
        return sum;
    }

    @Override
    public double kernel(double[] x, double[] y, int index1, int index2) {
        double value1 = this.diag[this.diag.length - 1];
        double value2 = this.diag[this.diag.length - 1];
        double value3 = this.diag[this.diag.length - 1];
        for (int i = 0; i < x.length; ++i) {
            value1 += x[i] * this.diag[i] * y[i];
            value2 += x[i] * this.diag[i] * x[i];
            value3 += y[i] * this.diag[i] * y[i];
        }
        double noise = 0.0;
        if (index1 == index2) {
            noise = this.theta[this.KernelParameterCount - 1] * this.theta[this.KernelParameterCount - 1];
        }
        return 0.6366197723675814 * Math.asin(2.0 * value1 / Math.sqrt((1.0 + 2.0 * value2) * (1.0 + 2.0 * value3))) + noise;
    }

    @Override
    public double dkernel(double[] x, double[] y, int d) {
        Matrix mx = new Matrix(1, x.length + 1);
        Matrix my = new Matrix(y.length + 1, 1);
        for (int k = 0; k < x.length; ++k) {
            mx.set(0, k + 1, x[k]);
            my.set(k + 1, 0, y[k]);
        }
        mx.set(0, 0, 1.0);
        my.set(0, 0, 1.0);
        double xSy = mx.times(this.nnKernel).times(my).get(0, 0);
        double xSx = mx.times(this.nnKernel).times(mx.transpose()).get(0, 0);
        double ySy = my.transpose().times(this.nnKernel).times(my).get(0, 0);
        double v = Math.sqrt((1.0 + 2.0 * xSx) * (1.0 + 2.0 * ySy));
        double outer_derivative = 0.6366197723675814 / Math.sqrt(1.0 - 4.0 * xSy * xSy / (v * v));
        int t1 = d;
        int t2 = d;
        double u = 2.0 * xSy;
        double du = 2.0 * mx.get(0, t1) * my.get(t2, 0);
        double dv = (2.0 * mx.get(0, t1) * mx.get(0, t2) * (1.0 + 2.0 * ySy) + 2.0 * my.get(t1, 0) * my.get(t2, 0) * (1.0 + 2.0 * xSx)) / (2.0 * v);
        return outer_derivative * (du / v - u * dv / (v * v));
    }
}

