/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.structure;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.activation.ActivationLinear;
import org.encog.engine.network.flat.FlatLayer;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.flat.FlatNetworkRBF;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.ObjectPair;
import org.encog.mathutil.matrices.Matrix;
import org.encog.neural.NeuralNetworkError;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.layers.ContextLayer;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.layers.RadialBasisFunctionLayer;
import org.encog.neural.networks.logic.FeedforwardLogic;
import org.encog.neural.networks.logic.SimpleRecurrentLogic;
import org.encog.neural.networks.structure.FlatUpdateNeeded;
import org.encog.neural.networks.structure.LayerComparator;
import org.encog.neural.networks.structure.NetworkCODEC;
import org.encog.neural.networks.structure.SynapseComparator;
import org.encog.neural.networks.structure.ValidateForFlat;
import org.encog.neural.networks.synapse.Synapse;
import org.encog.persist.EncogPersistedObject;
import org.encog.util.obj.ReflectionUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class NeuralStructure
implements Serializable {
    private static final long serialVersionUID = -2929683885395737817L;
    private static final transient Logger LOGGER = LoggerFactory.getLogger(NeuralStructure.class);
    private final List<Layer> layers = new ArrayList<Layer>();
    private final List<Synapse> synapses = new ArrayList<Synapse>();
    private final BasicNetwork network;
    private double connectionLimit;
    private boolean connectionLimited;
    private int nextID = 1;
    private transient FlatNetwork flat;
    private transient FlatUpdateNeeded flatUpdate;

    public NeuralStructure(BasicNetwork basicNetwork) {
        this.network = basicNetwork;
        this.flatUpdate = FlatUpdateNeeded.None;
    }

    public void assignID() {
        for (Layer layer : this.layers) {
            this.assignID(layer);
        }
        this.sort();
    }

    public void assignID(Layer layer) {
        if (layer.getID() == -1) {
            layer.setID(this.getNextID());
        }
    }

    public int calculateSize() {
        return NetworkCODEC.networkSize(this.network);
    }

    public boolean containsLayerType(Class<?> clazz) {
        for (Layer layer : this.layers) {
            if (!ReflectionUtil.isInstanceOf(layer.getClass(), clazz)) continue;
            return true;
        }
        return false;
    }

    private int countNonContext() {
        int n = 0;
        for (Layer layer : this.getLayers()) {
            if (layer.getClass() == ContextLayer.class) continue;
            ++n;
        }
        return n;
    }

    public void enforceLimit() {
        if (!this.connectionLimited) {
            return;
        }
        for (Synapse synapse : this.synapses) {
            Matrix matrix = synapse.getMatrix();
            if (matrix == null) continue;
            for (int i = 0; i < matrix.getRows(); ++i) {
                for (int j = 0; j < matrix.getCols(); ++j) {
                    double d = matrix.get(i, j);
                    if (!(Math.abs(d) < this.connectionLimit)) continue;
                    matrix.set(i, j, 0.0);
                }
            }
        }
    }

    private void finalizeLayers() {
        Serializable serializable;
        if (this.network.getLogic().getClass() == FeedforwardLogic.class || this.network.getLogic().getClass() == SimpleRecurrentLogic.class) {
            serializable = this.network.getLayer("INPUT");
            serializable.setBiasWeights(null);
        }
        serializable = new ArrayList();
        this.layers.clear();
        for (Layer layer : this.network.getLayerTags().values()) {
            this.getLayers((List<Layer>)((Object)serializable), layer);
        }
        this.layers.addAll((Collection<Layer>)((Object)serializable));
        for (Layer layer : this.layers) {
            if (layer.getID() < this.nextID) continue;
            this.nextID = layer.getID() + 1;
        }
        this.sort();
    }

    private void finalizeLimit() {
        String string = this.network.getPropertyString("CONNECTION_LIMIT");
        if (string != null) {
            try {
                this.connectionLimited = true;
                this.connectionLimit = Double.parseDouble(string);
            }
            catch (NumberFormatException numberFormatException) {
                throw new NeuralNetworkError("Invalid property(CONNECTION_LIMIT):" + string);
            }
        } else {
            this.connectionLimited = false;
            this.connectionLimit = 0.0;
        }
    }

    public void finalizeStructure() {
        this.finalizeLayers();
        this.finalizeSynapses();
        this.finalizeLimit();
        Collections.sort(this.layers);
        this.assignID();
        this.network.getLogic().init(this.network);
        this.enforceLimit();
        this.flatten();
    }

    private void finalizeSynapses() {
        HashSet<Synapse> hashSet = new HashSet<Synapse>();
        for (Layer layer : this.getLayers()) {
            for (Synapse synapse : layer.getNext()) {
                hashSet.add(synapse);
            }
        }
        this.synapses.clear();
        this.synapses.addAll(hashSet);
    }

    private double findNextBias(Layer layer) {
        Layer layer2;
        Synapse synapse;
        double d = 0.0;
        if (layer.getNext().size() > 0 && (synapse = this.network.getStructure().findNextSynapseByLayerType(layer, BasicLayer.class)) != null && (layer2 = synapse.getToLayer()).hasBias()) {
            d = layer2.getBiasActivation();
        }
        return d;
    }

    public Synapse findNextSynapseByLayerType(Layer layer, Class<? extends Layer> clazz) {
        for (Synapse synapse : layer.getNext()) {
            if (synapse.getToLayer().getClass() != clazz) continue;
            return synapse;
        }
        return null;
    }

    public Synapse findPreviousSynapseByLayerType(Layer layer, Class<? extends Layer> clazz) {
        for (Synapse synapse : this.getPreviousSynapses(layer)) {
            if (synapse.getFromLayer().getClass() != clazz) continue;
            return synapse;
        }
        return null;
    }

    public Synapse findSynapse(Layer layer, Layer layer2, boolean bl) {
        Synapse synapse = null;
        for (Synapse synapse2 : this.getSynapses()) {
            if (synapse2.getFromLayer() != layer || synapse2.getToLayer() != layer2) continue;
            synapse = synapse2;
            break;
        }
        if (bl && synapse == null) {
            String string = "This operation requires a network with a synapse between the " + this.nameLayer(layer) + " layer to the " + this.nameLayer(layer2) + " layer.";
            if (LOGGER.isErrorEnabled()) {
                LOGGER.error(string);
            }
            throw new NeuralNetworkError(string);
        }
        return synapse;
    }

    public void flatten() {
        HashMap<Layer, FlatLayer> hashMap = new HashMap<Layer, FlatLayer>();
        HashMap<FlatLayer, Layer> hashMap2 = new HashMap<FlatLayer, Layer>();
        ArrayList<ObjectPair<Object, Object>> arrayList = new ArrayList<ObjectPair<Object, Object>>();
        this.flat = null;
        ValidateForFlat validateForFlat = new ValidateForFlat();
        if (validateForFlat.isValid(this.network) == null) {
            Object object;
            Object object2;
            Synapse synapse;
            EncogPersistedObject encogPersistedObject;
            if (this.layers.size() == 3 && this.layers.get(1) instanceof RadialBasisFunctionLayer) {
                RadialBasisFunctionLayer radialBasisFunctionLayer = (RadialBasisFunctionLayer)this.layers.get(1);
                for (Layer layer : this.layers) {
                    if (!layer.hasBias()) continue;
                    throw new NeuralNetworkError("Bias cannot be used with an RBF neural network.");
                }
                this.flat = new FlatNetworkRBF(this.network.getInputCount(), radialBasisFunctionLayer.getNeuronCount(), this.network.getOutputCount(), radialBasisFunctionLayer.getRadialBasisFunction());
                this.flattenWeights();
                this.flatUpdate = FlatUpdateNeeded.None;
                return;
            }
            int n = this.countNonContext();
            FlatLayer[] flatLayerArray = new FlatLayer[n];
            int n2 = flatLayerArray.length - 1;
            for (Layer object3 : this.layers) {
                if (object3 instanceof ContextLayer) {
                    encogPersistedObject = this.network.getStructure().findPreviousSynapseByLayerType(object3, BasicLayer.class);
                    synapse = this.network.getStructure().findNextSynapseByLayerType(object3, BasicLayer.class);
                    if (encogPersistedObject == null) {
                        throw new NeuralNetworkError("Context layer must be connected to by one BasicLayer.");
                    }
                    if (synapse == null) {
                        throw new NeuralNetworkError("Context layer must connect to by one BasicLayer.");
                    }
                    object2 = encogPersistedObject.getFromLayer();
                    object = synapse.getToLayer();
                    arrayList.add(new ObjectPair<Object, Object>(object2, object));
                    continue;
                }
                double d = this.findNextBias(object3);
                object = new double[1];
                if (object3.getActivationFunction() == null) {
                    object2 = new ActivationLinear();
                    object = new double[]{1.0};
                } else {
                    object2 = object3.getActivationFunction();
                    object = object3.getActivationFunction().getParams();
                }
                FlatLayer flatLayer = new FlatLayer((ActivationFunction)object2, object3.getNeuronCount(), d, (double[])object);
                hashMap.put(object3, flatLayer);
                hashMap2.put(flatLayer, object3);
                flatLayerArray[n2--] = flatLayer;
            }
            for (ObjectPair objectPair : arrayList) {
                encogPersistedObject = (Layer)objectPair.getB();
                synapse = this.network.getStructure().findPreviousSynapseByLayerType((Layer)encogPersistedObject, BasicLayer.class);
                object2 = (FlatLayer)hashMap.get(objectPair.getA());
                object = (FlatLayer)hashMap.get(synapse.getFromLayer());
                object.setContextFedBy((FlatLayer)object2);
            }
            this.flat = new FlatNetwork(flatLayerArray);
            for (int i = 0; i < n; ++i) {
                FlatLayer flatLayer = flatLayerArray[i].getContextFedBy();
                if (flatLayer == null) continue;
                encogPersistedObject = (Layer)hashMap2.get(flatLayerArray[i + 1]);
                synapse = this.findPreviousSynapseByLayerType((Layer)encogPersistedObject, ContextLayer.class);
                if (synapse == null) {
                    throw new NeuralNetworkError("Can't find parent synapse to context layer.");
                }
                object2 = (ContextLayer)synapse.getFromLayer();
                int n3 = -1;
                for (int j = 0; j < n; ++j) {
                    if (flatLayerArray[j] != flatLayer) continue;
                    n3 = j;
                    break;
                }
                if (n3 == -1) {
                    throw new NeuralNetworkError("Can't find layer feeding context.");
                }
                ((ContextLayer)object2).setFlatContextIndex(this.flat.getContextTargetOffset()[n3]);
            }
            this.flattenWeights();
            this.flatUpdate = FlatUpdateNeeded.None;
        } else {
            this.flatUpdate = FlatUpdateNeeded.Never;
        }
    }

    public void flattenWeights() {
        if (this.flat != null) {
            this.flatUpdate = FlatUpdateNeeded.Flatten;
            double[] dArray = this.flat.getWeights();
            double[] dArray2 = NetworkCODEC.networkToArray(this.network);
            EngineArray.arrayCopy(dArray2, dArray);
            this.flatUpdate = FlatUpdateNeeded.None;
            for (Layer layer : this.layers) {
                ContextLayer contextLayer;
                if (!(layer instanceof ContextLayer) || (contextLayer = (ContextLayer)layer).getFlatContextIndex() == -1) continue;
                EngineArray.arrayCopy(contextLayer.getContext().getData(), 0, this.flat.getLayerOutput(), contextLayer.getFlatContextIndex(), contextLayer.getContext().size());
            }
            if (this.connectionLimited) {
                this.flat.setConnectionLimit(this.connectionLimit);
            } else {
                this.flat.clearConnectionLimit();
            }
        }
    }

    public double getConnectionLimit() {
        return this.connectionLimit;
    }

    public FlatNetwork getFlat() {
        return this.flat;
    }

    public FlatUpdateNeeded getFlatUpdate() {
        return this.flatUpdate;
    }

    public List<Layer> getLayers() {
        return this.layers;
    }

    private void getLayers(List<Layer> list, Layer layer) {
        if (!list.contains(layer)) {
            list.add(layer);
        }
        for (Synapse synapse : layer.getNext()) {
            Layer layer2 = synapse.getToLayer();
            if (list.contains(layer2)) continue;
            this.getLayers(list, layer2);
        }
    }

    public BasicNetwork getNetwork() {
        return this.network;
    }

    public int getNextID() {
        return this.nextID++;
    }

    public Collection<Layer> getPreviousLayers(Layer layer) {
        HashSet<Layer> hashSet = new HashSet<Layer>();
        for (Layer layer2 : this.getLayers()) {
            for (Synapse synapse : layer2.getNext()) {
                if (synapse.getToLayer() != layer) continue;
                hashSet.add(synapse.getFromLayer());
            }
        }
        return hashSet;
    }

    public List<Synapse> getPreviousSynapses(Layer layer) {
        ArrayList<Synapse> arrayList = new ArrayList<Synapse>();
        for (Synapse synapse : this.synapses) {
            if (synapse.getToLayer() != layer || arrayList.contains(synapse)) continue;
            arrayList.add(synapse);
        }
        return arrayList;
    }

    public List<Synapse> getSynapses() {
        return this.synapses;
    }

    public boolean isConnectionLimited() {
        return this.connectionLimited;
    }

    public boolean isRecurrent() {
        for (Layer layer : this.getLayers()) {
            if (!(layer instanceof ContextLayer)) continue;
            return true;
        }
        return false;
    }

    public List<String> nameLayer(Layer layer) {
        ArrayList<String> arrayList = new ArrayList<String>();
        for (Map.Entry<String, Layer> entry : this.network.getLayerTags().entrySet()) {
            if (entry.getValue() != layer) continue;
            arrayList.add(entry.getKey());
        }
        return arrayList;
    }

    public void setFlatUpdate(FlatUpdateNeeded flatUpdateNeeded) {
        this.flatUpdate = flatUpdateNeeded;
    }

    public void sort() {
        Collections.sort(this.layers, new LayerComparator(this));
        Collections.sort(this.synapses, new SynapseComparator(this));
    }

    public void unflattenWeights() {
        if (this.flat != null) {
            double[] dArray = this.flat.getWeights();
            NetworkCODEC.arrayToNetwork(dArray, this.network);
            this.flatUpdate = FlatUpdateNeeded.None;
            for (Layer layer : this.layers) {
                ContextLayer contextLayer;
                if (!(layer instanceof ContextLayer) || (contextLayer = (ContextLayer)layer).getFlatContextIndex() == -1) continue;
                EngineArray.arrayCopy(this.flat.getLayerOutput(), contextLayer.getFlatContextIndex(), contextLayer.getContext().getData(), 0, contextLayer.getContext().size());
            }
        }
    }

    public void updateFlatNetwork() {
        if (this.flatUpdate == null) {
            this.flattenWeights();
            this.flatUpdate = FlatUpdateNeeded.None;
        }
        switch (this.flatUpdate) {
            case Flatten: {
                this.flattenWeights();
                break;
            }
            case Unflatten: {
                this.unflattenWeights();
                break;
            }
            case None: 
            case Never: {
                return;
            }
        }
        this.flatUpdate = FlatUpdateNeeded.None;
    }
}

