package org.encog.neural.networks.training.propagation.quick;

import org.encog.ml.data.MLDataSet;
import org.encog.neural.flat.train.prop.TrainFlatNetworkQPROP;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.EngineArray;
import org.encog.util.validate.ValidateNetwork;

/* loaded from: input_file:org/encog/neural/networks/training/propagation/quick/QuickPropagation.class */
public class QuickPropagation extends Propagation implements LearningRate {
    public static final String LAST_GRADIENTS = "LAST_GRADIENTS";

    public QuickPropagation(ContainsFlat containsFlat, MLDataSet mLDataSet) {
        this(containsFlat, mLDataSet, 2.0d);
    }

    public QuickPropagation(ContainsFlat containsFlat, MLDataSet mLDataSet, double d) {
        super(containsFlat, mLDataSet);
        ValidateNetwork.validateMethodToData(containsFlat, mLDataSet);
        setFlatTraining(new TrainFlatNetworkQPROP(containsFlat.getFlat(), getTraining(), d));
    }

    @Override // org.encog.ml.train.MLTrain
    public final boolean canContinue() {
        return false;
    }

    public final double[] getLastDelta() {
        return ((TrainFlatNetworkQPROP) getFlatTraining()).getLastDelta();
    }

    @Override // org.encog.neural.networks.training.LearningRate
    public final double getLearningRate() {
        return ((TrainFlatNetworkQPROP) getFlatTraining()).getLearningRate();
    }

    public final boolean isValidResume(TrainingContinuation trainingContinuation) {
        return trainingContinuation.getContents().containsKey("LAST_GRADIENTS") && trainingContinuation.getTrainingType().equals(getClass().getSimpleName()) && ((double[]) trainingContinuation.get("LAST_GRADIENTS")).length == ((ContainsFlat) getMethod()).getFlat().getWeights().length;
    }

    @Override // org.encog.ml.train.MLTrain
    public final TrainingContinuation pause() {
        TrainingContinuation trainingContinuation = new TrainingContinuation();
        trainingContinuation.setTrainingType(getClass().getSimpleName());
        trainingContinuation.set("LAST_GRADIENTS", ((TrainFlatNetworkQPROP) getFlatTraining()).getLastGradient());
        return trainingContinuation;
    }

    @Override // org.encog.ml.train.MLTrain
    public final void resume(TrainingContinuation trainingContinuation) {
        if (!isValidResume(trainingContinuation)) {
            throw new TrainingError("Invalid training resume data length");
        }
        EngineArray.arrayCopy((double[]) trainingContinuation.get("LAST_GRADIENTS"), ((TrainFlatNetworkQPROP) getFlatTraining()).getLastGradient());
    }

    @Override // org.encog.neural.networks.training.LearningRate
    public final void setLearningRate(double d) {
        ((TrainFlatNetworkQPROP) getFlatTraining()).setLearningRate(d);
    }

    public double getOutputEpsilon() {
        return ((TrainFlatNetworkQPROP) getFlatTraining()).getOutputEpsilon();
    }

    public double getShrink() {
        return ((TrainFlatNetworkQPROP) getFlatTraining()).getShrink();
    }

    public void setShrink(double d) {
        ((TrainFlatNetworkQPROP) getFlatTraining()).setShrink(d);
    }

    public void setOutputEpsilon(double d) {
        setOutputEpsilon(d);
    }
}
