package org.encog.neural.networks.training.pnn;

import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.neural.pnn.BasicPNN;
import org.encog.neural.pnn.PNNKernelType;
import org.encog.neural.pnn.PNNOutputMode;

/* loaded from: input_file:org/encog/neural/networks/training/pnn/TrainBasicPNN.class */
public class TrainBasicPNN extends BasicTraining implements CalculationCriteria {
    public static final double DEFAULT_MAX_ERROR = 0.0d;
    public static final double DEFAULT_MIN_IMPROVEMENT = 1.0E-4d;
    public static final double DEFAULT_SIGMA_LOW = 1.0E-4d;
    public static final double DEFAULT_SIGMA_HIGH = 10.0d;
    public static final int DEFAULT_NUM_SIGMAS = 10;
    private double[] v;
    private double[] w;
    private double[] dsqr;
    private final BasicPNN network;
    private final MLDataSet training;
    private double maxError;
    private double minImprovement;
    private double sigmaLow;
    private double sigmaHigh;
    private int numSigmas;
    private boolean samplesLoaded;

    public TrainBasicPNN(BasicPNN basicPNN, MLDataSet mLDataSet) {
        super(TrainingImplementationType.OnePass);
        this.network = basicPNN;
        this.training = mLDataSet;
        this.maxError = 0.0d;
        this.minImprovement = 1.0E-4d;
        this.sigmaLow = 1.0E-4d;
        this.sigmaHigh = 10.0d;
        this.numSigmas = 10;
        this.samplesLoaded = false;
    }

    @Override // org.encog.neural.networks.training.pnn.CalculationCriteria
    public final double calcErrorWithMultipleSigma(double[] dArr, double[] dArr2, double[] dArr3, boolean z) {
        for (int i = 0; i < this.network.getInputCount(); i++) {
            this.network.getSigma()[i] = dArr[i];
        }
        if (!z) {
            return calculateError(this.network.getSamples(), false);
        }
        double calculateError = calculateError(this.network.getSamples(), true);
        for (int i2 = 0; i2 < this.network.getInputCount(); i2++) {
            dArr2[i2] = this.network.getDeriv()[i2];
            dArr3[i2] = this.network.getDeriv2()[i2];
        }
        return calculateError;
    }

    @Override // org.encog.neural.networks.training.pnn.CalculationCriteria
    public final double calcErrorWithSingleSigma(double d) {
        for (int i = 0; i < this.network.getInputCount(); i++) {
            this.network.getSigma()[i] = d;
        }
        return calculateError(this.network.getSamples(), false);
    }

    public final double calculateError(MLDataSet mLDataSet, boolean z) {
        MLData compute;
        double d;
        double d2;
        double d3;
        double d4 = 0.0d;
        if (z) {
            int inputCount = this.network.isSeparateClass() ? this.network.getInputCount() * this.network.getOutputCount() : this.network.getInputCount();
            for (int i = 0; i < inputCount; i++) {
                this.network.getDeriv()[i] = 0.0d;
                this.network.getDeriv2()[i] = 0.0d;
            }
        }
        this.network.setExclude((int) mLDataSet.getRecordCount());
        MLDataPair createPair = BasicMLDataPair.createPair(mLDataSet.getInputSize(), mLDataSet.getIdealSize());
        double[] dArr = new double[this.network.getOutputCount()];
        for (int i2 = 0; i2 < mLDataSet.getRecordCount(); i2++) {
            mLDataSet.getRecord(i2, createPair);
            this.network.setExclude(this.network.getExclude() - 1);
            double d5 = 0.0d;
            MLData input = createPair.getInput();
            MLData ideal = createPair.getIdeal();
            if (this.network.getOutputMode() == PNNOutputMode.Unsupervised) {
                if (z) {
                    MLData computeDeriv = computeDeriv(input, ideal);
                    for (int i3 = 0; i3 < this.network.getOutputCount(); i3++) {
                        dArr[i3] = computeDeriv.getData(i3);
                    }
                } else {
                    MLData compute2 = this.network.compute(input);
                    for (int i4 = 0; i4 < this.network.getOutputCount(); i4++) {
                        dArr[i4] = compute2.getData(i4);
                    }
                }
                for (int i5 = 0; i5 < this.network.getOutputCount(); i5++) {
                    double data = input.getData(i5) - dArr[i5];
                    d5 += data * data;
                }
            } else if (this.network.getOutputMode() == PNNOutputMode.Classification) {
                int data2 = (int) ideal.getData(0);
                if (z) {
                    compute = computeDeriv(input, createPair.getIdeal());
                    compute.getData(0);
                } else {
                    compute = this.network.compute(input);
                    compute.getData(0);
                }
                dArr[0] = compute.getData(0);
                for (int i6 = 0; i6 < dArr.length; i6++) {
                    if (i6 == data2) {
                        double d6 = 1.0d - dArr[i6];
                        d = d5;
                        d2 = d6;
                        d3 = d6;
                    } else {
                        d = d5;
                        d2 = dArr[i6];
                        d3 = dArr[i6];
                    }
                    d5 = d + (d2 * d3);
                }
            } else if (this.network.getOutputMode() == PNNOutputMode.Regression) {
                if (z) {
                    MLData compute3 = this.network.compute(input);
                    for (int i7 = 0; i7 < this.network.getOutputCount(); i7++) {
                        dArr[i7] = compute3.getData(i7);
                    }
                } else {
                    MLData compute4 = this.network.compute(input);
                    for (int i8 = 0; i8 < this.network.getOutputCount(); i8++) {
                        dArr[i8] = compute4.getData(i8);
                    }
                }
                for (int i9 = 0; i9 < this.network.getOutputCount(); i9++) {
                    double data3 = ideal.getData(i9) - dArr[i9];
                    d5 += data3 * data3;
                }
            }
            d4 += d5;
        }
        this.network.setExclude(-1);
        this.network.setError(d4 / mLDataSet.getRecordCount());
        if (z) {
            for (int i10 = 0; i10 < this.network.getDeriv().length; i10++) {
                double[] deriv = this.network.getDeriv();
                int i11 = i10;
                deriv[i11] = deriv[i11] / mLDataSet.getRecordCount();
                double[] deriv2 = this.network.getDeriv2();
                int i12 = i10;
                deriv2[i12] = deriv2[i12] / mLDataSet.getRecordCount();
            }
        }
        if (this.network.getOutputMode() == PNNOutputMode.Unsupervised || this.network.getOutputMode() == PNNOutputMode.Regression) {
            this.network.setError(this.network.getError() / this.network.getOutputCount());
            if (z) {
                for (int i13 = 0; i13 < this.network.getInputCount(); i13++) {
                    double[] deriv3 = this.network.getDeriv();
                    int i14 = i13;
                    deriv3[i14] = deriv3[i14] / this.network.getOutputCount();
                    double[] deriv22 = this.network.getDeriv2();
                    int i15 = i13;
                    deriv22[i15] = deriv22[i15] / this.network.getOutputCount();
                }
            }
        }
        return this.network.getError();
    }

    @Override // org.encog.ml.train.MLTrain
    public final boolean canContinue() {
        return false;
    }

    public final MLData computeDeriv(MLData mLData, MLData mLData2) {
        double d;
        double d2;
        double d3;
        double data;
        int i = 0;
        int i2 = 0;
        double[] dArr = new double[this.network.getOutputCount()];
        for (int i3 = 0; i3 < this.network.getOutputCount(); i3++) {
            dArr[i3] = 0.0d;
            for (int i4 = 0; i4 < this.network.getInputCount(); i4++) {
                this.v[(i3 * this.network.getInputCount()) + i4] = 0.0d;
                this.w[(i3 * this.network.getInputCount()) + i4] = 0.0d;
            }
        }
        double d4 = 0.0d;
        if (this.network.getOutputMode() != PNNOutputMode.Classification) {
            i = this.network.getOutputCount() * this.network.getInputCount();
            i2 = this.network.getOutputCount() * this.network.getInputCount();
            for (int i5 = 0; i5 < this.network.getInputCount(); i5++) {
                this.v[i + i5] = 0.0d;
                this.w[i2 + i5] = 0.0d;
            }
        }
        MLDataPair createPair = BasicMLDataPair.createPair(this.network.getSamples().getInputSize(), this.network.getSamples().getIdealSize());
        for (int i6 = 0; i6 < this.network.getSamples().getRecordCount(); i6++) {
            this.network.getSamples().getRecord(i6, createPair);
            if (i6 != this.network.getExclude()) {
                double d5 = 0.0d;
                for (int i7 = 0; i7 < this.network.getInputCount(); i7++) {
                    double data2 = (mLData.getData(i7) - createPair.getInput().getData(i7)) / this.network.getSigma()[i7];
                    this.dsqr[i7] = data2 * data2;
                    d5 += this.dsqr[i7];
                }
                if (this.network.getKernel() == PNNKernelType.Gaussian) {
                    d5 = Math.exp(-d5);
                } else if (this.network.getKernel() == PNNKernelType.Reciprocal) {
                    d5 = 1.0d / (1.0d + d5);
                }
                double d6 = d5;
                if (d5 < 1.0E-40d) {
                    d5 = 1.0E-40d;
                }
                if (this.network.getOutputMode() == PNNOutputMode.Classification) {
                    int data3 = (int) createPair.getIdeal().getData(0);
                    dArr[data3] = dArr[data3] + d5;
                    int inputCount = data3 * this.network.getInputCount();
                    int inputCount2 = data3 * this.network.getInputCount();
                    for (int i8 = 0; i8 < this.network.getInputCount(); i8++) {
                        double d7 = d6 * this.dsqr[i8];
                        double[] dArr2 = this.v;
                        int i9 = inputCount + i8;
                        dArr2[i9] = dArr2[i9] + d7;
                        double[] dArr3 = this.w;
                        int i10 = inputCount2 + i8;
                        dArr3[i10] = dArr3[i10] + (d7 * ((2.0d * this.dsqr[i8]) - 3.0d));
                    }
                } else if (this.network.getOutputMode() == PNNOutputMode.Unsupervised) {
                    for (int i11 = 0; i11 < this.network.getInputCount(); i11++) {
                        int i12 = i11;
                        dArr[i12] = dArr[i12] + (d5 * createPair.getInput().getData(i11));
                        double d8 = d6 * this.dsqr[i11];
                        double[] dArr4 = this.v;
                        int i13 = i + i11;
                        dArr4[i13] = dArr4[i13] + d8;
                        double[] dArr5 = this.w;
                        int i14 = i2 + i11;
                        dArr5[i14] = dArr5[i14] + (d8 * ((2.0d * this.dsqr[i11]) - 3.0d));
                    }
                    int i15 = 0;
                    int i16 = 0;
                    for (int i17 = 0; i17 < this.network.getOutputCount(); i17++) {
                        for (int i18 = 0; i18 < this.network.getInputCount(); i18++) {
                            double data4 = d6 * this.dsqr[i18] * createPair.getInput().getData(i18);
                            double[] dArr6 = this.v;
                            int i19 = i15;
                            i15++;
                            dArr6[i19] = dArr6[i19] + data4;
                            double[] dArr7 = this.w;
                            int i20 = i16;
                            i16++;
                            dArr7[i20] = dArr7[i20] + (data4 * ((2.0d * this.dsqr[i18]) - 3.0d));
                        }
                    }
                    d4 += d5;
                } else if (this.network.getOutputMode() == PNNOutputMode.Regression) {
                    for (int i21 = 0; i21 < this.network.getOutputCount(); i21++) {
                        int i22 = i21;
                        dArr[i22] = dArr[i22] + (d5 * createPair.getIdeal().getData(i21));
                    }
                    int i23 = 0;
                    int i24 = 0;
                    for (int i25 = 0; i25 < this.network.getOutputCount(); i25++) {
                        for (int i26 = 0; i26 < this.network.getInputCount(); i26++) {
                            double data5 = d6 * this.dsqr[i26] * createPair.getIdeal().getData(i25);
                            double[] dArr8 = this.v;
                            int i27 = i23;
                            i23++;
                            dArr8[i27] = dArr8[i27] + data5;
                            double[] dArr9 = this.w;
                            int i28 = i24;
                            i24++;
                            dArr9[i28] = dArr9[i28] + (data5 * ((2.0d * this.dsqr[i26]) - 3.0d));
                        }
                    }
                    for (int i29 = 0; i29 < this.network.getInputCount(); i29++) {
                        double d9 = d6 * this.dsqr[i29];
                        double[] dArr10 = this.v;
                        int i30 = i + i29;
                        dArr10[i30] = dArr10[i30] + d9;
                        double[] dArr11 = this.w;
                        int i31 = i2 + i29;
                        dArr11[i31] = dArr11[i31] + (d9 * ((2.0d * this.dsqr[i29]) - 3.0d));
                    }
                    d4 += d5;
                }
            }
        }
        if (this.network.getOutputMode() == PNNOutputMode.Classification) {
            d4 = 0.0d;
            for (int i32 = 0; i32 < this.network.getOutputCount(); i32++) {
                if (this.network.getPriors()[i32] >= 0.0d) {
                    int i33 = i32;
                    dArr[i33] = dArr[i33] * (this.network.getPriors()[i32] / this.network.getCountPer()[i32]);
                }
                d4 += dArr[i32];
            }
            if (d4 < 1.0E-40d) {
                d4 = 1.0E-40d;
            }
        }
        for (int i34 = 0; i34 < this.network.getOutputCount(); i34++) {
            int i35 = i34;
            dArr[i35] = dArr[i35] / d4;
        }
        for (int i36 = 0; i36 < this.network.getInputCount(); i36++) {
            if (this.network.getOutputMode() == PNNOutputMode.Classification) {
                d2 = 0.0d;
                d = 0.0d;
            } else {
                d = (this.v[i + i36] * 2.0d) / (d4 * this.network.getSigma()[i36]);
                d2 = (this.w[i2 + i36] * 2.0d) / ((d4 * this.network.getSigma()[i36]) * this.network.getSigma()[i36]);
            }
            for (int i37 = 0; i37 < this.network.getOutputCount(); i37++) {
                if (this.network.getOutputMode() == PNNOutputMode.Classification && this.network.getPriors()[i37] >= 0.0d) {
                    double[] dArr12 = this.v;
                    int inputCount3 = (i37 * this.network.getInputCount()) + i36;
                    dArr12[inputCount3] = dArr12[inputCount3] * (this.network.getPriors()[i37] / this.network.getCountPer()[i37]);
                    double[] dArr13 = this.w;
                    int inputCount4 = (i37 * this.network.getInputCount()) + i36;
                    dArr13[inputCount4] = dArr13[inputCount4] * (this.network.getPriors()[i37] / this.network.getCountPer()[i37]);
                }
                double[] dArr14 = this.v;
                int inputCount5 = (i37 * this.network.getInputCount()) + i36;
                dArr14[inputCount5] = dArr14[inputCount5] * (2.0d / (d4 * this.network.getSigma()[i36]));
                double[] dArr15 = this.w;
                int inputCount6 = (i37 * this.network.getInputCount()) + i36;
                dArr15[inputCount6] = dArr15[inputCount6] * (2.0d / ((d4 * this.network.getSigma()[i36]) * this.network.getSigma()[i36]));
                if (this.network.getOutputMode() == PNNOutputMode.Classification) {
                    d += this.v[(i37 * this.network.getInputCount()) + i36];
                    d2 += this.w[(i37 * this.network.getInputCount()) + i36];
                }
            }
            for (int i38 = 0; i38 < this.network.getOutputCount(); i38++) {
                double d10 = this.v[(i38 * this.network.getInputCount()) + i36] - (dArr[i38] * d);
                double d11 = ((this.w[(i38 * this.network.getInputCount()) + i36] + (((2.0d * dArr[i38]) * d) * d)) - ((2.0d * this.v[(i38 * this.network.getInputCount()) + i36]) * d)) - (dArr[i38] * d2);
                if (this.network.getOutputMode() != PNNOutputMode.Classification) {
                    d3 = 2.0d;
                    data = dArr[i38] - mLData2.getData(i38);
                } else if (i38 == ((int) mLData2.getData(0))) {
                    d3 = 2.0d;
                    data = dArr[i38] - 1.0d;
                } else {
                    d3 = 2.0d;
                    data = dArr[i38];
                }
                double d12 = d3 * data;
                double[] deriv = this.network.getDeriv();
                int i39 = i36;
                deriv[i39] = deriv[i39] + (d12 * d10);
                double[] deriv2 = this.network.getDeriv2();
                int i40 = i36;
                deriv2[i40] = deriv2[i40] + (d12 * d11) + (2.0d * d10 * d10);
            }
        }
        if (this.network.getOutputMode() != PNNOutputMode.Classification) {
            return null;
        }
        BasicMLData basicMLData = new BasicMLData(1);
        basicMLData.setData(0, 0.0d);
        return basicMLData;
    }

    public final double getMaxError() {
        return this.maxError;
    }

    @Override // org.encog.ml.train.MLTrain
    public final MLMethod getMethod() {
        return this.network;
    }

    public final double getMinImprovement() {
        return this.minImprovement;
    }

    public final int getNumSigmas() {
        return this.numSigmas;
    }

    public final double getSigmaHigh() {
        return this.sigmaHigh;
    }

    public final double getSigmaLow() {
        return this.sigmaLow;
    }

    @Override // org.encog.ml.train.MLTrain
    public final void iteration() {
        if (!this.samplesLoaded) {
            this.network.setSamples(new BasicMLDataSet(this.training));
            this.samplesLoaded = true;
        }
        GlobalMinimumSearch globalMinimumSearch = new GlobalMinimumSearch();
        DeriveMinimum deriveMinimum = new DeriveMinimum();
        int outputCount = this.network.getOutputMode() == PNNOutputMode.Classification ? this.network.getOutputCount() : this.network.getOutputCount() + 1;
        this.dsqr = new double[this.network.getInputCount()];
        this.v = new double[this.network.getInputCount() * outputCount];
        this.w = new double[this.network.getInputCount() * outputCount];
        double[] dArr = new double[this.network.getInputCount()];
        double[] dArr2 = new double[this.network.getInputCount()];
        double[] dArr3 = new double[this.network.getInputCount()];
        double[] dArr4 = new double[this.network.getInputCount()];
        double[] dArr5 = new double[this.network.getInputCount()];
        double[] dArr6 = new double[this.network.getInputCount()];
        if (this.network.isTrained()) {
            for (int i = 0; i < this.network.getInputCount(); i++) {
                dArr[i] = this.network.getSigma()[i];
            }
            globalMinimumSearch.setY2(1.0E30d);
        } else {
            globalMinimumSearch.findBestRange(this.sigmaLow, this.sigmaHigh, this.numSigmas, true, this.maxError, this);
            for (int i2 = 0; i2 < this.network.getInputCount(); i2++) {
                dArr[i2] = globalMinimumSearch.getX2();
            }
        }
        globalMinimumSearch.setY2(deriveMinimum.calculate(32767, this.maxError, 1.0E-8d, this.minImprovement, this, this.network.getInputCount(), dArr, globalMinimumSearch.getY2(), dArr2, dArr3, dArr4, dArr5, dArr6));
        for (int i3 = 0; i3 < this.network.getInputCount(); i3++) {
            this.network.getSigma()[i3] = dArr[i3];
        }
        this.network.setError(Math.abs(globalMinimumSearch.getY2()));
        this.network.setTrained(true);
    }

    @Override // org.encog.ml.train.MLTrain
    public final TrainingContinuation pause() {
        return null;
    }

    @Override // org.encog.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
    }

    public final void setMaxError(double d) {
        this.maxError = d;
    }

    public final void setMinImprovement(double d) {
        this.minImprovement = d;
    }

    public final void setNumSigmas(int i) {
        this.numSigmas = i;
    }

    public final void setSigmaHigh(double d) {
        this.sigmaHigh = d;
    }

    public final void setSigmaLow(double d) {
        this.sigmaLow = d;
    }
}
