package boofcv.deepboof;

import boofcv.alg.color.ColorYuv;
import boofcv.alg.filter.stat.ImageLocalNormalization;
import boofcv.struct.border.BorderType;
import boofcv.struct.convolve.Kernel1D;
import boofcv.struct.convolve.Kernel1D_F32;
import boofcv.struct.image.GrayF32;
import boofcv.struct.image.Planar;
import deepboof.datasets.UtilCifar10;
import deepboof.io.torch7.ParseBinaryTorch7;
import deepboof.io.torch7.SequenceAndParameters;
import deepboof.misc.TensorOps;
import deepboof.models.DeepModelIO;
import deepboof.models.YuvStatistics;
import deepboof.tensors.Tensor_F32;
import java.io.File;

/* loaded from: classes.dex */
public class ImageClassifierVggCifar10 extends BaseImageClassifier {
    static final int inputSize = 32;
    Planar<GrayF32> imageYuv;
    Kernel1D_F32 kernel;
    ImageLocalNormalization<GrayF32> localNorm;
    YuvStatistics stats;

    public ImageClassifierVggCifar10() {
        super(32);
        this.imageYuv = new Planar<>(GrayF32.class, 32, 32, 3);
        this.categories.addAll(UtilCifar10.getClassNames());
    }

    @Override // boofcv.abst.scene.ImageModelBase
    public void loadModel(File file) {
        this.stats = DeepModelIO.load(new File(file, "YuvStatistics.txt"));
        this.network = ((SequenceAndParameters) new ParseBinaryTorch7().parseIntoBoof(new File(file, "model.net"))).createForward(new int[]{3, 32, 32});
        this.tensorOutput = new Tensor_F32(TensorOps.WI(1, this.network.getOutputShape()));
        this.localNorm = new ImageLocalNormalization<>(GrayF32.class, BorderType.valueOf(this.stats.border));
        this.kernel = DataManipulationOps.create1D_F32(this.stats.kernel);
    }

    @Override // boofcv.deepboof.BaseImageClassifier
    public Planar<GrayF32> preprocess(Planar<GrayF32> planar) {
        super.preprocess(planar);
        ColorYuv.rgbToYuv(this.imageRgb, this.imageYuv);
        this.localNorm.zeroMeanStdOne((Kernel1D) this.kernel, (double) this.imageYuv.getBand(0), 255.0d, 1.0E-4d, (double) this.imageYuv.getBand(0));
        DataManipulationOps.normalize(this.imageYuv.getBand(1), (float) this.stats.meanU, (float) this.stats.stdevU);
        DataManipulationOps.normalize(this.imageYuv.getBand(2), (float) this.stats.meanV, (float) this.stats.stdevV);
        return this.imageYuv;
    }
}
