package com.alibaba.dashscope.tokenizers;

import com.alibaba.dashscope.exception.NoSpecialTokenExists;
import com.alibaba.dashscope.exception.UnSupportedSpecialTokenMode;
import com.alibaba.dashscope.utils.StringUtils;
import io.reactivex.annotations.SchedulerSupport;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/* loaded from: classes.dex */
public class QwenTokenizer implements Tokenizer {
    static final /* synthetic */ boolean $assertionsDisabled = false;
    private static final String ENDOFTEXT = "<|endoftext|>";
    private static final String IMEND = "<|im_end|>";
    private static final String IMSTART = "<|im_start|>";
    private static final String PATTEN_STRING = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
    private static final String SPECIAL_END = "|>";
    private static final String SPECIAL_START = "<|";
    private static final int SPECIAL_START_ID = 151643;
    private static final String TOKEN_RANK_SEPARATOR = " ";
    private static final byte[][] decodeMap;
    private static final Map<EncodeBytesEntity, Integer> mergeableRanks;
    private static final Map<String, Integer> specialTokens;
    private static final String vocabularyBpeFile = "qwen.tiktoken";

    static {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(ENDOFTEXT, Integer.valueOf(SPECIAL_START_ID));
        linkedHashMap.put(IMSTART, 151644);
        linkedHashMap.put(IMEND, 151645);
        int i = 151646;
        int i2 = 0;
        while (i2 < 205) {
            linkedHashMap.put(String.format("<|extra_%d|>", Integer.valueOf(i2)), Integer.valueOf(i));
            i2++;
            i++;
        }
        specialTokens = Collections.unmodifiableMap(linkedHashMap);
        mergeableRanks = new LinkedHashMap();
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(QwenTokenizer.class.getClassLoader().getResourceAsStream(vocabularyBpeFile), StandardCharsets.UTF_8));
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                String[] split = readLine.split(TOKEN_RANK_SEPARATOR);
                byte[] decode = Base64.getDecoder().decode(split[0].getBytes(StandardCharsets.UTF_8));
                int intValue = Integer.valueOf(split[1]).intValue();
                mergeableRanks.put(new EncodeBytesEntity(decode, intValue), Integer.valueOf(intValue));
            }
            Map<EncodeBytesEntity, Integer> map = mergeableRanks;
            decodeMap = new byte[map.size() + specialTokens.size()];
            for (Map.Entry<EncodeBytesEntity, Integer> entry : map.entrySet()) {
                decodeMap[entry.getValue().intValue()] = Arrays.copyOf(entry.getKey().bytes, entry.getKey().bytes.length);
            }
            for (Map.Entry<String, Integer> entry2 : specialTokens.entrySet()) {
                byte[] bytes = entry2.getKey().getBytes(StandardCharsets.UTF_8);
                decodeMap[entry2.getValue().intValue()] = Arrays.copyOf(bytes, bytes.length);
            }
        } catch (IOException e) {
            throw new RuntimeException("Could not load qwen.tiktoken from resources", e);
        }
    }

    private List<Integer> encodeChunk(String str) {
        EncodeBytesEntity lowestIndexBytePair;
        byte[] bytes = str.getBytes(StandardCharsets.UTF_8);
        int length = bytes.length;
        EncodeBytesEntity[] encodeBytesEntityArr = new EncodeBytesEntity[length];
        int length2 = bytes.length;
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        while (i2 < length2) {
            EncodeBytesEntity encodeBytesEntity = new EncodeBytesEntity(new byte[]{bytes[i2]});
            encodeBytesEntity.rank = mergeableRanks.get(encodeBytesEntity).intValue();
            encodeBytesEntityArr[i3] = encodeBytesEntity;
            i2++;
            i3++;
        }
        ArrayList arrayList = new ArrayList();
        if (length < 2) {
            while (i < length) {
                arrayList.add(Integer.valueOf(encodeBytesEntityArr[i].rank));
                i++;
            }
            return arrayList;
        }
        while (encodeBytesEntityArr.length >= 2 && (lowestIndexBytePair = getLowestIndexBytePair(encodeBytesEntityArr)) != null) {
            encodeBytesEntityArr = merge(encodeBytesEntityArr, lowestIndexBytePair);
        }
        int length3 = encodeBytesEntityArr.length;
        while (i < length3) {
            arrayList.add(Integer.valueOf(encodeBytesEntityArr[i].rank));
            i++;
        }
        return arrayList;
    }

    private EncodeBytesEntity getLowestIndexBytePair(EncodeBytesEntity[] encodeBytesEntityArr) {
        ArrayList arrayList = new ArrayList();
        Integer num = Integer.MAX_VALUE;
        EncodeBytesEntity encodeBytesEntity = null;
        int i = 0;
        while (i < encodeBytesEntityArr.length - 1) {
            EncodeBytesEntity encodeBytesEntity2 = encodeBytesEntityArr[i];
            i++;
            EncodeBytesEntity mergePair = mergePair(encodeBytesEntity2, encodeBytesEntityArr[i]);
            if (arrayList.indexOf(mergePair) == -1) {
                Integer num2 = mergeableRanks.get(mergePair);
                if (num2 == null) {
                    mergePair.rank = Integer.MAX_VALUE;
                } else {
                    mergePair.rank = num2.intValue();
                    if (num2.intValue() < num.intValue()) {
                        encodeBytesEntity = mergePair;
                        num = num2;
                    }
                }
                arrayList.add(mergePair);
            }
        }
        return encodeBytesEntity;
    }

    private EncodeBytesEntity[] merge(EncodeBytesEntity[] encodeBytesEntityArr, EncodeBytesEntity encodeBytesEntity) {
        int i;
        EncodeBytesEntity[] encodeBytesEntityArr2 = new EncodeBytesEntity[encodeBytesEntityArr.length];
        int i2 = 0;
        int i3 = 0;
        while (i2 < encodeBytesEntityArr.length) {
            if (i2 < encodeBytesEntityArr.length - 1) {
                int i4 = i2 + 1;
                if (mergePair(encodeBytesEntityArr[i2], encodeBytesEntityArr[i4]).equals(encodeBytesEntity)) {
                    i = i3 + 1;
                    encodeBytesEntityArr2[i3] = encodeBytesEntity;
                    i2 += 2;
                } else {
                    encodeBytesEntityArr2[i3] = encodeBytesEntityArr[i2];
                    i3++;
                    i2 = i4;
                }
            } else {
                i = i3 + 1;
                encodeBytesEntityArr2[i3] = encodeBytesEntityArr[i2];
                i2++;
            }
            i3 = i;
        }
        return (EncodeBytesEntity[]) Arrays.copyOfRange(encodeBytesEntityArr2, 0, i3);
    }

    private EncodeBytesEntity mergePair(EncodeBytesEntity encodeBytesEntity, EncodeBytesEntity encodeBytesEntity2) {
        byte[] copyOf = Arrays.copyOf(encodeBytesEntity.bytes, encodeBytesEntity.bytes.length + encodeBytesEntity2.bytes.length);
        System.arraycopy(encodeBytesEntity2.bytes, 0, copyOf, encodeBytesEntity.bytes.length, encodeBytesEntity2.bytes.length);
        return new EncodeBytesEntity(copyOf);
    }

    private List<String> splitWithSpecial(String str) {
        ArrayList arrayList = new ArrayList();
        if (str.contains(SPECIAL_START) && str.contains(SPECIAL_END)) {
            return StringUtils.splitByStrings(str, specialTokens.keySet());
        }
        arrayList.add(str);
        return arrayList;
    }

    @Override // com.alibaba.dashscope.tokenizers.Tokenizer
    public String decode(List<Integer> list) {
        StringBuilder sb = new StringBuilder();
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            sb.append(new String(decodeMap[it.next().intValue()], StandardCharsets.UTF_8));
        }
        return sb.toString();
    }

    @Override // com.alibaba.dashscope.tokenizers.Tokenizer
    public List<Integer> encode(String str, String str2) throws NoSpecialTokenExists, UnSupportedSpecialTokenMode {
        Map linkedHashMap;
        boolean z;
        if (str2 == null) {
            str2 = "all";
        }
        if ("all".equals(str2)) {
            linkedHashMap = specialTokens;
        } else if (SchedulerSupport.NONE.equals(str2)) {
            linkedHashMap = new LinkedHashMap();
        } else {
            if (!"none_raise".equals(str2)) {
                throw new UnSupportedSpecialTokenMode(String.format("UnSupport allowedSpecial: %s", str2));
            }
            linkedHashMap = new LinkedHashMap();
            Iterator<String> it = specialTokens.keySet().iterator();
            while (true) {
                if (!it.hasNext()) {
                    z = false;
                    break;
                }
                if (str.indexOf(it.next()) != -1) {
                    z = true;
                    break;
                }
            }
            if (!z) {
                throw new NoSpecialTokenExists(String.format("No special token in %s", str));
            }
        }
        if (linkedHashMap.isEmpty()) {
            return encodeOrdinary(str);
        }
        List<String> splitWithSpecial = splitWithSpecial(str);
        ArrayList arrayList = new ArrayList();
        for (String str3 : splitWithSpecial) {
            if (linkedHashMap.containsKey(str3)) {
                arrayList.add(linkedHashMap.get(str3));
            } else {
                arrayList.addAll(encodeOrdinary(str3));
            }
        }
        return arrayList;
    }

    @Override // com.alibaba.dashscope.tokenizers.Tokenizer
    public List<Integer> encodeOrdinary(String str) {
        ArrayList arrayList = new ArrayList();
        Matcher matcher = Pattern.compile(PATTEN_STRING).matcher(str);
        while (matcher.find()) {
            arrayList.addAll(encodeChunk(matcher.group()));
        }
        return arrayList;
    }
}
