/*
 * Decompiled with CFR 0.152.
 */
package optas.regression.gaussian.inf;

import Jama.CholeskyDecomposition;
import Jama.Matrix;
import optas.regression.GaussianProcessRegression;
import optas.regression.gaussian.HyperParameter;
import optas.regression.gaussian.cov.CovarianceFunction;
import optas.regression.gaussian.cov.covSEiso;
import optas.regression.gaussian.inf.Inference;
import optas.regression.gaussian.mean.Constant;
import optas.regression.gaussian.mean.Linear;
import optas.regression.gaussian.mean.MeanFunction;
import optas.regression.gaussian.mean.Sum;
import optas.regression.likelihood.Gaussian;
import optas.regression.likelihood.LikelihoodFunction;

public class infExact
implements Inference {
    Matrix alpha = null;
    Matrix sw = null;
    Matrix L = null;
    Matrix lp = null;
    Matrix ymu = null;
    Matrix ys2 = null;
    double nlZ;

    @Override
    public void inference(HyperParameter hyp, MeanFunction mean, CovarianceFunction cov, LikelihoodFunction lik, double[][] x, double[] y, double[][] xs) {
        Matrix tmp;
        int i;
        boolean debug = false;
        if (!(lik instanceof Gaussian)) {
            System.out.println("Exact inference only possible with Gaussian likelihood");
            return;
        }
        if (x.length == 0) {
            System.out.println("Empty training dataset provided");
            return;
        }
        int n = x[0].length;
        int D = x.length;
        Matrix K = cov.eval(hyp.cov, x);
        if (debug) {
            System.out.println("**********K**************");
            for (int i2 = 0; i2 < K.getRowDimension(); ++i2) {
                for (int j = 0; j < K.getColumnDimension(); ++j) {
                    System.out.print(K.get(i2, j) + "\t");
                }
                System.out.println("");
            }
        }
        Matrix m = mean.eval(hyp.mean, x);
        if (debug) {
            System.out.println("**********m**************");
            for (int i3 = 0; i3 < m.getRowDimension(); ++i3) {
                for (int j = 0; j < m.getColumnDimension(); ++j) {
                    System.out.print(m.get(i3, j) + "\t");
                }
                System.out.println("");
            }
        }
        double sn2 = Math.exp(2.0 * hyp.lik[0]);
        for (int i4 = 0; i4 < D; ++i4) {
            for (int j = 0; j < D; ++j) {
                if (i4 == j) {
                    K.set(i4, j, K.get(i4, j) / sn2 + 1.0);
                    continue;
                }
                K.set(i4, j, K.get(i4, j) / sn2);
            }
        }
        CholeskyDecomposition chol = K.chol();
        Matrix y_minus_m = new Matrix(D, 1);
        for (i = 0; i < D; ++i) {
            y_minus_m.set(i, 0, y[i] - m.get(i, 0));
        }
        this.alpha = chol.solve(y_minus_m);
        this.L = chol.getL();
        for (i = 0; i < D; ++i) {
            this.alpha.set(i, 0, this.alpha.get(i, 0) / sn2);
        }
        this.sw = new Matrix(D, 1);
        for (i = 0; i < D; ++i) {
            this.sw.set(i, 0, 1.0 / Math.sqrt(sn2));
        }
        if (debug) {
            System.out.println("**********alpha**************");
            for (i = 0; i < this.alpha.getRowDimension(); ++i) {
                for (int j = 0; j < this.alpha.getColumnDimension(); ++j) {
                    System.out.print(this.alpha.get(i, j) + "\t");
                }
                System.out.println("");
            }
        }
        this.nlZ = 0.0;
        for (i = 0; i < D; ++i) {
            this.nlZ += y_minus_m.get(i, 0) * this.alpha.get(i, 0) / 2.0;
        }
        for (i = 0; i < D; ++i) {
            this.nlZ += Math.log(this.L.get(i, i));
        }
        this.nlZ += (double)D * Math.log(Math.PI * 2 * sn2) / 2.0;
        System.out.println("nlZ is " + this.nlZ);
        if (xs == null) {
            return;
        }
        int D22 = xs.length;
        Matrix kss = cov.selfVariance(hyp.cov, xs);
        if (debug) {
            System.out.println("**********kss**************");
            for (int i5 = 0; i5 < kss.getRowDimension(); ++i5) {
                for (int j = 0; j < kss.getColumnDimension(); ++j) {
                    System.out.print(kss.get(i5, j) + "\t");
                }
                System.out.println("");
            }
        }
        Matrix Ks = cov.crossVariance(hyp.cov, x, xs);
        if (debug) {
            System.out.println("**********ks**************");
            for (int i6 = 0; i6 < Ks.getRowDimension(); ++i6) {
                for (int j = 0; j < Ks.getColumnDimension(); ++j) {
                    System.out.print(Ks.get(i6, j) + "\t");
                }
                System.out.println("");
            }
        }
        Matrix ms = mean.eval(hyp.mean, xs);
        Matrix fmu = ms.plus(Ks.transpose().times(this.alpha));
        if (debug) {
            System.out.println("**********fmu**************");
            for (int i7 = 0; i7 < fmu.getRowDimension(); ++i7) {
                for (int j = 0; j < fmu.getColumnDimension(); ++j) {
                    System.out.print(fmu.get(i7, j) + "\t");
                }
                System.out.println("");
            }
        }
        boolean upperTriangularMatrix = true;
        block19: for (int i8 = 0; i8 < this.L.getColumnDimension(); ++i8) {
            for (int j = 0; j < this.L.getRowDimension(); ++j) {
                if (i8 >= j || this.L.get(i8, j) == 0.0) continue;
                upperTriangularMatrix = false;
                continue block19;
            }
        }
        Matrix fs2 = new Matrix(1, D22);
        if (upperTriangularMatrix) {
            int i9;
            int j;
            int i10;
            tmp = Ks.copy();
            for (i10 = 0; i10 < tmp.getRowDimension(); ++i10) {
                for (j = 0; j < tmp.getColumnDimension(); ++j) {
                    tmp.set(i10, j, Ks.get(i10, j) * this.sw.get(i10, 0));
                }
            }
            if (debug) {
                System.out.println("**********L**************");
                for (i10 = 0; i10 < this.L.getRowDimension(); ++i10) {
                    for (j = 0; j < this.L.getColumnDimension(); ++j) {
                        System.out.print(this.L.get(i10, j) + "\t");
                    }
                    System.out.println("");
                }
            }
            Matrix V = this.L.inverse().times(tmp);
            if (debug) {
                System.out.println("**********V**************");
                for (i9 = 0; i9 < V.getRowDimension(); ++i9) {
                    for (int j2 = 0; j2 < V.getColumnDimension(); ++j2) {
                        System.out.print(V.get(i9, j2) + "\t");
                    }
                    System.out.println("");
                }
            }
            for (i9 = 0; i9 < D22; ++i9) {
                double sum = 0.0;
                for (int j3 = 0; j3 < D; ++j3) {
                    sum += V.get(j3, i9) * V.get(j3, i9);
                }
                fs2.set(0, i9, Math.max(kss.get(i9, 0) - sum, 0.0));
            }
            if (debug) {
                System.out.println("**********fs2**************");
                for (i9 = 0; i9 < fs2.getRowDimension(); ++i9) {
                    for (int j4 = 0; j4 < fs2.getColumnDimension(); ++j4) {
                        System.out.print(fs2.get(i9, j4) + "\t");
                    }
                    System.out.println("");
                }
            }
        } else {
            tmp = Ks.times(this.L.times(Ks));
            for (int i11 = 0; i11 < D22; ++i11) {
                double sum = 0.0;
                for (int j = 0; j < D; ++j) {
                    sum += tmp.get(j, i11) * tmp.get(j, i11);
                }
                fs2.set(0, i11, Math.max(0.0, kss.get(i11, 0) + sum));
            }
        }
        Matrix[] result = lik.calc(hyp.lik, null, fmu.getColumnPackedCopy(), fs2.getColumnPackedCopy());
        this.lp = result[0];
        this.ymu = result[1];
        this.ys2 = result[2];
        if (debug) {
            System.out.println("**********lp**************");
            for (int i12 = 0; i12 < this.lp.getRowDimension(); ++i12) {
                for (int j = 0; j < this.lp.getColumnDimension(); ++j) {
                    System.out.print(this.lp.get(i12, j) + "\t");
                }
                System.out.println("");
            }
        }
    }

    @Override
    public double getNLZ() {
        return this.nlZ;
    }

    public static void main(String[] args) {
        double[][] x = GaussianProcessRegression.generateRandomX();
        double[] y = GaussianProcessRegression.generateRandomY();
        HyperParameter hyp = new HyperParameter();
        hyp.cov = new double[]{Math.log(0.25), Math.log(1.0)};
        hyp.mean = new double[]{0.5, 1.0};
        hyp.lik = new double[]{Math.log(0.1)};
        Sum mean = new Sum(new Linear(), new Constant());
        infExact inf = new infExact();
        covSEiso cov = new covSEiso();
        Gaussian lik = new Gaussian();
        double[][] z = new double[101][1];
        for (int i = 0; i < 101; ++i) {
            z[i][0] = -1.9 + (double)i / 100.0 * 3.8;
        }
        inf.inference(hyp, mean, cov, lik, x, y, z);
    }

    @Override
    public Matrix getMu() {
        return this.ymu;
    }

    @Override
    public Matrix getSigma2() {
        return this.ys2;
    }
}

