/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.structure;

import java.util.Arrays;
import java.util.List;
import org.encog.engine.util.EngineArray;
import org.encog.neural.NeuralNetworkError;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.layers.ContextLayer;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.structure.FlatUpdateNeeded;
import org.encog.neural.networks.synapse.Synapse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class NetworkCODEC {
    private static final Logger LOGGER = LoggerFactory.getLogger(NetworkCODEC.class);

    public static void arrayToNetwork(double[] dArray, BasicNetwork basicNetwork) {
        int n = 0;
        for (Layer layer : basicNetwork.getStructure().getLayers()) {
            n = NetworkCODEC.processLayer(basicNetwork, layer, dArray, n);
        }
        basicNetwork.getStructure().setFlatUpdate(FlatUpdateNeeded.Flatten);
    }

    public static boolean equals(BasicNetwork basicNetwork, BasicNetwork basicNetwork2, int n) {
        double[] dArray;
        double[] dArray2 = NetworkCODEC.networkToArray(basicNetwork);
        if (dArray2.length != (dArray = NetworkCODEC.networkToArray(basicNetwork2)).length) {
            return false;
        }
        double d = Math.pow(10.0, n);
        if (Double.isInfinite(d) || d > 9.223372036854776E18) {
            String string = "Precision of " + n + " decimal places is not supported.";
            if (LOGGER.isErrorEnabled()) {
                LOGGER.error(string);
            }
            throw new NeuralNetworkError(string);
        }
        for (int i = 0; i < dArray2.length; ++i) {
            long l = (long)(dArray2[i] * d);
            long l2 = (long)(dArray[i] * d);
            if (l == l2) continue;
            return false;
        }
        return true;
    }

    public static boolean equals(BasicNetwork basicNetwork, BasicNetwork basicNetwork2) {
        double[] dArray;
        double[] dArray2 = NetworkCODEC.networkToArray(basicNetwork);
        if (dArray2.length != (dArray = NetworkCODEC.networkToArray(basicNetwork2)).length) {
            return false;
        }
        return Arrays.equals(dArray2, dArray);
    }

    public static int networkSize(BasicNetwork basicNetwork) {
        if (basicNetwork.getStructure().getFlat() != null && (basicNetwork.getStructure().getFlatUpdate() == FlatUpdateNeeded.None || basicNetwork.getStructure().getFlatUpdate() == FlatUpdateNeeded.Unflatten)) {
            return basicNetwork.getStructure().getFlat().getWeights().length;
        }
        int n = 0;
        for (Layer layer : basicNetwork.getStructure().getLayers()) {
            Synapse synapse = basicNetwork.getStructure().findPreviousSynapseByLayerType(layer, BasicLayer.class);
            Synapse synapse2 = basicNetwork.getStructure().findPreviousSynapseByLayerType(layer, ContextLayer.class);
            List<Synapse> list = basicNetwork.getStructure().getPreviousSynapses(layer);
            if (synapse == null && synapse2 == null && list.size() > 0) {
                synapse = list.get(0);
            }
            if (synapse == null || synapse.getMatrix() == null) continue;
            for (int i = 0; i < synapse.getToNeuronCount(); ++i) {
                n += synapse.getFromNeuronCount();
                if (synapse.getToLayer().hasBias()) {
                    ++n;
                }
                if (synapse2 == null) continue;
                n += synapse2.getFromNeuronCount();
            }
        }
        return n;
    }

    public static double[] networkToArray(BasicNetwork basicNetwork) {
        int n = NetworkCODEC.networkSize(basicNetwork);
        if (basicNetwork.getStructure().getFlat() != null && (basicNetwork.getStructure().getFlatUpdate() == FlatUpdateNeeded.None || basicNetwork.getStructure().getFlatUpdate() == FlatUpdateNeeded.Unflatten)) {
            return EngineArray.arrayCopy(basicNetwork.getStructure().getFlat().getWeights());
        }
        double[] dArray = new double[n];
        int n2 = 0;
        for (Layer layer : basicNetwork.getStructure().getLayers()) {
            Synapse synapse = basicNetwork.getStructure().findPreviousSynapseByLayerType(layer, BasicLayer.class);
            Synapse synapse2 = basicNetwork.getStructure().findPreviousSynapseByLayerType(layer, ContextLayer.class);
            List<Synapse> list = basicNetwork.getStructure().getPreviousSynapses(layer);
            if (synapse == null && synapse2 == null && list.size() > 0) {
                synapse = list.get(0);
            }
            if (synapse == null || synapse.getMatrix() == null) continue;
            for (int i = 0; i < synapse.getToNeuronCount(); ++i) {
                int n3;
                for (n3 = 0; n3 < synapse.getFromNeuronCount(); ++n3) {
                    dArray[n2++] = synapse.getMatrix().get(n3, i);
                }
                if (synapse.getToLayer().hasBias()) {
                    dArray[n2++] = synapse.getToLayer().getBiasWeights()[i];
                }
                if (synapse2 == null) continue;
                for (n3 = 0; n3 < synapse2.getFromNeuronCount(); ++n3) {
                    dArray[n2++] = synapse2.getMatrix().get(n3, i);
                }
            }
        }
        return dArray;
    }

    private static int processLayer(BasicNetwork basicNetwork, Layer layer, double[] dArray, int n) {
        int n2 = n;
        Synapse synapse = basicNetwork.getStructure().findPreviousSynapseByLayerType(layer, BasicLayer.class);
        Synapse synapse2 = basicNetwork.getStructure().findPreviousSynapseByLayerType(layer, ContextLayer.class);
        List<Synapse> list = basicNetwork.getStructure().getPreviousSynapses(layer);
        if (synapse == null && synapse2 == null && list.size() > 0) {
            synapse = list.get(0);
        }
        if (synapse != null && synapse.getMatrix() != null) {
            for (int i = 0; i < synapse.getToNeuronCount(); ++i) {
                int n3;
                for (n3 = 0; n3 < synapse.getFromNeuronCount(); ++n3) {
                    synapse.getMatrix().set(n3, i, dArray[n2++]);
                }
                if (synapse.getToLayer().hasBias()) {
                    synapse.getToLayer().getBiasWeights()[i] = dArray[n2++];
                }
                if (synapse2 == null) continue;
                for (n3 = 0; n3 < synapse2.getFromNeuronCount(); ++n3) {
                    double d = dArray[n2++];
                    double d2 = synapse2.getMatrix().get(n3, i);
                    if (Math.abs(d2) < basicNetwork.getStructure().getConnectionLimit()) {
                        d = 0.0;
                    }
                    synapse2.getMatrix().set(n3, i, d);
                }
            }
        }
        return n2;
    }

    private NetworkCODEC() {
    }
}

