package smile.classification;

import com.github.mikephil.charting.utils.Utils;
import java.lang.reflect.Array;
import java.util.Arrays;
import smile.classification.Classifier;
import smile.math.Math;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.EVD;
import smile.math.matrix.Matrix;

/* loaded from: classes2.dex */
public class QDA implements SoftClassifier<double[]> {
    private static final long serialVersionUID = 1;
    private final double[] ct;
    private final double[][] ev;
    private final int k;
    private final double[][] mu;
    private final int p;
    private final double[] priori;
    private final DenseMatrix[] scaling;

    /* loaded from: classes2.dex */
    public static class Trainer extends ClassifierTrainer<double[]> {
        private double[] b;
        private double c = 1.0E-4d;

        @Override // smile.classification.ClassifierTrainer
        public QDA a(double[][] dArr, int[] iArr) {
            return new QDA(dArr, iArr, this.b, this.c);
        }
    }

    public QDA(double[][] dArr, int[] iArr) {
        this(dArr, iArr, (double[]) null);
    }

    public QDA(double[][] dArr, int[] iArr, double d) {
        this(dArr, iArr, null, d);
    }

    public QDA(double[][] dArr, int[] iArr, double[] dArr2) {
        this(dArr, iArr, dArr2, 1.0E-4d);
    }

    public QDA(double[][] dArr, int[] iArr, double[] dArr2, double d) {
        double[][] dArr3 = dArr;
        int[] iArr2 = iArr;
        double[] dArr4 = dArr2;
        if (dArr3.length != iArr2.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(iArr.length)));
        }
        if (dArr4 != null) {
            if (dArr4.length < 2) {
                throw new IllegalArgumentException("Invalid number of priori probabilities: " + dArr4.length);
            }
            double d2 = 0.0d;
            for (double d3 : dArr4) {
                if (d3 <= Utils.a || d3 >= 1.0d) {
                    throw new IllegalArgumentException("Invalid priori probability: " + d3);
                }
                d2 += d3;
            }
            if (Math.a(d2 - 1.0d) > 1.0E-10d) {
                throw new IllegalArgumentException("The sum of priori probabilities is not one: " + d2);
            }
        }
        int[] h = Math.h(iArr);
        Arrays.sort(h);
        for (int i = 0; i < h.length; i++) {
            if (h[i] < 0) {
                throw new IllegalArgumentException("Negative class label: " + h[i]);
            }
            if (i > 0) {
                int i2 = i - 1;
                if (h[i] - h[i2] > 1) {
                    throw new IllegalArgumentException("Missing class: " + (h[i2] + 1));
                }
            }
        }
        int length = h.length;
        this.k = length;
        if (length < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (dArr4 != null && length != dArr4.length) {
            throw new IllegalArgumentException("The number of classes and the number of priori probabilities don't match.");
        }
        if (d < Utils.a) {
            throw new IllegalArgumentException("Invalid tol: " + d);
        }
        int length2 = dArr3.length;
        int i3 = this.k;
        if (length2 <= i3) {
            throw new IllegalArgumentException(String.format("Sample size is too small: %d <= %d", Integer.valueOf(length2), Integer.valueOf(this.k)));
        }
        int length3 = dArr3[0].length;
        this.p = length3;
        int[] iArr3 = new int[i3];
        this.mu = (double[][]) Array.newInstance((Class<?>) double.class, i3, length3);
        DenseMatrix[] denseMatrixArr = new DenseMatrix[this.k];
        for (int i4 = 0; i4 < length2; i4++) {
            int i5 = iArr2[i4];
            iArr3[i5] = iArr3[i5] + 1;
            for (int i6 = 0; i6 < this.p; i6++) {
                double[] dArr5 = this.mu[i5];
                dArr5[i6] = dArr5[i6] + dArr3[i4][i6];
            }
        }
        int i7 = 0;
        while (true) {
            int i8 = this.k;
            if (i7 >= i8) {
                if (dArr4 == null) {
                    dArr4 = new double[i8];
                    for (int i9 = 0; i9 < this.k; i9++) {
                        dArr4[i9] = iArr3[i9] / length2;
                    }
                }
                this.priori = dArr4;
                int i10 = 0;
                while (i10 < length2) {
                    int i11 = iArr2[i10];
                    int i12 = 0;
                    while (i12 < this.p) {
                        int i13 = 0;
                        while (i13 <= i12) {
                            DenseMatrix denseMatrix = denseMatrixArr[i11];
                            double d4 = dArr3[i10][i12];
                            double[][] dArr6 = this.mu;
                            denseMatrix.add(i12, i13, (d4 - dArr6[i11][i12]) * (dArr3[i10][i13] - dArr6[i11][i13]));
                            i13++;
                            dArr3 = dArr;
                        }
                        i12++;
                        dArr3 = dArr;
                    }
                    i10++;
                    dArr3 = dArr;
                    iArr2 = iArr;
                }
                double d5 = d * d;
                this.ev = new double[this.k];
                int i14 = 0;
                while (true) {
                    int i15 = this.k;
                    if (i14 >= i15) {
                        this.scaling = denseMatrixArr;
                        this.ct = new double[i15];
                        for (int i16 = 0; i16 < this.k; i16++) {
                            double d6 = Utils.a;
                            for (int i17 = 0; i17 < this.p; i17++) {
                                d6 += Math.g(this.ev[i16][i17]);
                            }
                            this.ct[i16] = Math.g(dArr4[i16]) - (d6 * 0.5d);
                        }
                        return;
                    }
                    for (int i18 = 0; i18 < this.p; i18++) {
                        for (int i19 = 0; i19 <= i18; i19++) {
                            denseMatrixArr[i14].div(i18, i19, iArr3[i14] - 1);
                            denseMatrixArr[i14].set(i19, i18, denseMatrixArr[i14].get(i18, i19));
                        }
                        if (denseMatrixArr[i14].get(i18, i18) < d5) {
                            throw new IllegalArgumentException(String.format("Class %d covariance matrix (variable %d) is close to singular.", Integer.valueOf(i14), Integer.valueOf(i18)));
                        }
                    }
                    denseMatrixArr[i14].setSymmetric(true);
                    EVD eigen = denseMatrixArr[i14].eigen();
                    for (double d7 : eigen.b()) {
                        if (d7 < d5) {
                            throw new IllegalArgumentException(String.format("Class %d covariance matrix is close to singular.", Integer.valueOf(i14)));
                        }
                    }
                    this.ev[i14] = eigen.b();
                    denseMatrixArr[i14] = eigen.a();
                    i14++;
                }
            } else {
                if (iArr3[i7] <= 1) {
                    throw new IllegalArgumentException(String.format("Class %d has only one sample.", Integer.valueOf(i7)));
                }
                int i20 = this.p;
                denseMatrixArr[i7] = Matrix.CC.zeros(i20, i20);
                for (int i21 = 0; i21 < this.p; i21++) {
                    double[] dArr7 = this.mu[i7];
                    dArr7[i21] = dArr7[i21] / iArr3[i7];
                }
                i7++;
            }
        }
    }

    public double[] getPriori() {
        return this.priori;
    }

    @Override // smile.classification.Classifier
    public int predict(double[] dArr) {
        return predict(dArr, (double[]) null);
    }

    @Override // smile.classification.SoftClassifier
    public int predict(double[] dArr, double[] dArr2) {
        double d;
        if (dArr.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.p)));
        }
        if (dArr2 != null && dArr2.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr2.length), Integer.valueOf(this.k)));
        }
        double d2 = Double.NEGATIVE_INFINITY;
        int i = this.p;
        double[] dArr3 = new double[i];
        double[] dArr4 = new double[i];
        int i2 = 0;
        int i3 = 0;
        while (true) {
            int i4 = this.k;
            d = Utils.a;
            if (i2 >= i4) {
                break;
            }
            for (int i5 = 0; i5 < this.p; i5++) {
                dArr3[i5] = dArr[i5] - this.mu[i2][i5];
            }
            this.scaling[i2].atx(dArr3, dArr4);
            for (int i6 = 0; i6 < this.p; i6++) {
                d += (dArr4[i6] * dArr4[i6]) / this.ev[i2][i6];
            }
            double d3 = this.ct[i2] - (d * 0.5d);
            if (d2 < d3) {
                i3 = i2;
                d2 = d3;
            }
            if (dArr2 != null) {
                dArr2[i2] = d3;
            }
            i2++;
        }
        if (dArr2 != null) {
            for (int i7 = 0; i7 < this.k; i7++) {
                dArr2[i7] = Math.e(dArr2[i7] - d2);
                d += dArr2[i7];
            }
            for (int i8 = 0; i8 < this.k; i8++) {
                dArr2[i8] = dArr2[i8] / d;
            }
        }
        return i3;
    }

    @Override // smile.classification.Classifier
    public /* synthetic */ int[] predict(T[] tArr) {
        return Classifier.CC.$default$predict(this, tArr);
    }
}
