/*
 * Decompiled with CFR 0.152.
 */
package jams.components.optimizer.gradient;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.util.Random;
import java.util.StringTokenizer;
import java.util.Vector;
import org.unijena.jams.data.JAMSDouble;
import org.unijena.jams.data.JAMSInteger;
import org.unijena.jams.data.JAMSString;
import org.unijena.jams.model.JAMSComponent;
import org.unijena.jams.model.JAMSComponentDescription;
import org.unijena.jams.model.JAMSContext;
import org.unijena.jams.model.JAMSVarDescription;

@JAMSComponentDescription(title="GradientDescent", author="Christian Fischer", description="distance driven monte carlo optimization with gradient descent")
public class SmartGradientDescent
extends JAMSContext {
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.INIT, description="List of parameter identifiers to be sampled")
    public JAMSString parameterIDs;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.INIT, description="List of parameter value bounaries corresponding to parameter identifiers")
    public JAMSString boundaries;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.INIT, description="efficiency methods")
    public JAMSString effMethodName;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READWRITE, update=JAMSVarDescription.UpdateType.RUN, description="efficiency values; note: only first value is optimized")
    public JAMSDouble[] effValue;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="file with optimization information and best parameter set")
    public JAMSString resultFile;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="workspace directory")
    public JAMSString dirName;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="number of monte carlo runs to determine current minimal distance")
    public JAMSInteger MonteCarloParameter;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.RUN, description="minimal distance until optimization is stopped")
    public JAMSDouble MinimalDistance;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READ, update=JAMSVarDescription.UpdateType.INIT, description="worst efficiency which is accepted")
    public JAMSDouble ValueBoundary;
    @JAMSVarDescription(access=JAMSVarDescription.AccessType.READWRITE, update=JAMSVarDescription.UpdateType.RUN, description="Minimization,Mazimation or absolute optimization")
    public JAMSInteger OptimizationType;
    static final int MAXIMIZATION = 1;
    static final int MINIMIZATION = 2;
    static final int ABSMAXIMIZATION = 3;
    static final int ABSMINIMIZATION = 4;
    String[] parameterNames;
    JAMSDouble[] parameters = null;
    double[] lowBound;
    double[] upBound;
    Vector<double[]> pointList = null;
    int numVisitedPoints;
    double[] bestpoint;
    double bestvalue;
    double lowLipschitzBound;
    double alpha_min = 0.01;
    double diff_min = 1.0E-7;
    double approxError = 0.01;
    Random generator;
    BufferedWriter writer;

    public void init() {
        String key;
        this.generator = new Random();
        this.generator.setSeed(System.currentTimeMillis());
        this.pointList = new Vector(10000, 0);
        this.numVisitedPoints = 0;
        this.lowLipschitzBound = 1.0;
        this.TranslateToMaximization(this.ValueBoundary);
        StringTokenizer tok = new StringTokenizer(this.parameterIDs.getValue(), ";");
        this.parameters = new JAMSDouble[tok.countTokens()];
        this.parameterNames = new String[tok.countTokens()];
        int i = 0;
        while (tok.hasMoreTokens()) {
            this.parameterNames[i] = key = tok.nextToken();
            this.parameters[i] = (JAMSDouble)this.getModel().getRuntime().getDataHandles().get(key);
            ++i;
        }
        tok = new StringTokenizer(this.boundaries.getValue(), ";");
        int n = tok.countTokens();
        this.lowBound = new double[n];
        this.upBound = new double[n];
        if (n != i) {
            this.getModel().getRuntime().sendHalt("Component " + this.getInstanceName() + ": Different number of parameterIDs and boundaries!");
        }
        i = 0;
        while (tok.hasMoreTokens()) {
            key = tok.nextToken();
            key = key.substring(1, key.length() - 1);
            StringTokenizer boundTok = new StringTokenizer(key, ">");
            this.lowBound[i] = Double.parseDouble(boundTok.nextToken());
            this.upBound[i] = Double.parseDouble(boundTok.nextToken());
            if (this.upBound[i] <= this.lowBound[i]) {
                this.getModel().getRuntime().sendHalt("Component " + this.getInstanceName() + ": upBound must be higher than lowBound!");
            }
            ++i;
        }
        i = 0;
        tok = new StringTokenizer(this.effMethodName.getValue(), ";");
        String[] effNames = new String[tok.countTokens()];
        i = 0;
        while (tok.hasMoreTokens()) {
            effNames[i] = key = tok.nextToken();
            ++i;
        }
        if (effNames.length != 1) {
            this.getModel().getRuntime().sendHalt("Cant process multiobjective optimization problems!!");
        }
        this.bestvalue = Double.NEGATIVE_INFINITY;
    }

    private double[] RandomSampler() {
        int paras = this.parameterNames.length;
        double[] sample = new double[paras];
        for (int i = 0; i < paras; ++i) {
            double d = this.generator.nextDouble();
            sample[i] = this.lowBound[i] + d * (this.upBound[i] - this.lowBound[i]);
        }
        return sample;
    }

    private double CalcDistance(double[] x0, double[] x1) {
        double dist = 0.0;
        for (int i = 0; i < this.parameters.length; ++i) {
            dist += (x0[i] - x1[i]) * (x0[i] - x1[i]);
        }
        return Math.sqrt(dist);
    }

    private void UpdateLipschitz(double[] x1) {
        if (this.numVisitedPoints > 10000) {
            return;
        }
        for (int i = 0; i < this.numVisitedPoints; ++i) {
            double L;
            double[] x0 = this.pointList.get(i);
            double f0 = x0[this.parameters.length];
            double f1 = x1[this.parameters.length];
            double distance = this.CalcDistance(x0, x1);
            if (distance < 0.001 || !(this.lowLipschitzBound < (L = Math.abs((f1 - f0) / distance)))) continue;
            this.lowLipschitzBound = L;
            this.UpdateForbiddenCircles();
        }
    }

    private void UpdateForbiddenCircles() {
        for (int i = 0; i < this.numVisitedPoints; ++i) {
            double[] x0 = this.pointList.get(i);
            double f0 = x0[this.parameters.length];
            x0[this.parameters.length + 1] = (f0 - this.bestvalue) / this.lowLipschitzBound;
        }
    }

    private void TranslateToMaximization(JAMSDouble value) {
        if (this.OptimizationType.getValue() != 1) {
            if (this.OptimizationType.getValue() == 2) {
                value.setValue(-value.getValue());
            } else if (this.OptimizationType.getValue() == 3) {
                value.setValue(Math.abs(value.getValue()));
            } else if (this.OptimizationType.getValue() == 4) {
                value.setValue(-Math.abs(value.getValue()));
            }
        }
    }

    private boolean singleRun() {
        double f0;
        double[] nextPoint = new double[this.parameters.length + 2];
        for (int i = 0; i < this.parameters.length; ++i) {
            if (!(this.parameters[i].getValue() < 0.0) && !(this.parameters[i].getValue() > 0.0) || this.parameters[i].getValue() < 0.0 && this.parameters[i].getValue() > 0.0) {
                System.out.println("Single Run Failed!");
                return false;
            }
            nextPoint[i] = this.parameters[i].getValue();
        }
        this.runEnumerator.reset();
        while (this.runEnumerator.hasNext() && this.doRun) {
            JAMSComponent comp = this.runEnumerator.next();
            try {
                comp.init();
            }
            catch (Exception e) {
                System.out.println(e.toString());
            }
        }
        this.runEnumerator.reset();
        while (this.runEnumerator.hasNext() && this.doRun) {
            JAMSComponent comp = this.runEnumerator.next();
            try {
                comp.run();
            }
            catch (Exception e) {
                System.out.println(e.getMessage());
            }
        }
        this.runEnumerator.reset();
        while (this.runEnumerator.hasNext() && this.doRun) {
            JAMSComponent comp = this.runEnumerator.next();
            try {
                comp.cleanup();
            }
            catch (Exception e) {
                System.out.println(e.getMessage());
            }
        }
        this.TranslateToMaximization(this.effValue[0]);
        if (this.effValue[0].getValue() < this.ValueBoundary.getValue()) {
            this.effValue[0].setValue(this.ValueBoundary.getValue());
        }
        nextPoint[this.parameters.length] = f0 = this.effValue[0].getValue();
        this.UpdateLipschitz(nextPoint);
        nextPoint[this.parameters.length + 1] = Math.abs((f0 - this.bestvalue) / this.lowLipschitzBound);
        this.pointList.add(nextPoint);
        ++this.numVisitedPoints;
        if (f0 > this.bestvalue) {
            this.bestvalue = f0;
            this.bestpoint = nextPoint;
            this.UpdateForbiddenCircles();
            try {
                double realValue = this.OptimizationType.getValue() == 2 || this.OptimizationType.getValue() == 4 ? -this.bestvalue : this.bestvalue;
                String output = "A new best point has been found! Value: " + realValue;
                this.getModel().getRuntime().println(output);
                this.writer.write(output);
                this.writer.newLine();
                output = "Parameters: ";
                for (int k = 0; k < this.parameters.length; ++k) {
                    output = output + this.bestpoint[k] + ",";
                }
                this.getModel().getRuntime().println(output);
                this.writer.write(output);
                this.writer.newLine();
                this.writer.flush();
            }
            catch (Exception e) {
                System.out.println("Could not write to output file because:" + e.toString());
            }
        }
        return true;
    }

    private boolean IsSampleValid(JAMSDouble[] sample) {
        int paras = this.parameterNames.length;
        boolean criticalPara = false;
        double criticalParaValue = 0.0;
        for (int i = 0; i < paras; ++i) {
            if (!(sample[i].getValue() < this.lowBound[i]) && !(sample[i].getValue() > this.upBound[i])) continue;
            return false;
        }
        return true;
    }

    private double calcMinimalDist(double[] x) {
        double dist = 0.0;
        double mindist = Double.POSITIVE_INFINITY;
        for (int i = 0; i < this.numVisitedPoints; ++i) {
            double[] otherpoint = this.pointList.get(i);
            dist = this.CalcDistance(otherpoint, x);
            if ((dist -= otherpoint[this.parameters.length + 1]) < 0.0) {
                return 0.0;
            }
            if (!(dist < mindist)) continue;
            mindist = dist;
        }
        return mindist;
    }

    public void GradientDescent(double[] x) {
        double[] grad = new double[this.parameters.length];
        double alpha = 0.1;
        double diff = 1.0;
        while (alpha > this.alpha_min && diff > this.diff_min) {
            int i;
            int i2;
            for (i2 = 0; i2 < this.parameters.length; ++i2) {
                this.parameters[i2].setValue(x[i2]);
            }
            if (!this.singleRun()) {
                return;
            }
            double y1 = this.effValue[0].getValue();
            if (y1 <= this.ValueBoundary.getValue()) {
                return;
            }
            for (i2 = 0; i2 < this.parameters.length; ++i2) {
                for (int j = 0; j < this.parameters.length; ++j) {
                    if (j == i2) {
                        this.parameters[j].setValue(x[j] + this.approxError);
                        continue;
                    }
                    this.parameters[j].setValue(x[j]);
                }
                if (!this.IsSampleValid(this.parameters)) {
                    grad[i2] = 0.0;
                    continue;
                }
                if (!this.singleRun()) {
                    return;
                }
                double y2 = this.effValue[0].getValue();
                grad[i2] = (y2 - y1) / this.approxError;
            }
            alpha *= 4.0;
            do {
                for (i2 = 0; i2 < this.parameters.length; ++i2) {
                    this.parameters[i2].setValue(x[i2] + alpha * grad[i2]);
                }
                if (!this.IsSampleValid(this.parameters)) continue;
                if (!this.singleRun()) {
                    return;
                }
                if (this.effValue[0].getValue() > y1) break;
            } while (!((alpha /= 2.0) < this.alpha_min));
            String info = "Gradient:\t";
            for (i = 0; i < this.parameters.length; ++i) {
                int n = i;
                x[n] = x[n] + alpha * grad[i];
                info = info + grad[i] + "\t";
            }
            this.getModel().getRuntime().println(info);
            info = "Stelle:\t\t";
            for (i = 0; i < this.parameters.length; ++i) {
                info = info + this.parameters[i].getValue() + "\t";
            }
            this.getModel().getRuntime().println(info);
            this.getModel().getRuntime().println("Funktionswert:\t" + y1 + "\t Alpha: " + alpha);
        }
    }

    public void run() {
        String output;
        try {
            this.writer = new BufferedWriter(new FileWriter(this.dirName.getValue() + this.resultFile.getValue()));
        }
        catch (Exception e) {
            System.out.println("Could not open result file, becauce:" + e.toString());
        }
        if (this.runEnumerator == null) {
            this.runEnumerator = this.getChildrenEnumerator();
        }
        while (true) {
            double[] x = null;
            double distance = -1.0;
            for (int i = 0; i < this.MonteCarloParameter.getValue(); ++i) {
                double[] mc_point = this.RandomSampler();
                double distanceOfMCPoint = this.calcMinimalDist(mc_point);
                if (!(distanceOfMCPoint > distance)) continue;
                x = mc_point;
                distance = distanceOfMCPoint;
            }
            output = "current minimal distance: " + distance + "\n" + "lowerlipschitzbound: " + this.lowLipschitzBound;
            this.getModel().getRuntime().println(output);
            try {
                this.writer.newLine();
                this.writer.write(output);
                this.writer.newLine();
            }
            catch (Exception e) {
                System.out.println("Could not write to output file because:" + e.toString());
            }
            if (distance < this.MinimalDistance.getValue()) {
                output = "optimization has stopped because: MinimalDistance has reached limit";
                this.getModel().getRuntime().println(output);
                try {
                    this.writer.write(output);
                    this.writer.newLine();
                }
                catch (Exception e) {
                    System.out.println("Could not write to output file because:" + e.toString());
                }
                break;
            }
            this.GradientDescent(x);
        }
        try {
            output = "Result of Optimization: Value: " + this.bestvalue;
            this.getModel().getRuntime().println(output);
            this.writer.write(output);
            this.writer.newLine();
            output = "Parameters: ";
            for (int k = 0; k < this.bestpoint.length; ++k) {
                output = output + this.bestpoint[k] + ",";
            }
            this.getModel().getRuntime().println(output);
            this.writer.write(output);
            this.writer.newLine();
            this.writer.flush();
            this.writer.close();
        }
        catch (Exception e) {
            System.out.println("Could not close output file because:" + e.toString());
        }
    }
}

