package org.encog.neural.neat.training;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.encog.mathutil.randomize.RangeRandomizer;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.genetic.GeneticAlgorithm;
import org.encog.ml.genetic.genome.Chromosome;
import org.encog.ml.genetic.genome.Genome;
import org.encog.ml.genetic.genome.GenomeComparator;
import org.encog.ml.genetic.population.Population;
import org.encog.ml.genetic.species.BasicSpecies;
import org.encog.ml.genetic.species.Species;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.Strategy;
import org.encog.neural.neat.NEATNetwork;
import org.encog.neural.neat.NEATPopulation;
import org.encog.neural.networks.training.CalculateScore;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.genetic.GeneticScoreAdapter;
import org.encog.neural.networks.training.propagation.TrainingContinuation;

/* loaded from: input_file:org/encog/neural/neat/training/NEATTraining.class */
public class NEATTraining extends GeneticAlgorithm implements MLTrain {
    private double averageFitAdjustment;
    private double bestEverScore;
    private NEATNetwork bestEverNetwork;
    private final int inputCount;
    private final int outputCount;
    private double paramActivationMutationRate = 0.1d;
    private double paramChanceAddLink = 0.07d;
    private double paramChanceAddNode = 0.04d;
    private double paramChanceAddRecurrentLink = 0.05d;
    private double paramCompatibilityThreshold = 0.26d;
    private double paramCrossoverRate = 0.7d;
    private double paramMaxActivationPerturbation = 0.1d;
    private int paramMaxNumberOfSpecies = 0;
    private double paramMaxPermittedNeurons = 100.0d;
    private double paramMaxWeightPerturbation = 0.5d;
    private double paramMutationRate = 0.2d;
    private int paramNumAddLinkAttempts = 5;
    private int paramNumGensAllowedNoImprovement = 15;
    private int paramNumTrysToFindLoopedLink = 5;
    private int paramNumTrysToFindOldLink = 5;
    private double paramProbabilityWeightReplaced = 0.1d;
    private double totalFitAdjustment;
    private boolean snapshot;
    private int iteration;

    public NEATTraining(CalculateScore calculateScore, int i, int i2, int i3) {
        this.inputCount = i;
        this.outputCount = i2;
        setCalculateScore(new GeneticScoreAdapter(calculateScore));
        setComparator(new GenomeComparator(getCalculateScore()));
        setPopulation(new NEATPopulation(i, i2, i3));
        init();
    }

    public NEATTraining(CalculateScore calculateScore, Population population) {
        if (population.size() < 1) {
            throw new TrainingError("Population can not be empty.");
        }
        NEATGenome nEATGenome = (NEATGenome) population.getGenomes().get(0);
        setCalculateScore(new GeneticScoreAdapter(calculateScore));
        setComparator(new GenomeComparator(getCalculateScore()));
        setPopulation(population);
        this.inputCount = nEATGenome.getInputCount();
        this.outputCount = nEATGenome.getOutputCount();
        init();
    }

    public void addNeuronID(long j, List<Long> list) {
        for (int i = 0; i < list.size(); i++) {
            if (list.get(i).longValue() == j) {
                return;
            }
        }
        list.add(Long.valueOf(j));
    }

    @Override // org.encog.ml.train.MLTrain
    public void addStrategy(Strategy strategy) {
        throw new TrainingError("Strategies are not supported by this training method.");
    }

    public void adjustCompatibilityThreshold() {
        if (this.paramMaxNumberOfSpecies < 1) {
            return;
        }
        if (getPopulation().getSpecies().size() > this.paramMaxNumberOfSpecies) {
            this.paramCompatibilityThreshold += 0.01d;
        } else if (getPopulation().getSpecies().size() < 2) {
            this.paramCompatibilityThreshold -= 0.01d;
        }
    }

    public void adjustSpeciesScore() {
        for (Species species : getPopulation().getSpecies()) {
            for (Genome genome : species.getMembers()) {
                double score = genome.getScore();
                if (species.getAge() < getPopulation().getYoungBonusAgeThreshold()) {
                    score = getComparator().applyBonus(score, getPopulation().getYoungScoreBonus());
                }
                if (species.getAge() > getPopulation().getOldAgeThreshold()) {
                    score = getComparator().applyPenalty(score, getPopulation().getOldAgePenalty());
                }
                genome.setAdjustedScore(score / species.getMembers().size());
            }
        }
    }

    @Override // org.encog.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    public NEATGenome crossover(NEATGenome nEATGenome, NEATGenome nEATGenome2) {
        NEATParent nEATParent = nEATGenome.getScore() == nEATGenome2.getScore() ? nEATGenome.getNumGenes() == nEATGenome2.getNumGenes() ? Math.random() > 0.0d ? NEATParent.Mom : NEATParent.Dad : nEATGenome.getNumGenes() < nEATGenome2.getNumGenes() ? NEATParent.Mom : NEATParent.Dad : getComparator().isBetterThan(nEATGenome.getScore(), nEATGenome2.getScore()) ? NEATParent.Mom : NEATParent.Dad;
        Chromosome chromosome = new Chromosome();
        Chromosome chromosome2 = new Chromosome();
        ArrayList arrayList = new ArrayList();
        int i = 0;
        int i2 = 0;
        NEATLinkGene nEATLinkGene = null;
        while (true) {
            if (i >= nEATGenome.getNumGenes() && i2 >= nEATGenome2.getNumGenes()) {
                break;
            }
            NEATLinkGene nEATLinkGene2 = i < nEATGenome.getNumGenes() ? (NEATLinkGene) nEATGenome.getLinks().get(i) : null;
            NEATLinkGene nEATLinkGene3 = i2 < nEATGenome2.getNumGenes() ? (NEATLinkGene) nEATGenome2.getLinks().get(i2) : null;
            if (nEATLinkGene2 == null && nEATLinkGene3 != null) {
                if (nEATParent == NEATParent.Dad) {
                    nEATLinkGene = nEATLinkGene3;
                }
                i2++;
            } else if (nEATLinkGene3 == null && nEATLinkGene2 != null) {
                if (nEATParent == NEATParent.Mom) {
                    nEATLinkGene = nEATLinkGene2;
                }
                i++;
            } else if (nEATLinkGene2.getInnovationId() < nEATLinkGene3.getInnovationId()) {
                if (nEATParent == NEATParent.Mom) {
                    nEATLinkGene = nEATLinkGene2;
                }
                i++;
            } else if (nEATLinkGene3.getInnovationId() < nEATLinkGene2.getInnovationId()) {
                if (nEATParent == NEATParent.Dad) {
                    nEATLinkGene = nEATLinkGene3;
                }
                i2++;
            } else if (nEATLinkGene3.getInnovationId() == nEATLinkGene2.getInnovationId()) {
                nEATLinkGene = Math.random() < 0.5d ? nEATLinkGene2 : nEATLinkGene3;
                i++;
                i2++;
            }
            if (chromosome2.size() == 0) {
                chromosome2.add(nEATLinkGene);
            } else if (((NEATLinkGene) chromosome2.get(chromosome2.size() - 1)).getInnovationId() != nEATLinkGene.getInnovationId()) {
                chromosome2.add(nEATLinkGene);
            }
            addNeuronID(nEATLinkGene.getFromNeuronID(), arrayList);
            addNeuronID(nEATLinkGene.getToNeuronID(), arrayList);
        }
        Collections.sort(arrayList);
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            chromosome.add(getInnovations().createNeuronFromID(arrayList.get(i3).longValue()));
        }
        NEATGenome nEATGenome3 = new NEATGenome(getPopulation().assignGenomeID(), chromosome, chromosome2, nEATGenome.getInputCount(), nEATGenome.getOutputCount());
        nEATGenome3.setGeneticAlgorithm(this);
        nEATGenome3.setPopulation(getPopulation());
        return nEATGenome3;
    }

    @Override // org.encog.ml.train.MLTrain
    public void finishTraining() {
    }

    @Override // org.encog.ml.train.MLTrain
    public double getError() {
        return this.bestEverScore;
    }

    @Override // org.encog.ml.train.MLTrain
    public TrainingImplementationType getImplementationType() {
        return TrainingImplementationType.Iterative;
    }

    public NEATInnovationList getInnovations() {
        return (NEATInnovationList) getPopulation().getInnovations();
    }

    public int getInputCount() {
        return this.inputCount;
    }

    @Override // org.encog.ml.train.MLTrain
    public int getIteration() {
        return this.iteration;
    }

    @Override // org.encog.ml.train.MLTrain
    public MLMethod getMethod() {
        return this.bestEverNetwork;
    }

    public int getOutputCount() {
        return this.outputCount;
    }

    public double getParamActivationMutationRate() {
        return this.paramActivationMutationRate;
    }

    public double getParamChanceAddLink() {
        return this.paramChanceAddLink;
    }

    public double getParamChanceAddNode() {
        return this.paramChanceAddNode;
    }

    public double getParamChanceAddRecurrentLink() {
        return this.paramChanceAddRecurrentLink;
    }

    public double getParamCompatibilityThreshold() {
        return this.paramCompatibilityThreshold;
    }

    public double getParamCrossoverRate() {
        return this.paramCrossoverRate;
    }

    public double getParamMaxActivationPerturbation() {
        return this.paramMaxActivationPerturbation;
    }

    public int getParamMaxNumberOfSpecies() {
        return this.paramMaxNumberOfSpecies;
    }

    public double getParamMaxPermittedNeurons() {
        return this.paramMaxPermittedNeurons;
    }

    public double getParamMaxWeightPerturbation() {
        return this.paramMaxWeightPerturbation;
    }

    public double getParamMutationRate() {
        return this.paramMutationRate;
    }

    public int getParamNumAddLinkAttempts() {
        return this.paramNumAddLinkAttempts;
    }

    public int getParamNumGensAllowedNoImprovement() {
        return this.paramNumGensAllowedNoImprovement;
    }

    public int getParamNumTrysToFindLoopedLink() {
        return this.paramNumTrysToFindLoopedLink;
    }

    public int getParamNumTrysToFindOldLink() {
        return this.paramNumTrysToFindOldLink;
    }

    public double getParamProbabilityWeightReplaced() {
        return this.paramProbabilityWeightReplaced;
    }

    @Override // org.encog.ml.train.MLTrain
    public List<Strategy> getStrategies() {
        return new ArrayList();
    }

    @Override // org.encog.ml.train.MLTrain
    public MLDataSet getTraining() {
        return null;
    }

    private void init() {
        if (getCalculateScore().shouldMinimize()) {
            this.bestEverScore = Double.MAX_VALUE;
        } else {
            this.bestEverScore = Double.MIN_VALUE;
        }
        for (Genome genome : getPopulation().getGenomes()) {
            if (!(genome instanceof NEATGenome)) {
                throw new TrainingError("Population can only contain objects of NEATGenome.");
            }
            NEATGenome nEATGenome = (NEATGenome) genome;
            if (nEATGenome.getInputCount() != this.inputCount || nEATGenome.getOutputCount() != this.outputCount) {
                throw new TrainingError("All NEATGenome's must have the same input and output sizes as the base network.");
            }
            nEATGenome.setGeneticAlgorithm(this);
        }
        getPopulation().claim(this);
        resetAndKill();
        sortAndRecord();
        speciateAndCalculateSpawnLevels();
    }

    public boolean isSnapshot() {
        return this.snapshot;
    }

    @Override // org.encog.ml.train.MLTrain
    public boolean isTrainingDone() {
        return false;
    }

    @Override // org.encog.ml.genetic.GeneticAlgorithm
    public void iteration() {
        this.iteration++;
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (Species species : getPopulation().getSpecies()) {
            if (i < getPopulation().size()) {
                int round = (int) Math.round(species.getNumToSpawn());
                boolean z = false;
                while (true) {
                    int i2 = round;
                    round--;
                    if (i2 > 0) {
                        NEATGenome nEATGenome = null;
                        if (z) {
                            if (species.getMembers().size() == 1) {
                                nEATGenome = new NEATGenome((NEATGenome) species.chooseParent());
                            } else {
                                NEATGenome nEATGenome2 = (NEATGenome) species.chooseParent();
                                if (Math.random() < this.paramCrossoverRate) {
                                    NEATGenome nEATGenome3 = (NEATGenome) species.chooseParent();
                                    int i3 = 5;
                                    while (nEATGenome2.getGenomeID() == nEATGenome3.getGenomeID()) {
                                        int i4 = i3;
                                        i3--;
                                        if (i4 <= 0) {
                                            break;
                                        } else {
                                            nEATGenome3 = (NEATGenome) species.chooseParent();
                                        }
                                    }
                                    if (nEATGenome2.getGenomeID() != nEATGenome3.getGenomeID()) {
                                        nEATGenome = crossover(nEATGenome2, nEATGenome3);
                                    }
                                } else {
                                    nEATGenome = new NEATGenome(nEATGenome2);
                                }
                            }
                            if (nEATGenome != null) {
                                nEATGenome.setGenomeID(getPopulation().assignGenomeID());
                                if (nEATGenome.getNeurons().size() < this.paramMaxPermittedNeurons) {
                                    nEATGenome.addNeuron(this.paramChanceAddNode, this.paramNumTrysToFindOldLink);
                                }
                                nEATGenome.addLink(this.paramChanceAddLink, this.paramChanceAddRecurrentLink, this.paramNumTrysToFindLoopedLink, this.paramNumAddLinkAttempts);
                                nEATGenome.mutateWeights(this.paramMutationRate, this.paramProbabilityWeightReplaced, this.paramMaxWeightPerturbation);
                                nEATGenome.mutateActivationResponse(this.paramActivationMutationRate, this.paramMaxActivationPerturbation);
                            }
                        } else {
                            nEATGenome = (NEATGenome) species.getLeader();
                            z = true;
                        }
                        if (nEATGenome != null) {
                            nEATGenome.sortGenes();
                            arrayList.add(nEATGenome);
                            i++;
                            if (i == getPopulation().size()) {
                                round = 0;
                            }
                        }
                    }
                }
            }
        }
        while (arrayList.size() < getPopulation().size()) {
            arrayList.add(tournamentSelection(getPopulation().size() / 5));
        }
        getPopulation().clear();
        getPopulation().addAll(arrayList);
        resetAndKill();
        sortAndRecord();
        speciateAndCalculateSpawnLevels();
    }

    @Override // org.encog.ml.train.MLTrain
    public void iteration(int i) {
        for (int i2 = 0; i2 < i; i2++) {
            iteration();
        }
    }

    @Override // org.encog.ml.train.MLTrain
    public TrainingContinuation pause() {
        return null;
    }

    public void resetAndKill() {
        this.totalFitAdjustment = 0.0d;
        this.averageFitAdjustment = 0.0d;
        for (Object obj : getPopulation().getSpecies().toArray()) {
            Species species = (Species) obj;
            species.purge();
            if (species.getGensNoImprovement() > this.paramNumGensAllowedNoImprovement && getComparator().isBetterThan(this.bestEverScore, species.getBestScore())) {
                getPopulation().getSpecies().remove(species);
            }
        }
    }

    @Override // org.encog.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
    }

    @Override // org.encog.ml.train.MLTrain
    public void setError(double d) {
    }

    @Override // org.encog.ml.train.MLTrain
    public void setIteration(int i) {
        this.iteration = i;
    }

    public void setParamActivationMutationRate(double d) {
        this.paramActivationMutationRate = d;
    }

    public void setParamChanceAddLink(double d) {
        this.paramChanceAddLink = d;
    }

    public void setParamChanceAddNode(double d) {
        this.paramChanceAddNode = d;
    }

    public void setParamChanceAddRecurrentLink(double d) {
        this.paramChanceAddRecurrentLink = d;
    }

    public void setParamCompatibilityThreshold(double d) {
        this.paramCompatibilityThreshold = d;
    }

    public void setParamCrossoverRate(double d) {
        this.paramCrossoverRate = d;
    }

    public void setParamMaxActivationPerturbation(double d) {
        this.paramMaxActivationPerturbation = d;
    }

    public void setParamMaxNumberOfSpecies(int i) {
        this.paramMaxNumberOfSpecies = i;
    }

    public void setParamMaxPermittedNeurons(double d) {
        this.paramMaxPermittedNeurons = d;
    }

    public void setParamMaxWeightPerturbation(double d) {
        this.paramMaxWeightPerturbation = d;
    }

    public void setParamMutationRate(double d) {
        this.paramMutationRate = d;
    }

    public void setParamNumAddLinkAttempts(int i) {
        this.paramNumAddLinkAttempts = i;
    }

    public void setParamNumGensAllowedNoImprovement(int i) {
        this.paramNumGensAllowedNoImprovement = i;
    }

    public void setParamNumTrysToFindLoopedLink(int i) {
        this.paramNumTrysToFindLoopedLink = i;
    }

    public void setParamNumTrysToFindOldLink(int i) {
        this.paramNumTrysToFindOldLink = i;
    }

    public void setParamProbabilityWeightReplaced(double d) {
        this.paramProbabilityWeightReplaced = d;
    }

    public void setSnapshot(boolean z) {
        this.snapshot = z;
    }

    public void sortAndRecord() {
        for (Genome genome : getPopulation().getGenomes()) {
            genome.decode();
            calculateScore(genome);
        }
        getPopulation().sort();
        Genome best = getPopulation().getBest();
        double score = best.getScore();
        if (getComparator().isBetterThan(score, this.bestEverScore)) {
            this.bestEverScore = score;
            this.bestEverNetwork = (NEATNetwork) best.getOrganism();
        }
        this.bestEverScore = getComparator().bestScore(getError(), this.bestEverScore);
    }

    public void speciateAndCalculateSpawnLevels() {
        adjustCompatibilityThreshold();
        Iterator<Genome> it = getPopulation().getGenomes().iterator();
        while (it.hasNext()) {
            NEATGenome nEATGenome = (NEATGenome) it.next();
            boolean z = false;
            Iterator<Species> it2 = getPopulation().getSpecies().iterator();
            while (true) {
                if (!it2.hasNext()) {
                    break;
                }
                Species next = it2.next();
                if (nEATGenome.getCompatibilityScore((NEATGenome) next.getLeader()) <= this.paramCompatibilityThreshold) {
                    addSpeciesMember(next, nEATGenome);
                    nEATGenome.setSpeciesID(next.getSpeciesID());
                    z = true;
                    break;
                }
            }
            if (!z) {
                getPopulation().getSpecies().add(new BasicSpecies(getPopulation(), nEATGenome, getPopulation().assignSpeciesID()));
            }
        }
        adjustSpeciesScore();
        Iterator<Genome> it3 = getPopulation().getGenomes().iterator();
        while (it3.hasNext()) {
            this.totalFitAdjustment += ((NEATGenome) it3.next()).getAdjustedScore();
        }
        this.averageFitAdjustment = this.totalFitAdjustment / getPopulation().size();
        Iterator<Genome> it4 = getPopulation().getGenomes().iterator();
        while (it4.hasNext()) {
            NEATGenome nEATGenome2 = (NEATGenome) it4.next();
            nEATGenome2.setAmountToSpawn(nEATGenome2.getAdjustedScore() / this.averageFitAdjustment);
        }
        Iterator<Species> it5 = getPopulation().getSpecies().iterator();
        while (it5.hasNext()) {
            it5.next().calculateSpawnAmount();
        }
    }

    public NEATGenome tournamentSelection(int i) {
        double d = 0.0d;
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            int randomize = (int) RangeRandomizer.randomize(0.0d, getPopulation().size() - 1);
            if (getPopulation().get(randomize).getScore() > d) {
                i2 = randomize;
                d = getPopulation().get(randomize).getScore();
            }
        }
        return (NEATGenome) getPopulation().get(i2);
    }
}
