package org.neuroph.nnet.learning;

import java.util.ArrayList;
import java.util.Iterator;
import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.Neuron;
import org.neuroph.core.transfer.Gaussian;
import org.neuroph.nnet.learning.kmeans.Cluster;
import org.neuroph.nnet.learning.kmeans.KMeansClustering;
import org.neuroph.nnet.learning.kmeans.KVector;
import org.neuroph.nnet.learning.knn.KNearestNeighbour;

/* loaded from: input_file:org/neuroph/nnet/learning/RBFLearning.class */
public class RBFLearning extends LMS {
    int k = 2;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.neuroph.core.learning.SupervisedLearning, org.neuroph.core.learning.IterativeLearning, org.neuroph.core.learning.LearningRule
    public void onStart() {
        super.onStart();
        KMeansClustering kMeansClustering = new KMeansClustering(getTrainingSet());
        kMeansClustering.setNumberOfClusters(this.neuralNetwork.getLayerAt(1).getNeuronsCount());
        kMeansClustering.doClustering();
        Cluster[] clusters = kMeansClustering.getClusters();
        Layer layerAt = this.neuralNetwork.getLayerAt(1);
        int i = 0;
        for (Neuron neuron : layerAt.getNeurons()) {
            double[] values = clusters[i].getCentroid().getValues();
            int i2 = 0;
            Iterator<Connection> it = neuron.getInputConnections().iterator();
            while (it.hasNext()) {
                it.next().getWeight().setValue(values[i2]);
                i2++;
            }
            i++;
        }
        ArrayList<KVector> arrayList = new ArrayList();
        for (Cluster cluster : clusters) {
            arrayList.add(cluster.getCentroid());
        }
        KNearestNeighbour kNearestNeighbour = new KNearestNeighbour();
        kNearestNeighbour.setDataSet(arrayList);
        for (KVector kVector : arrayList) {
            ((Gaussian) layerAt.getNeuronAt(0).getTransferFunction()).setSigma(calculateSigma(kVector, kNearestNeighbour.getKNearestNeighbours(kVector, this.k)));
            i++;
        }
    }

    private double calculateSigma(KVector kVector, KVector[] kVectorArr) {
        double d = 0.0d;
        for (KVector kVector2 : kVectorArr) {
            d += Math.pow(kVector.distanceFrom(kVector2), 2.0d);
        }
        return Math.sqrt((1.0d / kVectorArr.length) * d);
    }
}
