/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.svm;

import org.encog.engine.util.ErrorCalculation;
import org.encog.mathutil.libsvm.svm;
import org.encog.mathutil.libsvm.svm_parameter;
import org.encog.mathutil.libsvm.svm_problem;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.svm.KernelType;
import org.encog.neural.networks.svm.SVMNetwork;
import org.encog.neural.networks.training.BasicTraining;
import org.encog.neural.networks.training.svm.EncodeSVMProblem;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SVMTrain
extends BasicTraining {
    private static final transient Logger LOGGER = LoggerFactory.getLogger(SVMTrain.class);
    public static final double DEFAULT_CONST_BEGIN = -5.0;
    public static final double DEFAULT_CONST_END = 15.0;
    public static final double DEFAULT_CONST_STEP = 2.0;
    public static final double DEFAULT_GAMMA_BEGIN = -10.0;
    public static final double DEFAULT_GAMMA_END = 10.0;
    public static final double DEFAULT_GAMMA_STEP = 1.0;
    private SVMNetwork network;
    private svm_problem[] problem;
    private int fold = 5;
    private double constBegin = -5.0;
    private double constStep = 15.0;
    private double constEnd = 2.0;
    private double gammaBegin = -10.0;
    private double gammaEnd = 10.0;
    private double gammaStep = 1.0;
    private double[] bestConst;
    private double[] bestGamma;
    private double[] bestError;
    private double[] currentConst;
    private double[] currentGamma;
    private boolean isSetup;
    private boolean trainingDone;

    public SVMTrain(BasicNetwork basicNetwork, NeuralDataSet neuralDataSet) {
        this.network = (SVMNetwork)basicNetwork;
        this.setTraining(neuralDataSet);
        this.isSetup = false;
        this.trainingDone = false;
        this.problem = new svm_problem[this.network.getOutputCount()];
        for (int i = 0; i < this.network.getOutputCount(); ++i) {
            this.problem[i] = EncodeSVMProblem.encode(neuralDataSet, i);
        }
    }

    public void train() {
        double d = 1.0 / (double)this.network.getInputCount();
        double d2 = 1.0;
        for (int i = 0; i < this.network.getOutputCount(); ++i) {
            this.train(i, d, d2);
        }
    }

    public void train(int n, double d, double d2) {
        this.network.getParams()[n].C = d2;
        this.network.getParams()[n].gamma = d > 1.0E-7 ? 1.0 / (double)this.network.getInputCount() : d;
        this.network.getModels()[n] = svm.svm_train(this.problem[n], this.network.getParams()[n]);
    }

    public double crossValidate(int n, double d, double d2) {
        double[] dArray = new double[this.problem[0].l];
        this.network.getParams()[n].C = d2;
        this.network.getParams()[n].gamma = d;
        svm.svm_cross_validation(this.problem[n], this.network.getParams()[n], this.fold, dArray);
        return this.evaluate(this.network.getParams()[n], this.problem[n], dArray);
    }

    private double evaluate(svm_parameter svm_parameter2, svm_problem svm_problem2, double[] dArray) {
        int n = 0;
        ErrorCalculation errorCalculation = new ErrorCalculation();
        if (svm_parameter2.svm_type == 3 || svm_parameter2.svm_type == 4) {
            for (int i = 0; i < svm_problem2.l; ++i) {
                double d = svm_problem2.y[i];
                double d2 = dArray[i];
                errorCalculation.updateError(d2, d);
            }
            return errorCalculation.calculate();
        }
        for (int i = 0; i < svm_problem2.l; ++i) {
            if (dArray[i] != svm_problem2.y[i]) continue;
            ++n;
        }
        return 100.0 * (double)n / (double)svm_problem2.l;
    }

    private void setup() {
        this.currentConst = new double[this.network.getOutputCount()];
        this.currentGamma = new double[this.network.getOutputCount()];
        this.bestConst = new double[this.network.getOutputCount()];
        this.bestGamma = new double[this.network.getOutputCount()];
        this.bestError = new double[this.network.getOutputCount()];
        for (int i = 0; i < this.network.getOutputCount(); ++i) {
            this.currentConst[i] = this.constBegin;
            this.currentGamma[i] = this.gammaBegin;
            this.bestError[i] = Double.POSITIVE_INFINITY;
        }
        this.isSetup = true;
    }

    public void iteration() {
        if (!this.trainingDone) {
            if (!this.isSetup) {
                this.setup();
            }
            this.preIteration();
            if (this.network.getKernelType() == KernelType.RadialBasisFunction) {
                double d = 0.0;
                for (int i = 0; i < this.network.getOutputCount(); ++i) {
                    double d2 = this.crossValidate(i, this.currentGamma[i], this.currentConst[i]);
                    if (d2 < this.bestError[i]) {
                        this.bestConst[i] = this.currentConst[i];
                        this.bestGamma[i] = this.currentGamma[i];
                        this.bestError[i] = d2;
                    }
                    int n = i;
                    this.currentConst[n] = this.currentConst[n] + this.constStep;
                    if (this.currentConst[i] > this.constEnd) {
                        this.currentConst[i] = this.constBegin;
                        int n2 = i;
                        this.currentGamma[n2] = this.currentGamma[n2] + this.gammaStep;
                        if (this.currentGamma[i] > this.gammaEnd) {
                            this.trainingDone = true;
                        }
                    }
                    d += this.bestError[i];
                }
                this.setError(d / (double)this.network.getOutputCount());
            } else {
                this.train();
            }
            this.postIteration();
        }
    }

    public svm_problem[] getProblem() {
        return this.problem;
    }

    public int getFold() {
        return this.fold;
    }

    public void setFold(int n) {
        this.fold = n;
    }

    public double getConstBegin() {
        return this.constBegin;
    }

    public void setConstBegin(double d) {
        this.constBegin = d;
    }

    public double getConstStep() {
        return this.constStep;
    }

    public void setConstStep(double d) {
        this.constStep = d;
    }

    public double getConstEnd() {
        return this.constEnd;
    }

    public void setConstEnd(double d) {
        this.constEnd = d;
    }

    public double getGammaBegin() {
        return this.gammaBegin;
    }

    public void setGammaBegin(double d) {
        this.gammaBegin = d;
    }

    public double getGammaEnd() {
        return this.gammaEnd;
    }

    public void setGammaEnd(double d) {
        this.gammaEnd = d;
    }

    public double getGammaStep() {
        return this.gammaStep;
    }

    public void setGammaStep(double d) {
        this.gammaStep = d;
    }

    public void finishTraining() {
        for (int i = 0; i < this.network.getOutputCount(); ++i) {
            this.train(i, this.bestGamma[i], this.bestConst[i]);
        }
    }

    public BasicNetwork getNetwork() {
        return this.network;
    }

    public boolean isTrainingDone() {
        return this.trainingDone;
    }

    public void train(double d, double d2) {
        for (int i = 0; i < this.network.getOutputCount(); ++i) {
            this.train(i, d, d2);
        }
    }
}

