diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs index 5bd204f501..8786eeb9ac 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs @@ -59,6 +59,50 @@ internal SentencePieceBaseModel(ModelProto modelProto, bool addBos = false, bool specialTokens); } + internal SentencePieceBaseModel( + bool addBos, bool addEos, + string bosToken, int bosId, + string eosToken, int eosId, + string unkToken, int unkId, + bool addDummyPrefix, bool escapeWhiteSpaces, + bool treatWhitespaceAsSuffix, bool byteFallback, + ReadOnlySpan precompiledCharsmap, bool removeExtraWhitespaces, + IReadOnlyDictionary? specialTokens) + { + AddBeginningOfSentence = addBos; + AddEndOfSentence = addEos; + BeginningOfSentenceToken = bosToken; + BeginningOfSentenceId = bosId; + EndOfSentenceToken = eosToken; + EndOfSentenceId = eosId; + UnknownToken = unkToken; + UnknownId = unkId; + AddDummyPrefix = addDummyPrefix; + EscapeWhiteSpaces = escapeWhiteSpaces; + TreatWhitespaceAsSuffix = treatWhitespaceAsSuffix; + ByteFallback = byteFallback; + SpecialTokens = specialTokens; + + if (specialTokens is not null && specialTokens.Count > 0) + { + InternalSpecialTokens = new Dictionary(); + SpecialTokensReverse = new Dictionary(); + + foreach (var item in specialTokens) + { + InternalSpecialTokens.Add(new StringSpanOrdinalKey(item.Key), item.Value); + SpecialTokensReverse.Add(item.Value, item.Key); + } + + SpecialTokensRegex = new Regex(string.Join("|", specialTokens.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled); + } + + Normalizer = new SentencePieceNormalizer( + precompiledCharsmap, removeExtraWhitespaces, + addDummyPrefix, escapeWhiteSpaces, + treatWhitespaceAsSuffix, specialTokens); + } + internal Regex? SpecialTokensRegex { get; } internal Dictionary? InternalSpecialTokens { get; } diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs index cb945d24fa..464b89019a 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs @@ -7,6 +7,7 @@ using System.Buffers; using System.Collections.Generic; using System.IO; +using System.Text.Json; namespace Microsoft.ML.Tokenizers { @@ -30,6 +31,11 @@ internal SentencePieceTokenizer(ModelProto modelProto, bool addBos, bool addEos, }; } + private SentencePieceTokenizer(SentencePieceBaseModel model) + { + _model = model; + } + /// /// The special tokens. /// @@ -457,5 +463,756 @@ public static SentencePieceTokenizer Create( return new SentencePieceTokenizer(modelProto, addBeginningOfSentence, addEndOfSentence, specialTokens); } + + /// + /// Creates a Unigram from an in-memory vocabulary of (piece, score) pairs. + /// + /// + /// The vocabulary as an ordered sequence of (piece, score) pairs. The position of each pair + /// in the sequence determines its token ID. + /// + /// The index (token ID) of the unknown token in . + /// Whether to emit the beginning-of-sentence token during encoding. + /// Whether to emit the end-of-sentence token during encoding. + /// + /// Optional precompiled character normalization map (as found in the SentencePiece normalizer_spec.precompiled_charsmap + /// field or in the Hugging Face tokenizer.json normalizer.precompiled_charsmap property). + /// Pass to skip precompiled normalization. + /// + /// Whether to prepend the dummy whitespace prefix character (U+2581) at the start of the input. + /// Whether to replace spaces with the dummy whitespace character (U+2581) during normalization. + /// Whether to emit the U+2581 character at the end of the last token rather than the beginning of the first token. + /// Whether unknown characters are decomposed into UTF-8 byte pieces (<0x00>..<0xFF>) instead of the unknown token. + /// Additional special tokens to recognize, supplied as a mapping of token string to token ID. + /// A new instance. + /// + /// The beginning-of-sentence and end-of-sentence token IDs are auto-detected by looking for pieces + /// named <s> and </s> in . If a piece is not found it is + /// treated as absent; requesting or + /// when the corresponding piece is absent throws an . A <pad> piece + /// is likewise detected automatically when present. + /// + /// When creating the tokenizer, ensure that the vocabulary is sourced from a trusted provider. + /// + /// + public static SentencePieceTokenizer Create( + IEnumerable<(string Piece, float Score)> vocab, + int unkId, + bool addBeginningOfSentence = true, + bool addEndOfSentence = false, + ReadOnlySpan precompiledCharsMap = default, + bool addDummyPrefix = true, + bool escapeWhiteSpaces = true, + bool treatWhitespaceAsSuffix = false, + bool byteFallback = false, + IReadOnlyDictionary? specialTokens = null) + { + if (vocab is null) + { + throw new ArgumentNullException(nameof(vocab)); + } + + IReadOnlyList<(string Piece, float Score)> pieces = vocab as IReadOnlyList<(string Piece, float Score)> + ?? new List<(string Piece, float Score)>(vocab); + + SentencePieceUnigramModel model = new SentencePieceUnigramModel( + pieces, unkId, addBeginningOfSentence, addEndOfSentence, + precompiledCharsMap, addDummyPrefix, escapeWhiteSpaces, + treatWhitespaceAsSuffix, removeExtraWhitespaces: true, byteFallback, specialTokens); + + return new SentencePieceTokenizer(model); + } + + /// + /// Creates a Unigram by parsing a Hugging Face tokenizer.json + /// that contains a Unigram model (model.type == "Unigram"). + /// + /// A stream containing the UTF-8-encoded tokenizer.json content. + /// Whether to emit the beginning-of-sentence token during encoding. + /// Whether to emit the end-of-sentence token during encoding. + /// Additional special tokens to recognize, supplied as a mapping of token string to token ID. + /// A new instance. + /// + /// The following fields are read from the JSON: + /// + /// model.vocab — array of [piece, score] pairs (required). + /// model.unk_id — index of the unknown token (required). + /// model.byte_fallback — whether unknown characters fall back to UTF-8 byte pieces. + /// added_tokens — special tokens (those with "special": true) and their IDs. + /// normalizer.precompiled_charsmap (base64) — normalization map; also searched inside a Sequence normalizer. + /// pre_tokenizer of type Metaspaceadd_prefix_space and replacement; also searched inside a Sequence pre-tokenizer. + /// post_processor (TemplateProcessing, RobertaProcessing, BertProcessing, or a Sequence of these) — the special tokens that wrap a single sequence, gated by (prefix) and (suffix). + /// + /// + /// remove_extra_whitespaces has no direct representation in tokenizer.json; it is deduced from + /// the normalizer's whitespace-collapsing steps (a right-Strip plus a runs-of-spaces Replace) and + /// from a WhitespaceSplit pre-tokenizer, defaulting to when none are present, to + /// match the Hugging Face fast-tokenizer runtime. Normalizers with content-modifying steps (per-character + /// Replace, Lowercase, StripAccents, Unicode normalization, Nmt, Prepend) are + /// applied in full before encoding. Pair-sequence templates and per-token type_ids are not applied. + /// Templates that place a special token in the middle of the sequence are rejected with + /// . + /// + /// + /// When creating the tokenizer, ensure that the JSON stream is sourced from a trusted provider. + /// + /// + public static SentencePieceTokenizer CreateFromTokenizerJson( + Stream tokenizerJsonStream, + bool addBeginningOfSentence = true, + bool addEndOfSentence = false, + IReadOnlyDictionary? specialTokens = null) + { + if (tokenizerJsonStream is null) + { + throw new ArgumentNullException(nameof(tokenizerJsonStream)); + } + + using JsonDocument doc = JsonDocument.Parse(tokenizerJsonStream); + JsonElement root = doc.RootElement; + + // Validate model type + if (!root.TryGetProperty("model", out JsonElement modelElement)) + { + throw new InvalidDataException("The tokenizer.json does not contain a 'model' property."); + } + + if (modelElement.ValueKind != JsonValueKind.Object) + { + throw new InvalidDataException("The tokenizer.json 'model' property must be a JSON object."); + } + + // Validate the model is Unigram. Older tokenizer.json files (e.g. xlm-roberta-base, albert) omit the + // model "type" entirely; treat a model that has a "vocab" but no BPE "merges" as Unigram, which matches + // how the Hugging Face loaders disambiguate these files. + if (modelElement.TryGetProperty("type", out JsonElement modelTypeElement) && + modelTypeElement.ValueKind == JsonValueKind.String) + { + if (!string.Equals(modelTypeElement.GetString(), "Unigram", StringComparison.OrdinalIgnoreCase)) + { + throw new InvalidDataException($"Expected model type 'Unigram' but found '{modelTypeElement.GetString()}'."); + } + } + else if (modelElement.TryGetProperty("merges", out _)) + { + throw new InvalidDataException("The tokenizer.json 'model' has no 'type' and contains 'merges'; this factory only supports 'Unigram' models."); + } + + if (!modelElement.TryGetProperty("unk_id", out JsonElement unkIdElement)) + { + throw new InvalidDataException("The tokenizer.json model does not contain an 'unk_id' property."); + } + + if (unkIdElement.ValueKind != JsonValueKind.Number) + { + throw new InvalidDataException("The tokenizer.json model 'unk_id' property must be a number."); + } + + int unkId = unkIdElement.GetInt32(); + + bool byteFallback = modelElement.TryGetProperty("byte_fallback", out JsonElement byteFallbackElement) && + byteFallbackElement.ValueKind == JsonValueKind.True; + + if (!modelElement.TryGetProperty("vocab", out JsonElement vocabElement) || + vocabElement.ValueKind != JsonValueKind.Array) + { + throw new InvalidDataException("The tokenizer.json model does not contain a valid 'vocab' array."); + } + + List<(string Piece, float Score)> vocab = new List<(string Piece, float Score)>(vocabElement.GetArrayLength()); + foreach (JsonElement entry in vocabElement.EnumerateArray()) + { + if (entry.ValueKind != JsonValueKind.Array || entry.GetArrayLength() < 2) + { + throw new InvalidDataException("Each entry in 'model.vocab' must be a [piece, score] array."); + } + + if (entry[0].ValueKind != JsonValueKind.String || entry[1].ValueKind != JsonValueKind.Number) + { + throw new InvalidDataException("Each entry in 'model.vocab' must be a [string piece, number score] pair."); + } + + string? piece = entry[0].GetString(); + if (piece is null) + { + throw new InvalidDataException("A piece string in 'model.vocab' is null."); + } + + vocab.Add((piece, entry[1].GetSingle())); + } + + if (unkId < 0 || unkId >= vocab.Count) + { + throw new InvalidDataException($"The tokenizer.json model 'unk_id' ({unkId}) is out of range for a vocabulary of {vocab.Count} pieces."); + } + + // Extract normalizer settings + byte[]? precompiledCharsMap = null; + bool addDummyPrefix = true; + // HF tokenizer.json has no remove_extra_whitespaces flag; SpmConverter encodes that behavior as + // explicit normalizer steps (a right-Strip plus a Replace collapsing runs of spaces). Deduce it from + // those steps, defaulting to false when absent to match the HF fast-tokenizer runtime. + bool removeExtraWhitespaces = false; + // When the normalizer has content-modifying steps that the charsmap + removeExtraWhitespaces + // approximation cannot represent (per-character Replace, Lowercase, Unicode normalization, ...), + // apply the full normalizer chain (charsmap included) before the Metaspace pass instead. + SentencePieceNormalizationStep? chainNormalizer = null; + if (root.TryGetProperty("normalizer", out JsonElement normalizerElement) && + normalizerElement.ValueKind == JsonValueKind.Object) + { + if (SentencePieceNormalizationStep.HasRichSteps(normalizerElement)) + { + chainNormalizer = SentencePieceNormalizationStep.Build(normalizerElement); + // The chain owns the charsmap and any strip/collapse, so the model's own normalizer must do + // only the Metaspace escaping to avoid applying either transformation twice. + precompiledCharsMap = null; + removeExtraWhitespaces = false; + } + else + { + precompiledCharsMap = ExtractPrecompiledCharsMap(normalizerElement); + removeExtraWhitespaces = NormalizerCollapsesWhitespace(normalizerElement); + } + } + + // Extract pre_tokenizer settings. treatWhitespaceAsSuffix has no tokenizer.json representation, so it + // stays at its default; it is still passed to the model which supports it. + bool escapeWhiteSpaces = true; + bool treatWhitespaceAsSuffix = false; + if (root.TryGetProperty("pre_tokenizer", out JsonElement preTokenizerElement) && + preTokenizerElement.ValueKind == JsonValueKind.Object) + { + ExtractMetaspaceSettings(preTokenizerElement, ref addDummyPrefix, ref escapeWhiteSpaces); + + // A WhitespaceSplit pre-tokenizer splits on whitespace and drops the empties, which collapses runs of + // whitespace and strips leading/trailing whitespace before Metaspace adds the dummy prefix. That is + // exactly remove_extra_whitespaces, so honor it even when the normalizer carries no collapse step. + // This is independent of chain-mode: the chain handles the normalizer's own steps, but whitespace + // collapsing here comes from the pre-tokenizer and is applied by the model's Metaspace pass. + if (PreTokenizerSplitsWhitespace(preTokenizerElement)) + { + removeExtraWhitespaces = true; + } + } + + // Merge the special tokens declared in added_tokens (authoritative source for their IDs) with any + // caller-supplied special tokens; the caller's entries win on conflict. + Dictionary mergedSpecialTokens = ParseAddedTokens(root); + if (specialTokens is not null) + { + foreach (var kvp in specialTokens) + { + mergedSpecialTokens[kvp.Key] = kvp.Value; + } + } + + // Resolve the prefix/suffix special-token wrapping from the post_processor (if present), falling back + // to the SentencePiece-conventional / names otherwise. + ResolvePostProcessorAffixes(root, vocab, mergedSpecialTokens, + out List<(int Id, string Token)> prefixTokens, out List<(int Id, string Token)> suffixTokens); + + // Ensure every wrapping token is registered as a special token so it is classified Control and round-trips on decode. + foreach (var (id, token) in prefixTokens) + { + mergedSpecialTokens[token] = id; + } + foreach (var (id, token) in suffixTokens) + { + mergedSpecialTokens[token] = id; + } + + int padId = mergedSpecialTokens.TryGetValue("", out int p) ? p : FindPieceId(vocab, ""); + + SentencePieceUnigramModel model = new SentencePieceUnigramModel( + vocab, unkId, addBeginningOfSentence, addEndOfSentence, + precompiledCharsMap is not null ? precompiledCharsMap.AsSpan() : default, + addDummyPrefix, escapeWhiteSpaces, treatWhitespaceAsSuffix, removeExtraWhitespaces, byteFallback, + mergedSpecialTokens.Count > 0 ? mergedSpecialTokens : null, + prefixTokens, suffixTokens, padId); + + if (chainNormalizer is not null) + { + model.Normalizer!.NormalizationChain = chainNormalizer; + } + + return new SentencePieceTokenizer(model); + } + + // Reads the special tokens (those marked "special": true) from the top-level added_tokens array. + private static Dictionary ParseAddedTokens(JsonElement root) + { + Dictionary result = new Dictionary(); + if (!root.TryGetProperty("added_tokens", out JsonElement addedTokens) || addedTokens.ValueKind != JsonValueKind.Array) + { + return result; + } + + foreach (JsonElement entry in addedTokens.EnumerateArray()) + { + if (entry.ValueKind != JsonValueKind.Object) + { + continue; + } + + if (!entry.TryGetProperty("special", out JsonElement specialElement) || specialElement.ValueKind != JsonValueKind.True) + { + continue; + } + + if (entry.TryGetProperty("content", out JsonElement contentElement) && + entry.TryGetProperty("id", out JsonElement idElement)) + { + if (contentElement.ValueKind != JsonValueKind.String || idElement.ValueKind != JsonValueKind.Number) + { + throw new InvalidDataException("An 'added_tokens' entry must have a string 'content' and a numeric 'id'."); + } + + result[contentElement.GetString()!] = idElement.GetInt32(); + } + } + + return result; + } + + // Resolves the ordered prefix/suffix special tokens that wrap an encoded sequence, from the post_processor. + private static void ResolvePostProcessorAffixes( + JsonElement root, + IReadOnlyList<(string Piece, float Score)> vocab, + IReadOnlyDictionary specialTokens, + out List<(int Id, string Token)> prefixTokens, + out List<(int Id, string Token)> suffixTokens) + { + prefixTokens = new List<(int Id, string Token)>(); + suffixTokens = new List<(int Id, string Token)>(); + + if (root.TryGetProperty("post_processor", out JsonElement postProcessor) && + postProcessor.ValueKind == JsonValueKind.Object) + { + ProcessPostProcessor(postProcessor, vocab, specialTokens, prefixTokens, suffixTokens); + return; + } + + // No post_processor: fall back to the SentencePiece-conventional names. + AddAffixToken(prefixTokens, "", vocab, specialTokens, required: false); + AddAffixToken(suffixTokens, "", vocab, specialTokens, required: false); + } + + private static void ProcessPostProcessor( + JsonElement postProcessor, + IReadOnlyList<(string Piece, float Score)> vocab, + IReadOnlyDictionary specialTokens, + List<(int Id, string Token)> prefixTokens, + List<(int Id, string Token)> suffixTokens) + { + string? type = GetStringOrNull(postProcessor, "type"); + + switch (type) + { + case "TemplateProcessing": + ProcessTemplate(postProcessor, vocab, specialTokens, prefixTokens, suffixTokens); + break; + + case "RobertaProcessing": + AddProcessorAffix(postProcessor, "cls", prefixTokens, vocab, specialTokens); + AddProcessorAffix(postProcessor, "sep", suffixTokens, vocab, specialTokens); + break; + + case "BertProcessing": + AddProcessorAffix(postProcessor, "cls", prefixTokens, vocab, specialTokens); + AddProcessorAffix(postProcessor, "sep", suffixTokens, vocab, specialTokens); + break; + + case "Sequence": + if (postProcessor.TryGetProperty("processors", out JsonElement processors) && processors.ValueKind == JsonValueKind.Array) + { + foreach (JsonElement inner in processors.EnumerateArray()) + { + if (inner.ValueKind == JsonValueKind.Object) + { + ProcessPostProcessor(inner, vocab, specialTokens, prefixTokens, suffixTokens); + } + } + } + break; + + default: + // ByteLevel and other processors do not contribute special-token wrapping; ignore them. + break; + } + } + + // Parses a TemplateProcessing "single" template into leading (prefix) and trailing (suffix) special tokens. + private static void ProcessTemplate( + JsonElement postProcessor, + IReadOnlyList<(string Piece, float Score)> vocab, + IReadOnlyDictionary specialTokens, + List<(int Id, string Token)> prefixTokens, + List<(int Id, string Token)> suffixTokens) + { + if (!postProcessor.TryGetProperty("single", out JsonElement single) || single.ValueKind != JsonValueKind.Array) + { + return; + } + + JsonElement? ppSpecialTokens = postProcessor.TryGetProperty("special_tokens", out JsonElement st) && st.ValueKind == JsonValueKind.Object + ? st : (JsonElement?)null; + + bool seenSequence = false; + foreach (JsonElement item in single.EnumerateArray()) + { + if (item.ValueKind != JsonValueKind.Object) + { + continue; + } + + if (item.TryGetProperty("Sequence", out _)) + { + if (seenSequence) + { + throw new NotSupportedException("tokenizer.json post_processor templates with more than one sequence are not supported."); + } + + seenSequence = true; + } + else if (item.TryGetProperty("SpecialToken", out JsonElement specialToken) && + specialToken.TryGetProperty("id", out JsonElement idElement)) + { + if (idElement.ValueKind != JsonValueKind.String) + { + throw new InvalidDataException("A post_processor template 'SpecialToken.id' must be a string."); + } + + string tokenName = idElement.GetString()!; + int id = ResolveTemplateTokenId(tokenName, ppSpecialTokens, specialTokens, vocab); + (seenSequence ? suffixTokens : prefixTokens).Add((id, tokenName)); + } + } + + if (!seenSequence) + { + throw new NotSupportedException("tokenizer.json post_processor template does not contain a sequence placeholder."); + } + } + + private static int ResolveTemplateTokenId( + string tokenName, + JsonElement? ppSpecialTokens, + IReadOnlyDictionary specialTokens, + IReadOnlyList<(string Piece, float Score)> vocab) + { + if (ppSpecialTokens is JsonElement st && + st.TryGetProperty(tokenName, out JsonElement entry) && + entry.TryGetProperty("ids", out JsonElement ids) && + ids.ValueKind == JsonValueKind.Array && + ids.GetArrayLength() > 0) + { + if (ids[0].ValueKind != JsonValueKind.Number) + { + throw new InvalidDataException($"The tokenizer.json post_processor special token '{tokenName}' has a non-numeric id."); + } + + int id = ids[0].GetInt32(); + + // Validate the id maps back to the referenced token (via added tokens or the vocab), mirroring + // AddProcessorAffix, so an inconsistent file cannot emit an id whose decoded token differs. + bool consistent = (specialTokens.TryGetValue(tokenName, out int mappedId) && mappedId == id) + || (id >= 0 && id < vocab.Count && vocab[id].Piece == tokenName); + if (!consistent) + { + throw new InvalidDataException($"The tokenizer.json post_processor special token '{tokenName}' maps to id {id}, which does not match the vocabulary or added tokens."); + } + + return id; + } + + if (specialTokens.TryGetValue(tokenName, out int specialId)) + { + return specialId; + } + + int vocabId = FindPieceId(vocab, tokenName); + if (vocabId < 0) + { + throw new InvalidDataException($"The tokenizer.json post_processor references special token '{tokenName}' that is not present in the vocabulary."); + } + + return vocabId; + } + + private static void AddProcessorAffix( + JsonElement postProcessor, + string property, + List<(int Id, string Token)> target, + IReadOnlyList<(string Piece, float Score)> vocab, + IReadOnlyDictionary specialTokens) + { + // Roberta/Bert processors store cls/sep as [token, id] arrays. + if (postProcessor.TryGetProperty(property, out JsonElement el) && el.ValueKind == JsonValueKind.Array && el.GetArrayLength() >= 2 && + el[0].ValueKind == JsonValueKind.String && el[1].ValueKind == JsonValueKind.Number) + { + string token = el[0].GetString()!; + int id = el[1].GetInt32(); + + // Validate the [token, id] pair against the vocabulary / added tokens so an inconsistent file cannot + // emit ids that do not map to the intended token. + bool consistent = (specialTokens.TryGetValue(token, out int specialId) && specialId == id) + || (id >= 0 && id < vocab.Count && vocab[id].Piece == token); + if (!consistent) + { + throw new InvalidDataException($"The post-processor '{property}' token '{token}' with id {id} does not match the vocabulary or added tokens."); + } + + target.Add((id, token)); + } + } + + private static void AddAffixToken( + List<(int Id, string Token)> target, + string tokenName, + IReadOnlyList<(string Piece, float Score)> vocab, + IReadOnlyDictionary specialTokens, + bool required) + { + int id = specialTokens.TryGetValue(tokenName, out int specialId) ? specialId : FindPieceId(vocab, tokenName); + if (id >= 0) + { + target.Add((id, tokenName)); + } + else if (required) + { + throw new InvalidDataException($"The tokenizer.json does not contain the required special token '{tokenName}'."); + } + } + + private static int FindPieceId(IReadOnlyList<(string Piece, float Score)> vocab, string token) + { + for (int i = 0; i < vocab.Count; i++) + { + if (vocab[i].Piece == token) + { + return i; + } + } + + return -1; + } + + // Reads a string-valued property, returning null when the property is absent or not a JSON string. Keeps the + // switch/compare parsing paths from throwing InvalidOperationException on malformed (non-string) tokenizer.json. + private static string? GetStringOrNull(JsonElement element, string propertyName) => + element.TryGetProperty(propertyName, out JsonElement value) && value.ValueKind == JsonValueKind.String + ? value.GetString() + : null; + + private static byte[]? ExtractPrecompiledCharsMap(JsonElement normalizer) + { + if (normalizer.ValueKind != JsonValueKind.Object) + { + return null; + } + + string? type = GetStringOrNull(normalizer, "type"); + if (string.Equals(type, "Precompiled", StringComparison.OrdinalIgnoreCase)) + { + if (normalizer.TryGetProperty("precompiled_charsmap", out JsonElement mapElement) && + mapElement.ValueKind == JsonValueKind.String) + { + string? base64 = mapElement.GetString(); + if (base64 is not null) + { + return SentencePieceNormalizationStep.DecodePrecompiledCharsMap(base64); + } + } + return null; + } + else if (string.Equals(type, "Sequence", StringComparison.OrdinalIgnoreCase) && + normalizer.TryGetProperty("normalizers", out JsonElement normalizersElement) && + normalizersElement.ValueKind == JsonValueKind.Array) + { + // Reached only for non-rich Sequences (charsmap plus whitespace-collapse/strip steps); rich steps + // such as Nmt route through the managed normalizer chain (HasRichSteps) and never get here. Extract + // the precompiled map and skip the non-rich steps handled elsewhere rather than failing the load. + byte[]? result = null; + foreach (JsonElement inner in normalizersElement.EnumerateArray()) + { + if (inner.ValueKind != JsonValueKind.Object) + { + continue; + } + + byte[]? innerResult = ExtractPrecompiledCharsMap(inner); + if (innerResult is not null) + { + result = innerResult; + } + } + return result; + } + + // Other normalizer types (Nmt, Replace, Lowercase, ...) carry no precompiled map; treat as absent. + return null; + } + + // Detects whether the normalizer collapses extra whitespace, i.e. SentencePiece's remove_extra_whitespaces. + // HF's SpmConverter emits this as a right-Strip plus a Replace of a runs-of-spaces Regex (" {2,}") -> "▁". + private static bool NormalizerCollapsesWhitespace(JsonElement normalizer) + { + if (normalizer.ValueKind != JsonValueKind.Object) + { + return false; + } + + string? type = GetStringOrNull(normalizer, "type"); + + if (string.Equals(type, "Strip", StringComparison.OrdinalIgnoreCase)) + { + // A right-Strip removes trailing whitespace; treat its presence as the strip half of the behavior. + return !normalizer.TryGetProperty("strip_right", out JsonElement stripRight) || stripRight.ValueKind != JsonValueKind.False; + } + + if (string.Equals(type, "Replace", StringComparison.OrdinalIgnoreCase)) + { + return ReplaceCollapsesSpaces(normalizer); + } + + if (string.Equals(type, "Sequence", StringComparison.OrdinalIgnoreCase) && + normalizer.TryGetProperty("normalizers", out JsonElement normalizersElement) && + normalizersElement.ValueKind == JsonValueKind.Array) + { + foreach (JsonElement inner in normalizersElement.EnumerateArray()) + { + if (NormalizerCollapsesWhitespace(inner)) + { + return true; + } + } + } + + return false; + } + + // True only for a Replace whose Regex matches runs of two-or-more spaces, not a single-space Metaspace Replace. + private static bool ReplaceCollapsesSpaces(JsonElement replace) + { + if (!replace.TryGetProperty("pattern", out JsonElement patternElement) || + patternElement.ValueKind != JsonValueKind.Object || + !patternElement.TryGetProperty("Regex", out JsonElement regexElement) || + regexElement.ValueKind != JsonValueKind.String) + { + return false; + } + + string? pattern = regexElement.GetString(); + if (pattern is null) + { + return false; + } + + // Do not trim: HF's canonical patterns " {2,}" and " +" carry a significant leading space. + switch (pattern) + { + case " {2,}": + case " +": + case "[ ]+": + case "[ ]{2,}": + case "\\s+": + case "\\s{2,}": + return true; + default: + return false; + } + } + + // Returns true if the pre-tokenizer splits on whitespace (WhitespaceSplit/Whitespace), recursing into a + // Sequence. Such a split discards whitespace runs, matching SentencePiece's remove_extra_whitespaces. + private static bool PreTokenizerSplitsWhitespace(JsonElement preTokenizer) + { + if (preTokenizer.ValueKind != JsonValueKind.Object) + { + return false; + } + + string? type = GetStringOrNull(preTokenizer, "type"); + if (string.Equals(type, "WhitespaceSplit", StringComparison.OrdinalIgnoreCase) || + string.Equals(type, "Whitespace", StringComparison.OrdinalIgnoreCase)) + { + return true; + } + + if (string.Equals(type, "Sequence", StringComparison.OrdinalIgnoreCase) && + preTokenizer.TryGetProperty("pretokenizers", out JsonElement preTokenizersElement) && + preTokenizersElement.ValueKind == JsonValueKind.Array) + { + foreach (JsonElement inner in preTokenizersElement.EnumerateArray()) + { + if (PreTokenizerSplitsWhitespace(inner)) + { + return true; + } + } + } + + return false; + } + + private static void ExtractMetaspaceSettings(JsonElement preTokenizer, ref bool addDummyPrefix, ref bool escapeWhiteSpaces) + { + if (preTokenizer.ValueKind != JsonValueKind.Object) + { + return; + } + + string? type = GetStringOrNull(preTokenizer, "type"); + if (string.Equals(type, "Metaspace", StringComparison.OrdinalIgnoreCase)) + { + if (preTokenizer.TryGetProperty("add_prefix_space", out JsonElement addPrefixElement)) + { + if (addPrefixElement.ValueKind != JsonValueKind.True && addPrefixElement.ValueKind != JsonValueKind.False) + { + throw new InvalidDataException("The pre_tokenizer 'add_prefix_space' must be a boolean."); + } + + addDummyPrefix = addPrefixElement.GetBoolean(); + } + + if (preTokenizer.TryGetProperty("replacement", out JsonElement replacementElement)) + { + if (replacementElement.ValueKind != JsonValueKind.String && replacementElement.ValueKind != JsonValueKind.Null) + { + throw new InvalidDataException("The pre_tokenizer 'replacement' must be a string."); + } + + // HF Metaspace's 'replacement' is the actual whitespace marker character. The SentencePiece model + // only supports U+2581 ('▁'); reject any other marker rather than silently not escaping spaces. + string? replacement = replacementElement.GetString(); + if (replacement is not null && replacement != "\u2581") // U+2581 LOWER ONE EIGHTH BLOCK (▁) + { + throw new NotSupportedException( + $"The Metaspace 'replacement' '{replacement}' is not supported; only U+2581 ('\u2581') is supported."); + } + + escapeWhiteSpaces = true; + } + + if (preTokenizer.TryGetProperty("prepend_scheme", out JsonElement prependSchemeElement)) + { + string? scheme = prependSchemeElement.ValueKind == JsonValueKind.String ? prependSchemeElement.GetString() : null; + // "never" suppresses the dummy prefix; "always"/"first" keep the default (true) + if (string.Equals(scheme, "never", StringComparison.OrdinalIgnoreCase)) + { + addDummyPrefix = false; + } + } + } + else if (string.Equals(type, "Sequence", StringComparison.OrdinalIgnoreCase) && + preTokenizer.TryGetProperty("pretokenizers", out JsonElement preTokenizersElement) && + preTokenizersElement.ValueKind == JsonValueKind.Array) + { + foreach (JsonElement inner in preTokenizersElement.EnumerateArray()) + { + ExtractMetaspaceSettings(inner, ref addDummyPrefix, ref escapeWhiteSpaces); + } + } + } } } diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs index 3714206cf0..feb28e8f1e 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs @@ -8,6 +8,7 @@ using System.Collections.Generic; using System.Collections.ObjectModel; using System.Diagnostics; +using System.IO; using System.Linq; using System.Runtime.InteropServices; using System.Text; @@ -22,6 +23,8 @@ internal sealed class SentencePieceUnigramModel : SentencePieceBaseModel private readonly DoubleArrayTrie _trie; private readonly float _minScore; private readonly float _maxScore; + private readonly (int Id, string Token)[] _prefixTokens; + private readonly (int Id, string Token)[] _suffixTokens; private const float UnkPenalty = 10.0f; public SentencePieceUnigramModel(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary? specialTokens = null) : base(modelProto, addBos, addEos, specialTokens) @@ -91,6 +94,266 @@ public SentencePieceUnigramModel(ModelProto modelProto, bool addBos, bool addEos _vocab[modelProto.TrainerSpec.PadPiece] = modelProto.TrainerSpec.PadId; _vocabReverse[modelProto.TrainerSpec.PadId] = (modelProto.TrainerSpec.PadPiece, 0f, ModelProto.Types.SentencePiece.Types.Type.Control); } + + _prefixTokens = DefaultAffix(BeginningOfSentenceId, BeginningOfSentenceToken); + _suffixTokens = DefaultAffix(EndOfSentenceId, EndOfSentenceToken); + } + + // Constructor that builds a Unigram model directly from a list of (piece, score) pairs. + // BOS, EOS, and PAD tokens are identified by their names ("", "", "") in the vocab; + // if not found by name, they are treated as absent (id = -1) to avoid misidentifying real pieces. + internal SentencePieceUnigramModel( + IReadOnlyList<(string Piece, float Score)> pieces, + int unkId, + bool addBos, + bool addEos, + ReadOnlySpan precompiledCharsmap, + bool addDummyPrefix, + bool escapeWhiteSpaces, + bool treatWhitespaceAsSuffix, + bool removeExtraWhitespaces, + bool byteFallback, + IReadOnlyDictionary? specialTokens) + : this(pieces, unkId, addBos, addEos, precompiledCharsmap, addDummyPrefix, escapeWhiteSpaces, + treatWhitespaceAsSuffix, removeExtraWhitespaces, byteFallback, specialTokens, + CheckSpecialId(addBos, FindSpecialTokenId(ValidateVocab(pieces, unkId), ""), "addBeginningOfSentence"), + CheckSpecialId(addEos, FindSpecialTokenId(pieces, ""), "addEndOfSentence"), + FindSpecialTokenId(pieces, ""), prefixTokens: null, suffixTokens: null) + { + } + + // Constructor that builds a Unigram model with explicit prefix/suffix special-token lists, for example + // resolved from a tokenizer.json post_processor template. addBeginningOfSentence gates the prefix list + // and addEndOfSentence gates the suffix list; an empty list is allowed (no tokens are emitted). + internal SentencePieceUnigramModel( + IReadOnlyList<(string Piece, float Score)> pieces, + int unkId, + bool addBos, + bool addEos, + ReadOnlySpan precompiledCharsmap, + bool addDummyPrefix, + bool escapeWhiteSpaces, + bool treatWhitespaceAsSuffix, + bool removeExtraWhitespaces, + bool byteFallback, + IReadOnlyDictionary? specialTokens, + IReadOnlyList<(int Id, string Token)> prefixTokens, + IReadOnlyList<(int Id, string Token)> suffixTokens, + int padId) + : this(pieces, unkId, addBos, addEos, precompiledCharsmap, addDummyPrefix, escapeWhiteSpaces, + treatWhitespaceAsSuffix, removeExtraWhitespaces, byteFallback, specialTokens, + FirstId(prefixTokens), FirstId(suffixTokens), padId, prefixTokens, suffixTokens) + { + } + + private SentencePieceUnigramModel( + IReadOnlyList<(string Piece, float Score)> pieces, + int unkId, + bool addBos, + bool addEos, + ReadOnlySpan precompiledCharsmap, + bool addDummyPrefix, + bool escapeWhiteSpaces, + bool treatWhitespaceAsSuffix, + bool removeExtraWhitespaces, + bool byteFallback, + IReadOnlyDictionary? specialTokens, + int bosId, int eosId, int padId, + IReadOnlyList<(int Id, string Token)>? prefixTokens, + IReadOnlyList<(int Id, string Token)>? suffixTokens) + : base(addBos, addEos, + bosId >= 0 && bosId < GetPieceCount(pieces) ? pieces[bosId].Piece : "", bosId, + eosId >= 0 && eosId < GetPieceCount(pieces) ? pieces[eosId].Piece : "", eosId, + GetPieceAtIndex(pieces, unkId), unkId, + addDummyPrefix, escapeWhiteSpaces, treatWhitespaceAsSuffix, byteFallback, + precompiledCharsmap, removeExtraWhitespaces, specialTokens) + { + Debug.Assert(pieces is not null); + + _vocab = new SortedDictionary(OrdinalUtf8StringComparer.Instance); + _vocabReverse = new (string Piece, float Score, ModelProto.Types.SentencePiece.Types.Type Type)[pieces!.Count]; + _minScore = float.MaxValue; + _maxScore = float.MinValue; + + // Control tokens (BOS/EOS/PAD plus any caller- or added_tokens-supplied special tokens) are kept + // out of the trie so normal segmentation never produces them; they are re-inserted afterwards. + HashSet controlIds = new HashSet(); + AddControlId(controlIds, bosId); + AddControlId(controlIds, eosId); + AddControlId(controlIds, padId); + if (specialTokens is not null) + { + foreach (int specialId in specialTokens.Values) + { + AddControlId(controlIds, specialId); + } + } + + for (int i = 0; i < pieces.Count; i++) + { + var (piece, score) = pieces[i]; + if (i == unkId) + { + _vocabReverse[i] = (piece, score, ModelProto.Types.SentencePiece.Types.Type.Unknown); + } + else if (controlIds.Contains(i)) + { + _vocabReverse[i] = (piece, score, ModelProto.Types.SentencePiece.Types.Type.Control); + } + else + { + _vocabReverse[i] = (piece, score, ModelProto.Types.SentencePiece.Types.Type.Normal); + _vocab.Add(piece, i); + _minScore = Math.Min(_minScore, score); + _maxScore = Math.Max(_maxScore, score); + } + } + + if (ByteFallback) + { + // Byte fallback requires a contiguous block of the 256 byte pieces <0x00>..<0xFF>; encode/decode map a + // byte value to ByteCodeToIdOffset + value. Validate it (the proto path relies on the same layout) and + // set MaxByteId from <0xFF> so byte ids are recognized on decode, rather than misencoding silently. + ByteCodeToIdOffset = _vocab.TryGetValue("<0x00>", out int id) ? id : MaxByteId; + if (!_vocab.ContainsKey("<0x00>") || !_vocab.TryGetValue("<0xFF>", out int maxByteId) || maxByteId - ByteCodeToIdOffset != 0xFF) + { + throw new InvalidDataException("The tokenizer.json model enables byte_fallback but does not contain a contiguous <0x00>..<0xFF> byte-piece block required to represent it."); + } + + MaxByteId = maxByteId; + OneByteUtf8EncodingMaxId = ByteCodeToIdOffset + 0x7F; + MaxIdByteFallbackId = ByteCodeToIdOffset + 0xFF; + } + // When byte fallback is disabled the byte offsets stay at 0 so decode treats no ids as byte pieces, even + // if the vocab happens to contain <0xNN> entries (otherwise normal low ids would be dropped as bytes). + + _trie = new DoubleArrayTrie(_vocab); + + // Re-insert special tokens into the vocab maps after the trie is built so they map like regular tokens. + string unkToken = pieces[unkId].Piece; + _vocab[unkToken] = unkId; + _vocabReverse[unkId] = (unkToken, 0f, ModelProto.Types.SentencePiece.Types.Type.Unknown); + + foreach (int controlId in controlIds) + { + if (controlId == unkId) + { + continue; // unk is classified Unknown above; don't downgrade it to Control. + } + + if (controlId >= 0 && controlId < pieces.Count) + { + string piece = pieces[controlId].Piece; + _vocab[piece] = controlId; + _vocabReverse[controlId] = (piece, 0f, ModelProto.Types.SentencePiece.Types.Type.Control); + } + } + + _prefixTokens = prefixTokens is not null ? ToAffixArray(prefixTokens) : DefaultAffix(BeginningOfSentenceId, BeginningOfSentenceToken); + _suffixTokens = suffixTokens is not null ? ToAffixArray(suffixTokens) : DefaultAffix(EndOfSentenceId, EndOfSentenceToken); + } + + private static (int Id, string Token)[] DefaultAffix(int id, string token) + => id >= 0 ? new[] { (id, token) } : Array.Empty<(int, string)>(); + + private static (int Id, string Token)[] ToAffixArray(IReadOnlyList<(int Id, string Token)> tokens) + { + var array = new (int Id, string Token)[tokens.Count]; + for (int i = 0; i < tokens.Count; i++) + { + array[i] = tokens[i]; + } + + return array; + } + + private static int FirstId(IReadOnlyList<(int Id, string Token)> tokens) => tokens.Count > 0 ? tokens[0].Id : -1; + + private void AddPrefixTokens(List tokens) + { + foreach (var (id, token) in _prefixTokens) + { + tokens.Add(new EncodedToken(id, token, new Range(0, 0))); + } + } + + private void AddSuffixTokens(List tokens, int offset) + { + foreach (var (id, token) in _suffixTokens) + { + tokens.Add(new EncodedToken(id, token, new Range(offset, offset))); + } + } + + private static void AddControlId(HashSet set, int id) + { + if (id >= 0) + { + set.Add(id); + } + } + + private static int GetPieceCount(IReadOnlyList<(string Piece, float Score)>? pieces) + => pieces?.Count ?? 0; + + private static string GetPieceAtIndex(IReadOnlyList<(string Piece, float Score)>? pieces, int index) + { + if (pieces is null) + { + throw new ArgumentNullException("vocab"); + } + + if ((uint)index >= (uint)pieces.Count) + { + throw new ArgumentOutOfRangeException("unkId", "unkId must be a valid index in the vocabulary."); + } + + return pieces[index].Piece; + } + + // Validates pieces is not null and unkId is in range; returns pieces unchanged. + private static IReadOnlyList<(string Piece, float Score)> ValidateVocab( + IReadOnlyList<(string Piece, float Score)>? pieces, int unkId) + { + if (pieces is null) + { + throw new ArgumentNullException("vocab"); + } + + if ((uint)unkId >= (uint)pieces.Count) + { + throw new ArgumentOutOfRangeException("unkId", "unkId must be a valid index in the vocabulary."); + } + + return pieces; + } + + // Finds a special token by name; returns -1 if not found. + private static int FindSpecialTokenId(IReadOnlyList<(string Piece, float Score)>? pieces, string tokenName) + { + if (pieces is null) + { + return -1; + } + + for (int i = 0; i < pieces.Count; i++) + { + if (pieces[i].Piece == tokenName) + { + return i; + } + } + + return -1; + } + + private static int CheckSpecialId(bool required, int id, string paramName) + { + if (required && id < 0) + { + throw new ArgumentException($"The vocabulary does not contain the required special token.", paramName); + } + return id; } public override IReadOnlyDictionary Vocabulary => new ReadOnlyDictionary(_vocab); @@ -218,7 +481,7 @@ private void EncodeToTokensWithSpecialTokens( if (addBeginningOfSentence) { - tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0))); + AddPrefixTokens(tokens); } int currentOffset = 0; @@ -250,7 +513,7 @@ private void EncodeToTokensWithSpecialTokens( if (addEndOfSentence) { - tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(progressOffset, progressOffset))); + AddSuffixTokens(tokens, progressOffset); } normalizedText = normalizedString.AsSpan().Slice(0, normalizedStringIndex).ToString(); @@ -268,7 +531,7 @@ private void EncodeToTokensWithoutSpecialTokens( { if (addBeginningOfSentence) { - tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0))); + AddPrefixTokens(tokens); } int progressOffset = 0; @@ -278,7 +541,7 @@ private void EncodeToTokensWithoutSpecialTokens( if (addEndOfSentence) { - tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(progressOffset, progressOffset))); + AddSuffixTokens(tokens, progressOffset); } normalizedText = normalizedString.AsSpan().Slice(0, normalizedStringIndex).ToString(); @@ -571,12 +834,15 @@ public override IReadOnlyList EncodeToIds( if (addBeginningOfSentence) { - ids.Add(BeginningOfSentenceId); - if (maxTokenCount == 1) + foreach (var (id, _) in _prefixTokens) { - normalizedText = null; - charsConsumed = 0; - return ids; // done. no more space for anything else. + ids.Add(id); + if (ids.Count >= maxTokenCount) + { + normalizedText = null; + charsConsumed = 0; + return ids; // done. no more space for anything else. + } } } @@ -595,9 +861,17 @@ public override IReadOnlyList EncodeToIds( EncodeToIdsWithoutSpecialTokens(textToEncode, considerNormalization, ids, buffer, ref normalizedString, out normalizedText, out charsConsumed, maxTokenCount); } - if (addEndOfSentence && ids.Count < maxTokenCount) + if (addEndOfSentence) { - ids.Add(EndOfSentenceId); + foreach (var (id, _) in _suffixTokens) + { + if (ids.Count >= maxTokenCount) + { + break; + } + + ids.Add(id); + } } if (normalizedString is not null) @@ -841,7 +1115,7 @@ private void EncodeToIdsInternal( if (maxTokenCount == int.MaxValue) { - Debug.Assert(unknownTokensCount == 0 && unknownTokensTracking is null); + Debug.Assert(ByteFallback || (unknownTokensCount == 0 && unknownTokensTracking is null)); if (ByteFallback && unknownTokensCount > 0) { @@ -960,13 +1234,15 @@ public override int CountTokens( if (addBeginningOfSentence) { - tokenCount++; - - if (maxTokenCount == 1) + foreach (var _ in _prefixTokens) { - normalizedText = null; - charsConsumed = 0; - return tokenCount; + tokenCount++; + if (tokenCount >= maxTokenCount) + { + normalizedText = null; + charsConsumed = 0; + return tokenCount; + } } } @@ -987,7 +1263,15 @@ public override int CountTokens( if (addEndOfSentence && tokenCount < maxTokenCount) { - tokenCount++; + foreach (var _ in _suffixTokens) + { + if (tokenCount >= maxTokenCount) + { + break; + } + + tokenCount++; + } } if (normalizedString is not null) @@ -1228,12 +1512,14 @@ public override int GetIndexByTokenCountFromEnd( if (addEndOfSentence) { - tokenCount++; - - if (maxTokenCount == 1) + foreach (var _ in _suffixTokens) { - normalizedText = null; - return textToEncode.Length; + tokenCount++; + if (tokenCount >= maxTokenCount) + { + normalizedText = null; + return textToEncode.Length; + } } } @@ -1256,7 +1542,15 @@ public override int GetIndexByTokenCountFromEnd( if (addBeginningOfSentence && tokenCount < maxTokenCount) { - tokenCount++; + foreach (var _ in _prefixTokens) + { + if (tokenCount >= maxTokenCount) + { + break; + } + + tokenCount++; + } } ArrayPool.Shared.Return(buffer); diff --git a/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizationStep.cs b/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizationStep.cs new file mode 100644 index 0000000000..163ac7f8f3 --- /dev/null +++ b/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizationStep.cs @@ -0,0 +1,492 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Text; +using System.Text.Json; +using System.Text.RegularExpressions; + +namespace Microsoft.ML.Tokenizers +{ + /// + /// Applies a Hugging Face tokenizer.json normalizer chain in managed code so that JSON-only + /// SentencePiece (Unigram) tokenizers normalize identically to the reference implementation before + /// the SentencePiece model performs Metaspace pre-tokenization and Unigram decoding. + /// + /// + /// SentencePiece's own fuses the precompiled charsmap with the + /// Metaspace whitespace handling in a single pass and cannot represent arbitrary normalizer steps + /// (per-character Replace, Lowercase, Unicode normalization, ...). When such steps are + /// present this chain runs first (charsmap included), then the model's normalizer performs only the + /// Metaspace escaping — matching the Hugging Face order (normalizer chain, then Metaspace pre-tokenizer). + /// + /// Tradeoff: this is intentionally a simple string pipeline that allocates one intermediate string per + /// step. The steps rely on string-only BCL APIs (, , + /// ), so a span-based rewrite could only remove a subset of those + /// allocations and was judged not to earn its complexity here: this chain runs only for the minority of + /// models whose normalizer has content-modifying steps; the common path (a bare charsmap, and every + /// protobuf-loaded model) sets NormalizationChain to null and bypasses it entirely. The caller + /// decodes/encodes at the boundary through pooled buffers to avoid intermediate byte[] garbage. + /// + /// + internal abstract class SentencePieceNormalizationStep + { + /// Applies this normalization step (or chain of steps) to . + public abstract string Normalize(string text); + + // string.Normalize / NFD throw on lone (unpaired) surrogates, which a caller-supplied string may legitimately + // contain. Replace any unpaired surrogate with U+FFFD so Unicode normalization degrades gracefully. + private static string ReplaceLoneSurrogates(string text) + { + StringBuilder? sb = null; + for (int i = 0; i < text.Length; i++) + { + char c = text[i]; + bool lone; + if (char.IsHighSurrogate(c)) + { + lone = i + 1 >= text.Length || !char.IsLowSurrogate(text[i + 1]); + } + else if (char.IsLowSurrogate(c)) + { + lone = i == 0 || !char.IsHighSurrogate(text[i - 1]); + } + else + { + lone = false; + } + + if (lone) + { + sb ??= new StringBuilder(text, 0, i, text.Length); + sb.Append('\uFFFD'); + } + else + { + sb?.Append(c); + } + } + + return sb is null ? text : sb.ToString(); + } + + /// + /// Returns true if the normalizer tree contains a content-modifying step that the charsmap plus + /// removeExtraWhitespaces approximation in cannot represent + /// (e.g. a literal/character Replace, Lowercase, Unicode normalization, or Prepend). + /// A bare precompiled charsmap, a Strip, or a whitespace-collapsing Replace are not "rich". + /// + public static bool HasRichSteps(JsonElement normalizer) + { + if (normalizer.ValueKind != JsonValueKind.Object || !normalizer.TryGetProperty("type", out JsonElement typeElement) || + typeElement.ValueKind != JsonValueKind.String) + { + return false; + } + + string? type = typeElement.GetString(); + switch (type) + { + case "Sequence": + if (normalizer.TryGetProperty("normalizers", out JsonElement steps) && steps.ValueKind == JsonValueKind.Array) + { + foreach (JsonElement step in steps.EnumerateArray()) + { + if (HasRichSteps(step)) + { + return true; + } + } + } + return false; + + case "Precompiled": + case "Strip": + return false; + + case "Replace": + // A whitespace-collapsing Replace is already approximated by removeExtraWhitespaces; any other + // Replace (literal punctuation splitting, character substitution) is content-modifying. + return !ReplaceIsWhitespaceCollapse(normalizer); + + default: + // Lowercase, NFC/NFD/NFKC/NFKD, Nmt, Prepend, ... all change content. + return true; + } + } + + /// + /// Builds the managed normalizer chain from a tokenizer.json normalizer element. Throws + /// for normalizer types that are not modeled so callers fail loudly + /// rather than silently mis-tokenizing. + /// + public static SentencePieceNormalizationStep Build(JsonElement normalizer) + { + if (normalizer.ValueKind != JsonValueKind.Object) + { + throw new InvalidDataException("A tokenizer.json normalizer entry must be a JSON object."); + } + + string? type = normalizer.TryGetProperty("type", out JsonElement typeElement) && typeElement.ValueKind == JsonValueKind.String + ? typeElement.GetString() : null; + switch (type) + { + case "Sequence": + var children = new List(); + if (normalizer.TryGetProperty("normalizers", out JsonElement steps) && steps.ValueKind == JsonValueKind.Array) + { + foreach (JsonElement step in steps.EnumerateArray()) + { + children.Add(Build(step)); + } + } + return new SequenceStep(children); + + case "Precompiled": + { + string? charsMap = null; + if (normalizer.TryGetProperty("precompiled_charsmap", out JsonElement mapElement)) + { + if (mapElement.ValueKind != JsonValueKind.String && mapElement.ValueKind != JsonValueKind.Null) + { + throw new InvalidDataException("The Precompiled normalizer 'precompiled_charsmap' must be a string."); + } + + charsMap = mapElement.GetString(); + } + + return new PrecompiledStep(string.IsNullOrEmpty(charsMap) ? default : DecodePrecompiledCharsMap(charsMap!)); + } + + case "Replace": + return ReplaceStep.Create(normalizer); + + case "Strip": + return new StripStep( + stripLeft: !normalizer.TryGetProperty("strip_left", out JsonElement left) || left.ValueKind != JsonValueKind.False, + stripRight: !normalizer.TryGetProperty("strip_right", out JsonElement right) || right.ValueKind != JsonValueKind.False); + + case "Lowercase": + return LowercaseStep.Instance; + + case "StripAccents": + return StripAccentsStep.Instance; + + case "NFC": + return new UnicodeStep(NormalizationForm.FormC); + case "NFD": + return new UnicodeStep(NormalizationForm.FormD); + case "NFKC": + return new UnicodeStep(NormalizationForm.FormKC); + case "NFKD": + return new UnicodeStep(NormalizationForm.FormKD); + + case "Prepend": + { + if (normalizer.TryGetProperty("prepend", out JsonElement prependElement) && + prependElement.ValueKind != JsonValueKind.String && prependElement.ValueKind != JsonValueKind.Null) + { + throw new InvalidDataException("The Prepend normalizer 'prepend' must be a string."); + } + + string prepend = prependElement.ValueKind == JsonValueKind.String ? prependElement.GetString() ?? "" : ""; + return new PrependStep(prepend); + } + + case "Nmt": + return NmtStep.Instance; + + default: + throw new NotSupportedException( + $"Unigram normalizer type '{type ?? ""}' is not supported when loading a tokenizer.json with content-modifying normalizer steps."); + } + } + + // Decodes a base64 'precompiled_charsmap' value, surfacing malformed input as InvalidDataException so callers + // get a consistent, diagnostic failure for bad tokenizer.json files instead of a raw FormatException. + internal static byte[] DecodePrecompiledCharsMap(string base64) + { + try + { + return Convert.FromBase64String(base64); + } + catch (FormatException ex) + { + throw new InvalidDataException("The tokenizer.json normalizer 'precompiled_charsmap' is not valid base64.", ex); + } + } + + // Mirrors SentencePieceTokenizer.ReplaceCollapsesSpaces: a Replace whose Regex matches runs of spaces. + private static bool ReplaceIsWhitespaceCollapse(JsonElement replace) + { + if (!replace.TryGetProperty("pattern", out JsonElement patternElement) || + patternElement.ValueKind != JsonValueKind.Object || + !patternElement.TryGetProperty("Regex", out JsonElement regexElement) || + regexElement.ValueKind != JsonValueKind.String) + { + return false; + } + + switch (regexElement.GetString()) + { + case " {2,}": + case " +": + case "[ ]+": + case "[ ]{2,}": + case "\\s+": + case "\\s{2,}": + return true; + default: + return false; + } + } + + private sealed class SequenceStep : SentencePieceNormalizationStep + { + private readonly IReadOnlyList _steps; + + public SequenceStep(IReadOnlyList steps) => _steps = steps; + + public override string Normalize(string text) + { + foreach (SentencePieceNormalizationStep step in _steps) + { + text = step.Normalize(text); + } + return text; + } + } + + private sealed class LowercaseStep : SentencePieceNormalizationStep + { + public static readonly LowercaseStep Instance = new LowercaseStep(); + public override string Normalize(string text) => text.ToLowerInvariant(); + } + + // Port of the Hugging Face tokenizers "StripAccents" normalizer: decompose (NFD) and drop combining marks. + private sealed class StripAccentsStep : SentencePieceNormalizationStep + { + public static readonly StripAccentsStep Instance = new StripAccentsStep(); + + public override string Normalize(string text) + { + if (text.Length == 0) + { + return text; + } + + string decomposed = ReplaceLoneSurrogates(text).Normalize(NormalizationForm.FormD); + StringBuilder? sb = null; + int i = 0; + while (i < decomposed.Length) + { + // Iterate by code point so combining marks encoded as surrogate pairs (astral plane) are + // classified and dropped, matching the reference (which filters on full code points). + int charCount = char.IsHighSurrogate(decomposed[i]) && i + 1 < decomposed.Length && char.IsLowSurrogate(decomposed[i + 1]) ? 2 : 1; + if (CharUnicodeInfo.GetUnicodeCategory(decomposed, i) == UnicodeCategory.NonSpacingMark) + { + if (sb is null) + { + sb = new StringBuilder(decomposed.Length); + sb.Append(decomposed, 0, i); + } + } + else + { + sb?.Append(decomposed, i, charCount); + } + + i += charCount; + } + + return sb is null ? decomposed : sb.ToString(); + } + } + + private sealed class UnicodeStep : SentencePieceNormalizationStep + { + private readonly NormalizationForm _form; + public UnicodeStep(NormalizationForm form) => _form = form; + public override string Normalize(string text) => text.Length == 0 ? text : ReplaceLoneSurrogates(text).Normalize(_form); + } + + private sealed class PrependStep : SentencePieceNormalizationStep + { + private readonly string _prefix; + public PrependStep(string prefix) => _prefix = prefix; + public override string Normalize(string text) => _prefix.Length == 0 ? text : _prefix + text; + } + + // Port of the Hugging Face tokenizers "Nmt" normalizer: drop a set of control characters and map a set of + // whitespace/format characters to a regular space. A no-op for text without those characters. + private sealed class NmtStep : SentencePieceNormalizationStep + { + public static readonly NmtStep Instance = new NmtStep(); + + private static bool IsRemoved(char c) => + (c >= '\u0001' && c <= '\u0008') || c == '\u000B' || (c >= '\u000E' && c <= '\u001F') || + c == '\u007F' || c == '\u008F' || c == '\u009F'; + + private static bool IsSpace(char c) => + c == '\u0009' || c == '\u000A' || c == '\u000C' || c == '\u000D' || c == '\u1680' || + (c >= '\u200B' && c <= '\u200F') || c == '\u2028' || c == '\u2029' || + c == '\u2581' || c == '\uFEFF' || c == '\uFFFD'; + + public override string Normalize(string text) + { + StringBuilder? sb = null; + for (int i = 0; i < text.Length; i++) + { + char c = text[i]; + if (IsRemoved(c)) + { + sb ??= NewBuilder(text, i); + continue; + } + + char mapped = IsSpace(c) ? ' ' : c; + if (sb is not null) + { + sb.Append(mapped); + } + else if (mapped != c) + { + sb = NewBuilder(text, i); + sb.Append(mapped); + } + } + + return sb is null ? text : sb.ToString(); + } + + private static StringBuilder NewBuilder(string text, int upTo) + { + var sb = new StringBuilder(text.Length); + sb.Append(text, 0, upTo); + return sb; + } + } + + private sealed class StripStep : SentencePieceNormalizationStep + { + private readonly bool _stripLeft; + private readonly bool _stripRight; + + public StripStep(bool stripLeft, bool stripRight) + { + _stripLeft = stripLeft; + _stripRight = stripRight; + } + + public override string Normalize(string text) + { + if (_stripLeft && _stripRight) + { + return text.Trim(); + } + if (_stripLeft) + { + return text.TrimStart(); + } + return _stripRight ? text.TrimEnd() : text; + } + } + + private sealed class ReplaceStep : SentencePieceNormalizationStep + { + // Bounds worst-case evaluation of regexes parsed from untrusted tokenizer.json. + private static readonly TimeSpan _regexTimeout = TimeSpan.FromSeconds(1); + + private readonly string _content; + private readonly string? _literal; + private readonly Regex? _regex; + + private ReplaceStep(string? literal, Regex? regex, string content) + { + _literal = literal; + _regex = regex; + _content = content; + } + + public static ReplaceStep Create(JsonElement normalizer) + { + string content = normalizer.TryGetProperty("content", out JsonElement contentElement) && contentElement.ValueKind == JsonValueKind.String + ? contentElement.GetString() ?? "" : ""; + if (!normalizer.TryGetProperty("pattern", out JsonElement pattern) || pattern.ValueKind != JsonValueKind.Object) + { + throw new InvalidDataException("Replace normalizer is missing its pattern."); + } + + if (pattern.TryGetProperty("String", out JsonElement literal)) + { + if (literal.ValueKind != JsonValueKind.String) + { + throw new InvalidDataException("Replace normalizer 'String' pattern must be a string."); + } + + return new ReplaceStep(literal.GetString() ?? "", regex: null, content); + } + + if (pattern.TryGetProperty("Regex", out JsonElement regex)) + { + if (regex.ValueKind != JsonValueKind.String) + { + throw new InvalidDataException("Replace normalizer 'Regex' pattern must be a string."); + } + + string regexPattern = regex.GetString()!; + try + { + return new ReplaceStep(literal: null, new Regex(regexPattern, RegexOptions.CultureInvariant, _regexTimeout), content); + } + catch (ArgumentException ex) + { + throw new InvalidDataException($"Replace normalizer has an invalid Regex pattern '{regexPattern}'.", ex); + } + } + + throw new NotSupportedException("Replace normalizer requires a String or Regex pattern."); + } + + public override string Normalize(string text) + { + if (_regex is not null) + { + // Hugging Face replaces the matched range with 'content' literally; escape '$' so Regex.Replace + // does not interpret it as a substitution pattern (e.g. "$0", "$&") and diverge from the reference. + string replacement = _content.IndexOf('$') < 0 ? _content : _content.Replace("$", "$$"); + return _regex.Replace(text, replacement); + } + + return string.IsNullOrEmpty(_literal) ? text : text.Replace(_literal, _content); + } + } + + // Applies a SentencePiece precompiled charsmap by delegating to a charsmap-only SentencePieceNormalizer + // (no dummy prefix, no whitespace escaping, no whitespace stripping), reusing the existing DARTS engine. + private sealed class PrecompiledStep : SentencePieceNormalizationStep + { + private readonly SentencePieceNormalizer? _charsMap; + + public PrecompiledStep(ReadOnlySpan precompiledCharsMap) + { + if (!precompiledCharsMap.IsEmpty) + { + _charsMap = new SentencePieceNormalizer( + precompiledCharsMap, + removeExtraWhiteSpaces: false, + addDummyPrefix: false, + escapeWhiteSpaces: false, + treatWhitespaceAsSuffix: false, + specialTokens: null); + } + } + + public override string Normalize(string text) => _charsMap is null ? text : _charsMap.NormalizeUtf8ToString(text); + } + } +} diff --git a/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs b/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs index e939e7f35b..c1b9c210d5 100644 --- a/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs +++ b/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs @@ -25,6 +25,13 @@ public sealed class SentencePieceNormalizer : Normalizer private readonly DoubleArrayTrie? _trie; private readonly byte[]? _normalized; + /// + /// Optional Hugging Face tokenizer.json normalization chain applied (text -> text) before the SentencePiece + /// Metaspace pass, for JSON-only Unigram models whose normalizer has content-modifying steps. Null for the + /// common case so the existing hot path is unaffected. + /// + internal SentencePieceNormalizationStep? NormalizationChain { get; set; } + /// /// Creates a SentencePieceNormalizer object. /// @@ -96,12 +103,53 @@ public override string Normalize(string original) return Normalize(original.AsSpan()); } + // Applies normalization through the byte-level path, which is the only path that consults the precompiled + // charsmap (the char-level path only does Metaspace/whitespace handling). Used by the JSON normalizer + // chain's Precompiled step, which needs the charsmap folding as a text -> text transform. + internal string NormalizeUtf8ToString(string text) + { + if (string.IsNullOrEmpty(text)) + { + return string.Empty; + } + + byte[] inputArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length)); + byte[]? poolArray = ArrayPool.Shared.Rent(Math.Max(64, Encoding.UTF8.GetMaxByteCount(text.Length) * 2)); + Span normalized = poolArray; + try + { + int inputLength = Helpers.GetUtf8Bytes(text.AsSpan(), inputArray); + int length = Normalize(inputArray.AsSpan(0, inputLength), ref normalized, ref poolArray); + return length == 0 ? string.Empty : Encoding.UTF8.GetString(poolArray!, 0, length); + } + finally + { + ArrayPool.Shared.Return(inputArray); + if (poolArray is not null) + { + ArrayPool.Shared.Return(poolArray); + } + } + } + /// /// Normalize the original string according to SentencePiece normalization. /// /// The original string to normalize. /// The normalized string. public override string Normalize(ReadOnlySpan original) + { + if (NormalizationChain is not null) + { + // Chain-mode (content-modifying tokenizer.json normalizer) is an off-the-hot-path branch that runs a + // simple string pipeline; the intermediate string here is intentional (see SentencePieceNormalizationStep). + return NormalizeCore(NormalizationChain.Normalize(original.ToString()).AsSpan()); + } + + return NormalizeCore(original); + } + + private string NormalizeCore(ReadOnlySpan original) { int startIndex = 0; int endIndex = original.Length - 1; @@ -339,6 +387,39 @@ internal int Normalize(ReadOnlySpan input, ref Span normalized, ref return 0; } + if (NormalizationChain is not null) + { + // Chain-mode (content-modifying tokenizer.json normalizer) runs a simple string pipeline off the hot + // path; the one decoded string is intentional (see SentencePieceNormalizationStep). Decode the input + // and re-encode the result through the span-based helpers + a pooled buffer to avoid byte[] garbage. + string chained = NormalizationChain.Normalize(Helpers.GetString(input)); + if (chained.Length == 0) + { + return 0; + } + + byte[] chainedBytes = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(chained.Length)); + try + { + int chainedLength = Helpers.GetUtf8Bytes(chained.AsSpan(), chainedBytes); + return NormalizeCore(chainedBytes.AsSpan(0, chainedLength), ref normalized, ref poolArray); + } + finally + { + ArrayPool.Shared.Return(chainedBytes); + } + } + + return NormalizeCore(input, ref normalized, ref poolArray); + } + + private int NormalizeCore(ReadOnlySpan input, ref Span normalized, ref byte[]? poolArray) + { + if (input.IsEmpty) + { + return 0; + } + int consumed = 0; // Ignores heading space. diff --git a/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs b/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs index ca671ddebe..bd3e3e7806 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs @@ -562,5 +562,1232 @@ public void SpecialTokensTest() Assert.Equal("", _unigramTokenizer.EndOfSentenceToken); Assert.Equal(2, _unigramTokenizer.EndOfSentenceId); } + + [Fact] + public void CreateFromVocabTest() + { + // Build a minimal synthetic Unigram vocab: =0, =1, =2, then normal tokens + var vocab = new List<(string Piece, float Score)> + { + ("", 0f), + ("", 0f), + ("", 0f), + ("▁Hello", -1f), + (",", -2f), + ("▁world", -3f), + ("!", -4f), + }; + + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.Create( + vocab, unkId: 0, addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal("", tokenizer.UnknownToken); + Assert.Equal(0, tokenizer.UnknownId); + Assert.Equal("", tokenizer.BeginningOfSentenceToken); + Assert.Equal(1, tokenizer.BeginningOfSentenceId); + Assert.Equal("", tokenizer.EndOfSentenceToken); + Assert.Equal(2, tokenizer.EndOfSentenceId); + + IReadOnlyList ids = tokenizer.EncodeToIds("Hello, world!", addBeginningOfSentence: false, addEndOfSentence: false); + Assert.Equal(new[] { 3, 4, 5, 6 }, ids); + + string decoded = tokenizer.Decode(ids, considerSpecialTokens: false); + Assert.Equal("Hello, world!", decoded); + } + + [Fact] + public void CreateFromVocabNullTest() + { + Assert.Throws(() => + SentencePieceTokenizer.Create((IEnumerable<(string Piece, float Score)>)null!, unkId: 0)); + } + + [Fact] + public void CreateFromVocabInvalidUnkIdTest() + { + var vocab = new List<(string Piece, float Score)> { ("a", 0f) }; + Assert.Throws(() => + SentencePieceTokenizer.Create(vocab, unkId: 5)); + } + + [Fact] + public void CreateFromTokenizerJsonTest() + { + using Stream jsonStream = File.OpenRead(Path.Combine("Paraphrase-multilingual-MiniLM-L12-v2", "tokenizer.json")); + SentencePieceTokenizer jsonTokenizer = SentencePieceTokenizer.CreateFromTokenizerJson( + jsonStream, addBeginningOfSentence: false, addEndOfSentence: false); + + // The tokenizer.json vocab has =0, =1, =2, =3, then normal tokens + // (shifted +1 relative to .model which has =0, =1, =2) + Assert.Equal("", jsonTokenizer.UnknownToken); + Assert.Equal(3, jsonTokenizer.UnknownId); + Assert.Equal("", jsonTokenizer.BeginningOfSentenceToken); + Assert.Equal(0, jsonTokenizer.BeginningOfSentenceId); + Assert.Equal("", jsonTokenizer.EndOfSentenceToken); + Assert.Equal(2, jsonTokenizer.EndOfSentenceId); + + // Pieces produced should match the .model tokenizer; IDs are shifted by +1 + IReadOnlyList jsonTokens = jsonTokenizer.EncodeToTokens("Hello, world!", out _, addBeginningOfSentence: false, addEndOfSentence: false); + IReadOnlyList modelTokens = _unigramTokenizer.EncodeToTokens("Hello, world!", out _, addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal(modelTokens.Count, jsonTokens.Count); + for (int i = 0; i < modelTokens.Count; i++) + { + Assert.Equal(modelTokens[i].Value, jsonTokens[i].Value); + // JSON IDs are offset by 1 from the .model IDs for normal tokens + Assert.Equal(modelTokens[i].Id + 1, jsonTokens[i].Id); + } + } + + [Fact] + public void CreateFromTokenizerJsonNullStreamTest() + { + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(null!)); + } + + [Fact] + public void CreateFromTokenizerJsonNormalizationTest() + { + // Verify that the JSON tokenizer applies the precompiled charsmap normalization + // (same normalization as the .model tokenizer) + using Stream jsonStream = File.OpenRead(Path.Combine("Paraphrase-multilingual-MiniLM-L12-v2", "tokenizer.json")); + SentencePieceTokenizer jsonTokenizer = SentencePieceTokenizer.CreateFromTokenizerJson( + jsonStream, addBeginningOfSentence: false, addEndOfSentence: false); + + // "㍻" normalizes to "平成" via the precompiled charsmap (NFKC normalization) + IReadOnlyList jsonIds = jsonTokenizer.EncodeToIds("㍻", addBeginningOfSentence: false, addEndOfSentence: false); + IReadOnlyList modelIds = _unigramTokenizer.EncodeToIds("㍻", addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal(modelIds.Count, jsonIds.Count); + for (int i = 0; i < modelIds.Count; i++) + { + Assert.Equal(modelIds[i] + 1, jsonIds[i]); + } + } + + [Fact] + public void CreateFromVocabNoSpecialTokensTest() + { + // Vocab without // — resembles bge-m3/potion layout. + // Verify that real pieces (e.g. ",") are not marked Control and remain encodable. + var vocab = new List<(string Piece, float Score)> + { + ("[PAD]", 0f), // 0 + ("[UNK]", 0f), // 1 + (",", -1f), // 2 + ("▁Hello", -2f), // 3 + ("▁world", -3f), // 4 + ("!", -4f), // 5 + }; + + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.Create( + vocab, unkId: 1, addBeginningOfSentence: false, addEndOfSentence: false); + + // "," must be in the vocabulary and encodable (not silently dropped as Control) + IReadOnlyList ids = tokenizer.EncodeToIds("Hello, world!", addBeginningOfSentence: false, addEndOfSentence: false); + Assert.Contains(2, ids); // id 2 is "," + } + + [Fact] + public void CreateFromVocabBosRequiredButAbsentTest() + { + // Vocab without : addBeginningOfSentence:true should throw rather than emit index 0. + var vocab = new List<(string Piece, float Score)> + { + ("[UNK]", 0f), + ("▁Hello", -1f), + }; + + Assert.Throws(() => + SentencePieceTokenizer.Create(vocab, unkId: 0, addBeginningOfSentence: true)); + } + + [Fact] + public void CreateFromTokenizerJsonSequenceNormalizerWithExtraStepsTest() + { + // A Sequence normalizer that interleaves the precompiled map with Nmt and a collapsing Replace builds the + // managed normalizer chain (Nmt is content-modifying) and loads rather than failing. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "normalizer": { + "type": "Sequence", + "normalizers": [ + { "type": "Nmt" }, + { "type": "Precompiled", "precompiled_charsmap": "" }, + { "type": "Replace", "pattern": { "Regex": " {2,}" }, "content": "\u2581" } + ] + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false); + Assert.NotNull(tokenizer); + } + + [Fact] + public void CreateFromTokenizerJsonAppliesStandaloneNmtNormalizer() + { + // A normalizer that is only Nmt must still be applied: U+200B (zero-width space) maps to a regular space. + string json = """ + { + "normalizer": { "type": "Nmt" }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["\u2581a", -1.0], ["\u2581b", -1.0]] + } + } + """; + + Assert.Equal(new[] { "\u2581a", "\u2581b" }, JsonUnigramTokens(json, "a\u200Bb")); + } + + // Vocab shared by the remove_extra_whitespaces deduction tests; "▁" is its own piece so a preserved + // extra space surfaces as an extra token. + private const string WhitespaceDeductionVocab = + "\"vocab\": [[\"\", 0.0], [\"\u2581a\", -1.0], [\"\u2581b\", -1.0], [\"\u2581\", -3.0], [\"a\", -10.0], [\"b\", -10.0]]"; + + [Fact] + public void CreateFromTokenizerJsonDeducesRemoveExtraWhitespacesFromReplaceStep() + { + // HF encodes remove_extra_whitespaces as a Strip + Replace(" {2,}" -> "▁"); the collapsing Replace + // ALONE (no sibling Strip) must enable whitespace collapsing so "a b" collapses to two pieces. + string json = $$""" + { + "normalizer": { + "type": "Sequence", + "normalizers": [ + { "type": "Replace", "pattern": { "Regex": " {2,}" }, "content": "\u2581" } + ] + }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + {{WhitespaceDeductionVocab}} + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal(2, tokenizer.CountTokens("a b", addBeginningOfSentence: false, addEndOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonDeducesRemoveExtraWhitespacesFromStripStep() + { + // A right-Strip alone (no Replace) also marks the behavior. + string json = $$""" + { + "normalizer": { + "type": "Sequence", + "normalizers": [ + { "type": "Strip", "strip_left": false, "strip_right": true } + ] + }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + {{WhitespaceDeductionVocab}} + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal(2, tokenizer.CountTokens("a b", addBeginningOfSentence: false, addEndOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNoCollapseStepPreservesExtraWhitespace() + { + // Without a Strip/Replace collapsing step (e.g. older bare-Precompiled files), remove_extra_whitespaces + // is deduced false to match the HF fast-tokenizer runtime, so the extra space is preserved as a token. + string json = $$""" + { + "normalizer": { "type": "Precompiled", "precompiled_charsmap": "" }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + {{WhitespaceDeductionVocab}} + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal(3, tokenizer.CountTokens("a b", addBeginningOfSentence: false, addEndOfSentence: false)); + } + + private static string[] JsonUnigramTokens(string json, string text) + { + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false); + return tokenizer.EncodeToTokens(text, out _, addBeginningOfSentence: false, addEndOfSentence: false).Select(t => t.Value).ToArray(); + } + + [Fact] + public void CreateFromTokenizerJsonAppliesAnchoredRegexReplaceNormalizer() + { + // Regression: a zero-width/anchored regex match at offset 0 (e.g. "^") must still insert the content. + string json = """ + { + "normalizer": { "type": "Replace", "pattern": { "Regex": "^" }, "content": "X" }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["\u2581Xab", -1.0]] + } + } + """; + + Assert.Equal(new[] { "\u2581Xab" }, JsonUnigramTokens(json, "ab")); + } + + [Fact] + public void CreateFromTokenizerJsonRegexReplaceContentIsLiteral() + { + // Regression: a '$' in the Replace content is substituted literally (Hugging Face semantics), not as a + // regex replacement reference such as "$0". Exercises both the net8 span path and the net48 string path. + string json = """ + { + "normalizer": { "type": "Replace", "pattern": { "Regex": "a" }, "content": "$0X" }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["\u2581$0X", -1.0]] + } + } + """; + + Assert.Equal(new[] { "\u2581$0X" }, JsonUnigramTokens(json, "a")); + } + + [Fact] + public void CreateFromTokenizerJsonAppliesPunctuationReplaceNormalizer() + { + // A content-modifying Replace (punctuation splitting, as model2vec/potion uses) must be applied before + // Metaspace, so "a,b" is split into separate pieces rather than tokenized as one run. + string json = """ + { + "normalizer": { "type": "Replace", "pattern": { "String": "," }, "content": " , " }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["\u2581a", -1.0], ["\u2581,", -1.0], ["\u2581b", -1.0]] + } + } + """; + + Assert.Equal(new[] { "\u2581a", "\u2581,", "\u2581b" }, JsonUnigramTokens(json, "a,b")); + } + + [Fact] + public void CreateFromTokenizerJsonAppliesLowercaseNormalizer() + { + // A Lowercase normalizer step must run before encoding so upper-case input maps to the lower-case piece. + string json = """ + { + "normalizer": { "type": "Lowercase" }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["\u2581hello", -1.0]] + } + } + """; + + Assert.Equal(new[] { "\u2581hello" }, JsonUnigramTokens(json, "HELLO")); + } + + [Fact] + public void CreateFromTokenizerJsonAppliesNmtNormalizer() + { + // Nmt maps format characters (e.g. U+200B zero-width space) to a regular space; combined with a rich + // step it runs in the chain, so "A\u200BB" normalizes to two space-separated pieces. + string json = """ + { + "normalizer": { "type": "Sequence", "normalizers": [ { "type": "Nmt" }, { "type": "Lowercase" } ] }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["\u2581a", -1.0], ["\u2581b", -1.0]] + } + } + """; + + Assert.Equal(new[] { "\u2581a", "\u2581b" }, JsonUnigramTokens(json, "A\u200BB")); + } + + [Fact] + public void CreateFromTokenizerJsonAppliesNfkcNormalizer() + { + // An NFKC normalizer step must run before encoding so the circled digit U+2460 maps to "1". + string json = """ + { + "normalizer": { "type": "NFKC" }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["\u25811", -1.0]] + } + } + """; + + Assert.Equal(new[] { "\u25811" }, JsonUnigramTokens(json, "\u2460")); + } + + [Fact] + public void CreateFromTokenizerJsonAppliesStripAccentsNormalizer() + { + // A StripAccents normalizer step decomposes and drops combining marks, so "café" maps to the "cafe" piece. + string json = """ + { + "normalizer": { "type": "StripAccents" }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["\u2581cafe", -1.0]] + } + } + """; + + Assert.Equal(new[] { "\u2581cafe" }, JsonUnigramTokens(json, "caf\u00e9")); + } + + [Fact] + public void CreateFromTokenizerJsonStripAccentsDropsAstralCombiningMark() + { + // Regression: a combining mark encoded as a surrogate pair (astral plane, e.g. U+1D167) must also be + // stripped, which requires classifying by code point rather than per UTF-16 code unit. + string json = """ + { + "normalizer": { "type": "StripAccents" }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["\u2581ab", -1.0]] + } + } + """; + + // "a" + U+1D167 (MUSICAL SYMBOL COMBINING TREMOLO-1, a NonSpacingMark) + "b" -> "ab". + Assert.Equal(new[] { "\u2581ab" }, JsonUnigramTokens(json, "a\uD834\uDD67b")); + } + + [Fact] + public void CreateFromTokenizerJsonByteFallbackRoundTrips() + { + // A Unigram model with byte_fallback and the 256 <0x00>..<0xFF> byte pieces must encode out-of-vocabulary + // characters to byte-fallback ids and decode them back to the original text (requires MaxByteId to be set). + StringBuilder vocab = new StringBuilder("[[\"\", 0.0], [\"\", 0.0], [\"\", 0.0]"); + for (int b = 0; b <= 0xFF; b++) + { + vocab.Append($", [\"<0x{b:X2}>\", -10.0]"); + } + vocab.Append(", [\"\\u2581a\", -1.0]]"); + + string json = $$""" + { + "model": { "type": "Unigram", "unk_id": 0, "byte_fallback": true, "vocab": {{vocab}} }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": false } + } + """; + + using Stream stream = new MemoryStream(Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false); + + foreach (string text in new[] { "x", "\u20AC", "Ab9" }) + { + IReadOnlyList ids = tokenizer.EncodeToIds(text, addBeginningOfSentence: false, addEndOfSentence: false); + Assert.DoesNotContain(0, ids); // byte fallback, not the id + Assert.Equal(text, tokenizer.Decode(ids)); + } + } + + [Fact] + public void CreateFromTokenizerJsonByteFallbackWithoutBytePiecesThrows() + { + // byte_fallback enabled but the <0x00>..<0xFF> pieces are absent must fail loudly rather than mis-encode. + string json = """ + { + "model": { "type": "Unigram", "unk_id": 0, "byte_fallback": true, "vocab": [["", 0.0], ["\u2581a", -1.0]] }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true } + } + """; + + using Stream stream = new MemoryStream(Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNonMetaspaceReplacementThrows() + { + // HF Metaspace 'replacement' other than U+2581 cannot be represented by the model; reject it. + string json = """ + { + "model": { "type": "Unigram", "unk_id": 0, "vocab": [["", 0.0], ["_a", -1.0]] }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "_", "add_prefix_space": true } + } + """; + + using Stream stream = new MemoryStream(Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonWhitespaceSplitPreTokenizerCollapsesWhitespace() + { + // A WhitespaceSplit pre-tokenizer collapses runs of whitespace (and strips ends), matching + // remove_extra_whitespaces, even though the normalizer carries no collapse step. + string json = """ + { + "pre_tokenizer": { + "type": "Sequence", + "pretokenizers": [ + { "type": "WhitespaceSplit" }, + { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true } + ] + }, + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["\u2581a", -1.0], ["\u2581b", -1.0]] + } + } + """; + + Assert.Equal(new[] { "\u2581a", "\u2581b" }, JsonUnigramTokens(json, " a b ")); + } + + [Fact] + public void CreateFromTokenizerJsonAppliesPrecompiledCharsMapInChainMode() + { + // Regression: a Precompiled charsmap nested in a content-modifying (chain-mode) normalizer must still be + // applied. Wrap the bundled model's real charsmap with a Lowercase step to force chain-mode and verify + // the charsmap still folds a full-width digit while Lowercase also runs. + using JsonDocument bundled = JsonDocument.Parse(File.ReadAllText(Path.Combine("Paraphrase-multilingual-MiniLM-L12-v2", "tokenizer.json"))); + string charsMap = bundled.RootElement.GetProperty("normalizer").GetProperty("precompiled_charsmap").GetString()!; + + string json = $$""" + { + "normalizer": { + "type": "Sequence", + "normalizers": [ + { "type": "Precompiled", "precompiled_charsmap": "{{charsMap}}" }, + { "type": "Lowercase" } + ] + }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["\u25811", -1.0], ["\u2581a", -1.0]] + } + } + """; + + // Full-width "1" (U+FF11) folds to "1" via the charsmap; uppercase "A" lower-cases to "a". + Assert.Equal(new[] { "\u25811" }, JsonUnigramTokens(json, "\uFF11")); + Assert.Equal(new[] { "\u2581a" }, JsonUnigramTokens(json, "A")); + } + + [Fact] + public void CreateFromTokenizerJsonChainHandlesLoneSurrogate() + { + // A content-modifying normalizer that runs Unicode normalization must not throw on a lone surrogate. + string json = """ + { + "normalizer": { "type": "NFKC" }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["\u2581a", -1.0], ["\u2581b", -1.0]] + } + } + """; + + string[] tokens = JsonUnigramTokens(json, "a\uD800b"); + Assert.NotEmpty(tokens); + } + + [Fact] + public void CreateFromTokenizerJsonUnsupportedNormalizerStepThrows() + { + // When a content-modifying chain references a normalizer type we do not model, fail loudly rather than + // silently mis-tokenizing. + string json = """ + { + "normalizer": { "type": "Sequence", "normalizers": [ { "type": "Lowercase" }, { "type": "BertNormalizer" } ] }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["\u2581a", -1.0]] + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNullNormalizerTest() + { + // A null normalizer value in JSON should not throw. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "normalizer": null + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson( + stream, addBeginningOfSentence: false); + Assert.NotNull(tokenizer); + } + + [Fact] + public void CreateFromVocabAbsentBosNotDecodedAsIdZeroTest() + { + // Vocab without /. With the add flags off, BOS/EOS must stay absent (-1) + // rather than being clamped to 0, so id 0 decodes as its real piece. + var vocab = new List<(string Piece, float Score)> + { + ("", 0f), // 0 + ("▁Hello", -1f), // 1 + ("▁world", -2f), // 2 + }; + + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.Create( + vocab, unkId: 0, addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal(-1, tokenizer.BeginningOfSentenceId); + Assert.Equal(-1, tokenizer.EndOfSentenceId); + + // id 0 is , not BOS; decoding it with considerSpecialTokens must yield the unk piece. + string decoded = tokenizer.Decode(new[] { 0 }, considerSpecialTokens: true); + Assert.Equal("", decoded); + } + + [Fact] + public void CreateFromTokenizerJsonMissingModelTypeInfersUnigram() + { + // Older tokenizer.json files (xlm-roberta-base, albert) omit model.type; a model with a [piece, score] + // vocab and no BPE merges is inferred as Unigram and loads. + string json = """ + { + "model": { + "unk_id": 0, + "vocab": [["", 0.0], ["\u2581a", -1.0]] + }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false); + Assert.Equal(new[] { "\u2581a" }, tokenizer.EncodeToTokens("a", out _, addBeginningOfSentence: false, addEndOfSentence: false).Select(t => t.Value).ToArray()); + } + + [Fact] + public void CreateFromTokenizerJsonMissingTypeWithMergesThrows() + { + // A model with no type but BPE "merges" is not a Unigram model and must be rejected. + string json = """ + { + "model": { + "unk_id": 0, + "vocab": { "": 0, "a": 1, "b": 2 }, + "merges": ["a b"] + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNonNumericUnkIdThrows() + { + // A non-numeric unk_id is a malformed file and must fail with InvalidDataException, not a cast error. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": "zero", + "vocab": [["", 0.0], ["a", -1.0]] + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNonNumericVocabScoreThrows() + { + // A vocab entry whose score is not a number is malformed and must fail with InvalidDataException. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", "not-a-score"]] + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNonNumericAddedTokenIdThrows() + { + // A special added_token with a non-numeric id is malformed and must fail with InvalidDataException. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "added_tokens": [ { "id": "one", "content": "", "special": true } ] + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNonStringTemplateSpecialTokenIdThrows() + { + // A post_processor template SpecialToken whose id is not a string must fail with InvalidDataException. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { "Sequence": { "id": "A", "type_id": 0 } }, + { "SpecialToken": { "id": 5, "type_id": 0 } } + ] + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNonStringPrependNormalizerThrows() + { + // A Prepend normalizer with a non-string 'prepend' must fail with InvalidDataException. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "normalizer": { "type": "Prepend", "prepend": 5 } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNonObjectNormalizerSequenceEntryThrows() + { + // A rich normalizer Sequence containing a non-object entry must fail with InvalidDataException. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "normalizer": { "type": "Sequence", "normalizers": [ { "type": "Lowercase" }, 123 ] } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNonStringPrecompiledCharsMapInChainThrows() + { + // A non-string precompiled_charsmap inside a rich chain must fail with InvalidDataException. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "normalizer": { "type": "Sequence", "normalizers": [ { "type": "Lowercase" }, { "type": "Precompiled", "precompiled_charsmap": 5 } ] } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonInconsistentTemplateSpecialTokenIdThrows() + { + // A template special_tokens id that does not map back to the referenced token must fail loudly. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["", 0.0], ["a", -1.0]] + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { "Sequence": { "id": "A", "type_id": 0 } }, + { "SpecialToken": { "id": "", "type_id": 0 } } + ], + "special_tokens": { + "": { "id": "", "ids": [2], "tokens": [""] } + } + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNonNumericTemplateSpecialTokenIdsThrows() + { + // A post_processor special_tokens entry whose ids[0] is not numeric must fail with InvalidDataException. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["", 0.0], ["a", -1.0]] + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { "Sequence": { "id": "A", "type_id": 0 } }, + { "SpecialToken": { "id": "", "type_id": 0 } } + ], + "special_tokens": { + "": { "id": "", "ids": ["one"], "tokens": [""] } + } + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNonBooleanAddPrefixSpaceThrows() + { + // A pre_tokenizer add_prefix_space that is not a boolean must fail with InvalidDataException. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": "true" } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNonStringMetaspaceReplacementThrows() + { + // A pre_tokenizer replacement that is not a string must fail with InvalidDataException. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "pre_tokenizer": { "type": "Metaspace", "replacement": 5, "add_prefix_space": true } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonInvalidReplaceRegexThrows() + { + // An invalid Regex in a Replace normalizer must surface as InvalidDataException, not ArgumentException. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "normalizer": { "type": "Replace", "pattern": { "Regex": "[" }, "content": "x" } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNonStringNormalizerTypeIsIgnored() + { + // A normalizer whose 'type' is not a string is handled in a controlled way (treated as no-op), not thrown. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "normalizer": { "type": 123 } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false); + Assert.NotNull(tokenizer); + } + + [Fact] + public void CreateFromTokenizerJsonOutOfRangeUnkIdThrows() + { + // unk_id must index into the parsed vocab; an out-of-range value is a malformed file. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 5, + "vocab": [["", 0.0], ["a", -1.0]] + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonMalformedPrecompiledCharsMapThrows() + { + // A non-base64 precompiled_charsmap must surface as InvalidDataException, not a raw FormatException. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "normalizer": { "type": "Precompiled", "precompiled_charsmap": "not valid base64!!" } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonByteFallbackDisabledKeepsByteLikePieces() + { + // When byte_fallback is false, <0xNN> pieces in the vocab are ordinary tokens and must decode as + // themselves rather than being treated as byte-fallback pieces (which would drop low ids on decode). + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "byte_fallback": false, + "vocab": [["", 0.0], ["<0x00>", -1.0], ["a", -2.0]] + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false); + + // Decoding the byte-like piece id round-trips to its literal piece, and the following normal id is not dropped. + Assert.Equal("<0x00>a", tokenizer.Decode(new[] { 1, 2 })); + } + + [Fact] + public void CreateFromTokenizerJsonNonUnigramModelTypeTest() + { + string json = """ + { + "model": { + "type": "BPE", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNullModelTest() + { + string json = """ + { + "model": null + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNullPreTokenizerTest() + { + // A null pre_tokenizer value in JSON should not throw. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "pre_tokenizer": null + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson( + stream, addBeginningOfSentence: false); + Assert.NotNull(tokenizer); + } + + [Fact] + public void CreateFromTokenizerJsonTemplateMultiTokenSuffixTest() + { + // XLNet-style template: the sequence is followed by two special tokens ( then ). + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 1, + "vocab": [["", 0.0], ["", 0.0], ["", 0.0], ["a", -1.0], ["b", -2.0]] + }, + "added_tokens": [ + { "id": 0, "content": "", "special": true }, + { "id": 1, "content": "", "special": true }, + { "id": 2, "content": "", "special": true } + ], + "pre_tokenizer": { "type": "Metaspace", "add_prefix_space": false, "replacement": "\u2581" }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { "Sequence": { "id": "A", "type_id": 0 } }, + { "SpecialToken": { "id": "", "type_id": 0 } }, + { "SpecialToken": { "id": "", "type_id": 0 } } + ], + "special_tokens": { + "": { "id": "", "ids": [0], "tokens": [""] }, + "": { "id": "", "ids": [2], "tokens": [""] } + } + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal(0, tokenizer.EndOfSentenceId); + Assert.Equal("", tokenizer.EndOfSentenceToken); + + IReadOnlyList withSuffix = tokenizer.EncodeToIds("a", addBeginningOfSentence: false, addEndOfSentence: true); + Assert.Equal(new[] { 3, 0, 2 }, withSuffix); + + IReadOnlyList withoutSuffix = tokenizer.EncodeToIds("a", addBeginningOfSentence: false, addEndOfSentence: false); + Assert.Equal(new[] { 3 }, withoutSuffix); + + Assert.Equal("a", tokenizer.Decode(withSuffix, considerSpecialTokens: false)); + } + + [Fact] + public void CreateFromTokenizerJsonRobertaProcessingTest() + { + // RobertaProcessing wraps the sequence with cls () at the front and sep () at the end. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 1, + "vocab": [["", 0.0], ["", 0.0], ["", 0.0], ["a", -1.0]] + }, + "added_tokens": [ + { "id": 0, "content": "", "special": true }, + { "id": 2, "content": "", "special": true } + ], + "pre_tokenizer": { "type": "Metaspace", "add_prefix_space": false, "replacement": "\u2581" }, + "post_processor": { "type": "RobertaProcessing", "sep": ["", 2], "cls": ["", 0] } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal(0, tokenizer.BeginningOfSentenceId); + Assert.Equal("", tokenizer.BeginningOfSentenceToken); + Assert.Equal(2, tokenizer.EndOfSentenceId); + Assert.Equal("", tokenizer.EndOfSentenceToken); + + IReadOnlyList ids = tokenizer.EncodeToIds("a", addBeginningOfSentence: true, addEndOfSentence: true); + Assert.Equal(new[] { 0, 3, 2 }, ids); + } + + [Fact] + public void CreateFromTokenizerJsonInconsistentProcessorAffixThrows() + { + // RobertaProcessing cls/sep [token, id] pair whose id does not map to the token must be rejected. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 1, + "vocab": [["", 0.0], ["", 0.0], ["", 0.0], ["a", -1.0]] + }, + "pre_tokenizer": { "type": "Metaspace", "add_prefix_space": false, "replacement": "\u2581" }, + "post_processor": { "type": "RobertaProcessing", "sep": ["", 3], "cls": ["", 0] } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonAddedTokenRecognizedTest() + { + // A special token from added_tokens that is not // must still be recognized as atomic. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 1, + "vocab": [["", 0.0], ["", 0.0], ["", 0.0], ["a", -1.0], ["", -5.0]] + }, + "added_tokens": [ + { "id": 4, "content": "", "special": true } + ], + "pre_tokenizer": { "type": "Metaspace", "add_prefix_space": false, "replacement": "\u2581" } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false); + + IReadOnlyList ids = tokenizer.EncodeToIds("aa", addBeginningOfSentence: false, addEndOfSentence: false); + Assert.Equal(new[] { 3, 4, 3 }, ids); + } + + [Fact] + public void CreateFromTokenizerJsonTemplateMultiSequenceThrowsTest() + { + // A template with more than one sequence placeholder cannot be represented and must be rejected. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { "Sequence": { "id": "A", "type_id": 0 } }, + { "Sequence": { "id": "B", "type_id": 0 } } + ], + "special_tokens": {} + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false)); + } } }