package com.hankcs.hanlp.classification.classifiers;

import com.github.mikephil.charting.utils.Utils;
import com.hankcs.hanlp.classification.corpus.Document;
import com.hankcs.hanlp.classification.corpus.IDataSet;
import com.hankcs.hanlp.classification.features.BaseFeatureData;
import com.hankcs.hanlp.classification.features.ChiSquareFeatureExtractor;
import com.hankcs.hanlp.classification.models.AbstractModel;
import com.hankcs.hanlp.classification.models.NaiveBayesModel;
import com.hankcs.hanlp.classification.utilities.io.ConsoleLogger;
import com.hankcs.hanlp.collection.trie.bintrie.BinTrie;
import com.hankcs.hanlp.utility.MathUtility;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;

/* loaded from: classes2.dex */
public class NaiveBayesClassifier extends AbstractClassifier {
    private NaiveBayesModel model;

    public NaiveBayesClassifier() {
        this(null);
    }

    public NaiveBayesClassifier(NaiveBayesModel naiveBayesModel) {
        this.model = naiveBayesModel;
    }

    @Override // com.hankcs.hanlp.classification.classifiers.IClassifier
    public double[] categorize(Document document) throws IllegalArgumentException, IllegalStateException {
        NaiveBayesModel naiveBayesModel = this.model;
        double[] dArr = new double[naiveBayesModel.catalog.length];
        for (Map.Entry<Integer, Double> entry : naiveBayesModel.logPriors.entrySet()) {
            Integer key = entry.getKey();
            Double value = entry.getValue();
            Iterator<Map.Entry<Integer, int[]>> it = document.tfMap.entrySet().iterator();
            while (it.hasNext()) {
                Integer key2 = it.next().getKey();
                if (this.model.logLikelihoods.containsKey(key2)) {
                    value = Double.valueOf(value.doubleValue() + (r5.getValue()[0] * this.model.logLikelihoods.get(key2).get(key).doubleValue()));
                }
            }
            dArr[key.intValue()] = value.doubleValue();
        }
        if (this.configProbabilityEnabled) {
            MathUtility.normalizeExp(dArr);
        }
        return dArr;
    }

    @Override // com.hankcs.hanlp.classification.classifiers.IClassifier
    public AbstractModel getModel() {
        return this.model;
    }

    public NaiveBayesModel getNaiveBayesModel() {
        return this.model;
    }

    @Override // com.hankcs.hanlp.classification.classifiers.IClassifier
    public Map<String, Double> predict(String str) throws IllegalArgumentException, IllegalStateException {
        if (this.model == null) {
            throw new IllegalStateException("未训练模型！无法执行预测！");
        }
        if (str == null) {
            throw new IllegalArgumentException("参数 text == null");
        }
        NaiveBayesModel naiveBayesModel = this.model;
        return predict(new Document(naiveBayesModel.wordIdTrie, naiveBayesModel.tokenizer.segment(str)));
    }

    protected BaseFeatureData selectFeatures(IDataSet iDataSet) {
        ChiSquareFeatureExtractor chiSquareFeatureExtractor = new ChiSquareFeatureExtractor();
        ConsoleLogger.logger.start("使用卡方检测选择特征中...", new Object[0]);
        BaseFeatureData extractBasicFeatureData = ChiSquareFeatureExtractor.extractBasicFeatureData(iDataSet);
        Map<Integer, Double> chi_square = chiSquareFeatureExtractor.chi_square(extractBasicFeatureData);
        int size = chi_square.size();
        int[][] iArr = new int[size];
        extractBasicFeatureData.wordIdTrie = new BinTrie<>();
        String[] wordIdArray = iDataSet.getLexicon().getWordIdArray();
        int i8 = -1;
        for (Integer num : chi_square.keySet()) {
            i8++;
            iArr[i8] = extractBasicFeatureData.featureCategoryJointCount[num.intValue()];
            extractBasicFeatureData.wordIdTrie.put(wordIdArray[num.intValue()], (String) Integer.valueOf(i8));
        }
        ConsoleLogger.logger.finish(",选中特征数:%d / %d = %.2f%%\n", Integer.valueOf(size), Integer.valueOf(extractBasicFeatureData.featureCategoryJointCount.length), Double.valueOf((size / extractBasicFeatureData.featureCategoryJointCount.length) * 100.0d));
        extractBasicFeatureData.featureCategoryJointCount = iArr;
        return extractBasicFeatureData;
    }

    @Override // com.hankcs.hanlp.classification.classifiers.IClassifier
    public void train(IDataSet iDataSet) {
        ConsoleLogger.logger.out("原始数据集大小:%d\n", Integer.valueOf(iDataSet.size()));
        BaseFeatureData selectFeatures = selectFeatures(iDataSet);
        NaiveBayesModel naiveBayesModel = new NaiveBayesModel();
        this.model = naiveBayesModel;
        naiveBayesModel.f34086n = selectFeatures.f34083n;
        naiveBayesModel.f34085d = selectFeatures.featureCategoryJointCount.length;
        naiveBayesModel.f34084c = selectFeatures.categoryCounts.length;
        naiveBayesModel.logPriors = new TreeMap();
        int i8 = 0;
        while (true) {
            if (i8 >= selectFeatures.categoryCounts.length) {
                break;
            }
            this.model.logPriors.put(Integer.valueOf(i8), Double.valueOf(Math.log(r3[i8] / this.model.f34086n)));
            i8++;
        }
        TreeMap treeMap = new TreeMap();
        for (Integer num : this.model.logPriors.keySet()) {
            Double valueOf = Double.valueOf(Utils.DOUBLE_EPSILON);
            for (int i9 = 0; i9 < selectFeatures.featureCategoryJointCount.length; i9++) {
                valueOf = Double.valueOf(valueOf.doubleValue() + selectFeatures.featureCategoryJointCount[i9][num.intValue()]);
            }
            treeMap.put(num, valueOf);
        }
        for (Integer num2 : this.model.logPriors.keySet()) {
            int i10 = 0;
            while (true) {
                if (i10 < selectFeatures.featureCategoryJointCount.length) {
                    double log = Math.log((r6[i10][num2.intValue()] + 1.0d) / (((Double) treeMap.get(num2)).doubleValue() + this.model.f34085d));
                    if (!this.model.logLikelihoods.containsKey(Integer.valueOf(i10))) {
                        this.model.logLikelihoods.put(Integer.valueOf(i10), new TreeMap());
                    }
                    this.model.logLikelihoods.get(Integer.valueOf(i10)).put(num2, Double.valueOf(log));
                    i10++;
                }
            }
        }
        ConsoleLogger.logger.out("贝叶斯统计结束\n", new Object[0]);
        this.model.catalog = iDataSet.getCatalog().toArray();
        this.model.tokenizer = iDataSet.getTokenizer();
        this.model.wordIdTrie = selectFeatures.wordIdTrie;
    }
}
