package org.encog.neural.networks.training.propagation.resilient;

import org.encog.ml.data.MLDataSet;
import org.encog.neural.flat.train.prop.RPROPType;
import org.encog.neural.flat.train.prop.TrainFlatNetworkResilient;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.EngineArray;

/* loaded from: input_file:org/encog/neural/networks/training/propagation/resilient/ResilientPropagation.class */
public class ResilientPropagation extends Propagation {
    public static final String LAST_GRADIENTS = "LAST_GRADIENTS";
    public static final String UPDATE_VALUES = "UPDATE_VALUES";

    public ResilientPropagation(ContainsFlat containsFlat, MLDataSet mLDataSet) {
        this(containsFlat, mLDataSet, 0.1d, 50.0d);
    }

    public ResilientPropagation(ContainsFlat containsFlat, MLDataSet mLDataSet, double d, double d2) {
        super(containsFlat, mLDataSet);
        setFlatTraining(new TrainFlatNetworkResilient(containsFlat.getFlat(), getTraining(), 1.0E-17d, d, d2));
    }

    @Override // org.encog.ml.train.MLTrain
    public final boolean canContinue() {
        return true;
    }

    public final boolean isValidResume(TrainingContinuation trainingContinuation) {
        return trainingContinuation.getContents().containsKey("LAST_GRADIENTS") && trainingContinuation.getContents().containsKey(UPDATE_VALUES) && trainingContinuation.getTrainingType().equals(getClass().getSimpleName()) && ((double[]) trainingContinuation.get("LAST_GRADIENTS")).length == ((ContainsFlat) getMethod()).getFlat().getWeights().length;
    }

    @Override // org.encog.ml.train.MLTrain
    public final TrainingContinuation pause() {
        TrainingContinuation trainingContinuation = new TrainingContinuation();
        trainingContinuation.setTrainingType(getClass().getSimpleName());
        trainingContinuation.set("LAST_GRADIENTS", ((TrainFlatNetworkResilient) getFlatTraining()).getLastGradient());
        trainingContinuation.set(UPDATE_VALUES, ((TrainFlatNetworkResilient) getFlatTraining()).getUpdateValues());
        return trainingContinuation;
    }

    @Override // org.encog.ml.train.MLTrain
    public final void resume(TrainingContinuation trainingContinuation) {
        if (!isValidResume(trainingContinuation)) {
            throw new TrainingError("Invalid training resume data length");
        }
        double[] dArr = (double[]) trainingContinuation.get("LAST_GRADIENTS");
        double[] dArr2 = (double[]) trainingContinuation.get(UPDATE_VALUES);
        EngineArray.arrayCopy(dArr, ((TrainFlatNetworkResilient) getFlatTraining()).getLastGradient());
        EngineArray.arrayCopy(dArr2, ((TrainFlatNetworkResilient) getFlatTraining()).getUpdateValues());
    }

    public void setRPROPType(RPROPType rPROPType) {
        ((TrainFlatNetworkResilient) getFlatTraining()).setRpropType(rPROPType);
    }

    public RPROPType getRPROPType() {
        return ((TrainFlatNetworkResilient) getFlatTraining()).getRpropType();
    }
}
