Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorflow_text/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ tf_cc_library(
":darts_clone_trie_builder",
":darts_clone_trie_wrapper",
":fast_bert_normalizer_model",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -1254,9 +1255,11 @@ cc_library(
":string_vocab",
":whitespace_tokenizer",
":whitespace_tokenizer_config_builder",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/random",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
# lite/kernels/shim:status_macros tensorflow dep,
Expand Down
31 changes: 29 additions & 2 deletions tensorflow_text/core/kernels/fast_bert_normalizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_H_
#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_H_

#include <cstddef>
#include <cstdint>
#include <vector>

#include "absl/base/optimization.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/utf8.h"
#include "tensorflow/lite/kernels/shim/status_macros.h"
Expand Down Expand Up @@ -83,7 +86,12 @@ class FastBertNormalizer {
// lifetime of the instance.
static absl::StatusOr<FastBertNormalizer> Create(
const uint32_t* trie_data, int data_for_codepoint_zero,
const char* normalized_string_pool) {
const char* normalized_string_pool,
size_t normalized_string_pool_size = static_cast<size_t>(-1)) {
if (trie_data == nullptr || normalized_string_pool == nullptr) {
return absl::InvalidArgumentError(
"trie_data or normalized_string_pool is null");
}
FastBertNormalizer result;
SH_ASSIGN_OR_RETURN(auto trie,
trie_utils::DartsCloneTrieWrapper::Create(trie_data));
Expand All @@ -92,6 +100,7 @@ class FastBertNormalizer {
result.data_for_codepoint_zero_ = data_for_codepoint_zero;
result.normalized_string_pool_ =
reinterpret_cast<const char*>(normalized_string_pool);
result.normalized_string_pool_size_ = normalized_string_pool_size;
return result;
}

Expand All @@ -103,11 +112,20 @@ class FastBertNormalizer {
// through the lifetime of the instance.
static absl::StatusOr<FastBertNormalizer> Create(
const void* model_flatbuffer) {
if (model_flatbuffer == nullptr) {
return absl::InvalidArgumentError("model_flatbuffer is null");
}
// `GetFastBertNormalizerModel()` is autogenerated by flatbuffer.
auto model = GetFastBertNormalizerModel(model_flatbuffer);
if (model == nullptr || model->trie_array() == nullptr ||
model->normalized_string_pool() == nullptr) {
return absl::InvalidArgumentError(
"FastBertNormalizerModel or its required fields are null");
}
return Create(
model->trie_array()->data(), model->data_for_codepoint_zero(),
reinterpret_cast<const char*>(model->normalized_string_pool()->data()));
reinterpret_cast<const char*>(model->normalized_string_pool()->data()),
model->normalized_string_pool()->size());
}

// Normalizes the input based on config `lower_case_nfd_strip_accents`.
Expand Down Expand Up @@ -290,6 +308,12 @@ class FastBertNormalizer {
}
const int offset = (data & text_norm::kNormalizedStringOffsetMask) >>
text_norm::kBitsToEncodeUtf8LengthOfNormalizedString;
if (ABSL_PREDICT_FALSE(
offset < 0 ||
(normalized_string_pool_size_ != static_cast<size_t>(-1) &&
offset + len > normalized_string_pool_size_))) {
return "";
}
return absl::string_view(normalized_string_pool_ + offset, len);
}

Expand Down Expand Up @@ -331,6 +355,9 @@ class FastBertNormalizer {
// The string pool of normalized strings. Each normalized string is a
// substring denoted by (offset and length).
const char* normalized_string_pool_;

// The size of normalized_string_pool_ if known, or -1.
size_t normalized_string_pool_size_ = static_cast<size_t>(-1);
};

} // namespace text
Expand Down
25 changes: 25 additions & 0 deletions tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
#include <memory>

#include "absl/base/attributes.h"
#include "absl/base/optimization.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
Expand Down Expand Up @@ -48,6 +50,11 @@ FastWordpieceTokenizer::Create(const void* config_flatbuffer) {
FastWordpieceTokenizer tokenizer;
// `GetFastWordpieceTokenizerConfig()` is autogenerated by flatbuffer.
tokenizer.config_ = GetFastWordpieceTokenizerConfig(config_flatbuffer);
if (tokenizer.config_ == nullptr ||
tokenizer.config_->trie_array() == nullptr) {
return absl::InvalidArgumentError(
"FastWordpieceTokenizerConfig or its trie_array is null.");
}
auto trie_or = trie_utils::DartsCloneTrieWrapper::Create(
tokenizer.config_->trie_array()->data());
if (!trie_or.ok()) {
Expand Down Expand Up @@ -127,8 +134,23 @@ FastWordpieceTokenizer::DetokenizeToTokens(
"true in the config flatbuffer. Please rebuild the model flatbuffer "
"by setting support_detokenization=true.");
}
if (config_->vocab_array() == nullptr ||
config_->vocab_is_suffix_array() == nullptr) {
return absl::InternalError(
"Missing vocab_array or vocab_is_suffix_array in config.");
}
const int vocab_size = config_->vocab_array()->size();
const int is_suffix_size = config_->vocab_is_suffix_array()->size();
for (int id : input) {
if (ABSL_PREDICT_FALSE(id < 0 || id >= vocab_size ||
id >= is_suffix_size)) {
return absl::OutOfRangeError(
absl::StrCat("Token ID out of bounds: ", id));
}
auto vocab = config_->vocab_array()->Get(id);
if (ABSL_PREDICT_FALSE(vocab == nullptr)) {
return absl::InternalError("Null vocab string in vocab_array.");
}
auto is_suffix = config_->vocab_is_suffix_array()->Get(id);
if (!subwords.empty() && !is_suffix) {
// When current subword is not a suffix token, it marks the start of a new
Expand All @@ -140,6 +162,9 @@ FastWordpieceTokenizer::DetokenizeToTokens(
// Special case: when a suffix token e.g. "##a" appears at the start of the
// input ids, we preserve the suffix_indicator.
if (subwords.empty() && is_suffix) {
if (ABSL_PREDICT_FALSE(config_->suffix_indicator() == nullptr)) {
return absl::InternalError("Missing suffix_indicator in config.");
}
subwords.emplace_back(config_->suffix_indicator()->string_view());
}
subwords.emplace_back(vocab->string_view());
Expand Down
20 changes: 20 additions & 0 deletions tensorflow_text/core/kernels/phrase_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
#include <string>
#include <vector>

#include "absl/base/optimization.h"
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "tensorflow/lite/kernels/shim/status_macros.h"
Expand All @@ -34,6 +37,12 @@ namespace text {
PhraseTokenizer tokenizer;
// `GetPhraseTokenizerConfig()` is autogenerated by flatbuffer.
tokenizer.phrase_config_ = GetPhraseTokenizerConfig(config_flatbuffer);
if (tokenizer.phrase_config_ == nullptr ||
tokenizer.phrase_config_->vocab_trie() == nullptr ||
tokenizer.phrase_config_->whitespace_config() == nullptr) {
return absl::InvalidArgumentError(
"PhraseTokenizerConfig or required fields are null.");
}
tokenizer.trie_ = absl::make_unique<sentencepiece::DoubleArrayTrie>(
tokenizer.phrase_config_->vocab_trie()->nodes());
tokenizer.prob_ = static_cast<float>(tokenizer.phrase_config_->prob()) / 100;
Expand Down Expand Up @@ -174,8 +183,19 @@ absl::StatusOr<std::vector<std::string>> PhraseTokenizer::DetokenizeToTokens(
"true in the config flatbuffer. Please rebuild the model flatbuffer "
"by setting support_detokenization=true.");
}
if (phrase_config_->vocab_array() == nullptr) {
return absl::InternalError("Missing vocab_array in config.");
}
const int vocab_size = phrase_config_->vocab_array()->size();
for (int id : input) {
if (ABSL_PREDICT_FALSE(id < 0 || id >= vocab_size)) {
return absl::OutOfRangeError(
absl::StrCat("Token ID out of bounds: ", id));
}
auto vocab = phrase_config_->vocab_array()->Get(id);
if (ABSL_PREDICT_FALSE(vocab == nullptr)) {
return absl::InternalError("Null vocab string in vocab_array.");
}
output_tokens.emplace_back(vocab->string_view());
}
return output_tokens;
Expand Down