package org.encog.neural.networks.training.propagation;

import org.encog.EncogError;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.flat.train.TrainFlatNetwork;
import org.encog.neural.flat.train.prop.TrainFlatNetworkProp;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.Train;
import org.encog.util.EncogValidate;
import org.encog.util.concurrency.MultiThreadable;
import org.encog.util.logging.EncogLogging;

/* loaded from: input_file:org/encog/neural/networks/training/propagation/Propagation.class */
public abstract class Propagation extends BasicTraining implements Train, MultiThreadable {
    private final ContainsFlat network;
    private FlatNetwork currentFlatNetwork;
    private TrainFlatNetwork flatTraining;

    public Propagation(ContainsFlat containsFlat, MLDataSet mLDataSet) {
        super(TrainingImplementationType.Iterative);
        this.network = containsFlat;
        setTraining(mLDataSet);
    }

    @Override // org.encog.ml.train.BasicTraining, org.encog.ml.train.MLTrain
    public final void finishTraining() {
        super.finishTraining();
        this.flatTraining.finishTraining();
    }

    public final FlatNetwork getCurrentFlatNetwork() {
        return this.currentFlatNetwork;
    }

    public final TrainFlatNetwork getFlatTraining() {
        return this.flatTraining;
    }

    @Override // org.encog.ml.train.MLTrain
    public final MLMethod getMethod() {
        return this.network;
    }

    @Override // org.encog.util.concurrency.MultiThreadable
    public final int getThreadCount() {
        return this.flatTraining.getNumThreads();
    }

    @Override // org.encog.ml.train.MLTrain
    public final void iteration() {
        try {
            preIteration();
            this.flatTraining.iteration();
            setError(this.flatTraining.getError());
            postIteration();
            EncogLogging.log(1, "Training iteration done, error: " + getError());
        } catch (ArrayIndexOutOfBoundsException e) {
            EncogValidate.validateNetworkForTraining(this.network, getTraining());
            throw new EncogError(e);
        }
    }

    @Override // org.encog.ml.train.BasicTraining, org.encog.ml.train.MLTrain
    public final void iteration(int i) {
        try {
            preIteration();
            this.flatTraining.iteration(i);
            setIteration(this.flatTraining.getIteration());
            setError(this.flatTraining.getError());
            postIteration();
            EncogLogging.log(1, "Training iterations done, error: " + getError());
        } catch (ArrayIndexOutOfBoundsException e) {
            EncogValidate.validateNetworkForTraining(this.network, getTraining());
            throw new EncogError(e);
        }
    }

    public final void setFlatTraining(TrainFlatNetwork trainFlatNetwork) {
        this.flatTraining = trainFlatNetwork;
    }

    @Override // org.encog.util.concurrency.MultiThreadable
    public final void setThreadCount(int i) {
        this.flatTraining.setNumThreads(i);
    }

    public void fixFlatSpot(boolean z) {
        ((TrainFlatNetworkProp) this.flatTraining).fixFlatSpot(z);
    }

    public void setErrorFunction(ErrorFunction errorFunction) {
        ((TrainFlatNetworkProp) this.flatTraining).setErrorFunction(errorFunction);
    }
}
