768 lines
26 KiB
C++
768 lines
26 KiB
C++
/*
|
|
* Copyright (C) 2017 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "smartselect/feature-processor.h"
|
|
|
|
#include <iterator>
|
|
#include <set>
|
|
#include <vector>
|
|
|
|
#include "smartselect/text-classification-model.pb.h"
|
|
#include "util/base/logging.h"
|
|
#include "util/strings/utf8.h"
|
|
#include "util/utf8/unicodetext.h"
|
|
#include "unicode/brkiter.h"
|
|
#include "unicode/errorcode.h"
|
|
#include "unicode/uchar.h"
|
|
|
|
namespace libtextclassifier {
|
|
|
|
namespace internal {
|
|
|
|
TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
|
|
const FeatureProcessorOptions& options) {
|
|
TokenFeatureExtractorOptions extractor_options;
|
|
|
|
extractor_options.num_buckets = options.num_buckets();
|
|
for (int order : options.chargram_orders()) {
|
|
extractor_options.chargram_orders.push_back(order);
|
|
}
|
|
extractor_options.max_word_length = options.max_word_length();
|
|
extractor_options.extract_case_feature = options.extract_case_feature();
|
|
extractor_options.unicode_aware_features = options.unicode_aware_features();
|
|
extractor_options.extract_selection_mask_feature =
|
|
options.extract_selection_mask_feature();
|
|
for (int i = 0; i < options.regexp_feature_size(); ++i) {
|
|
extractor_options.regexp_features.push_back(options.regexp_feature(i));
|
|
}
|
|
extractor_options.remap_digits = options.remap_digits();
|
|
extractor_options.lowercase_tokens = options.lowercase_tokens();
|
|
|
|
return extractor_options;
|
|
}
|
|
|
|
FeatureProcessorOptions ParseSerializedOptions(
|
|
const std::string& serialized_options) {
|
|
FeatureProcessorOptions options;
|
|
options.ParseFromString(serialized_options);
|
|
return options;
|
|
}
|
|
|
|
void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
|
|
std::vector<Token>* tokens) {
|
|
for (auto it = tokens->begin(); it != tokens->end(); ++it) {
|
|
const UnicodeText token_word =
|
|
UTF8ToUnicodeText(it->value, /*do_copy=*/false);
|
|
|
|
auto last_start = token_word.begin();
|
|
int last_start_index = it->start;
|
|
std::vector<UnicodeText::const_iterator> split_points;
|
|
|
|
// Selection start split point.
|
|
if (selection.first > it->start && selection.first < it->end) {
|
|
std::advance(last_start, selection.first - last_start_index);
|
|
split_points.push_back(last_start);
|
|
last_start_index = selection.first;
|
|
}
|
|
|
|
// Selection end split point.
|
|
if (selection.second > it->start && selection.second < it->end) {
|
|
std::advance(last_start, selection.second - last_start_index);
|
|
split_points.push_back(last_start);
|
|
}
|
|
|
|
if (!split_points.empty()) {
|
|
// Add a final split for the rest of the token unless it's been all
|
|
// consumed already.
|
|
if (split_points.back() != token_word.end()) {
|
|
split_points.push_back(token_word.end());
|
|
}
|
|
|
|
std::vector<Token> replacement_tokens;
|
|
last_start = token_word.begin();
|
|
int current_pos = it->start;
|
|
for (const auto& split_point : split_points) {
|
|
Token new_token(token_word.UTF8Substring(last_start, split_point),
|
|
current_pos,
|
|
current_pos + std::distance(last_start, split_point));
|
|
|
|
last_start = split_point;
|
|
current_pos = new_token.end;
|
|
|
|
replacement_tokens.push_back(new_token);
|
|
}
|
|
|
|
it = tokens->erase(it);
|
|
it = tokens->insert(it, replacement_tokens.begin(),
|
|
replacement_tokens.end());
|
|
std::advance(it, replacement_tokens.size() - 1);
|
|
}
|
|
}
|
|
}
|
|
|
|
void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
|
|
std::vector<UnicodeTextRange>* ranges) {
|
|
UnicodeText::const_iterator start = t.begin();
|
|
UnicodeText::const_iterator curr = start;
|
|
UnicodeText::const_iterator end = t.end();
|
|
for (; curr != end; ++curr) {
|
|
if (codepoints.find(*curr) != codepoints.end()) {
|
|
if (start != curr) {
|
|
ranges->push_back(std::make_pair(start, curr));
|
|
}
|
|
start = curr;
|
|
++start;
|
|
}
|
|
}
|
|
if (start != end) {
|
|
ranges->push_back(std::make_pair(start, end));
|
|
}
|
|
}
|
|
|
|
void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
|
|
std::vector<Token>* tokens) {
|
|
const UnicodeText context_unicode = UTF8ToUnicodeText(context,
|
|
/*do_copy=*/false);
|
|
std::vector<UnicodeTextRange> lines;
|
|
std::set<char32> codepoints;
|
|
codepoints.insert('\n');
|
|
codepoints.insert('|');
|
|
internal::FindSubstrings(context_unicode, codepoints, &lines);
|
|
|
|
auto span_start = context_unicode.begin();
|
|
if (span.first > 0) {
|
|
std::advance(span_start, span.first);
|
|
}
|
|
auto span_end = context_unicode.begin();
|
|
if (span.second > 0) {
|
|
std::advance(span_end, span.second);
|
|
}
|
|
for (const UnicodeTextRange& line : lines) {
|
|
// Find the line that completely contains the span.
|
|
if (line.first <= span_start && line.second >= span_end) {
|
|
const CodepointIndex last_line_begin_index =
|
|
std::distance(context_unicode.begin(), line.first);
|
|
const CodepointIndex last_line_end_index =
|
|
last_line_begin_index + std::distance(line.first, line.second);
|
|
|
|
for (auto token = tokens->begin(); token != tokens->end();) {
|
|
if (token->start >= last_line_begin_index &&
|
|
token->end <= last_line_end_index) {
|
|
++token;
|
|
} else {
|
|
token = tokens->erase(token);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace internal
|
|
|
|
std::string FeatureProcessor::GetDefaultCollection() const {
|
|
if (options_.default_collection() >= options_.collections_size()) {
|
|
TC_LOG(ERROR) << "No collections specified. Returning empty string.";
|
|
return "";
|
|
}
|
|
return options_.collections(options_.default_collection());
|
|
}
|
|
|
|
std::vector<Token> FeatureProcessor::Tokenize(
|
|
const std::string& utf8_text) const {
|
|
if (options_.tokenization_type() ==
|
|
libtextclassifier::FeatureProcessorOptions::INTERNAL_TOKENIZER) {
|
|
return tokenizer_.Tokenize(utf8_text);
|
|
} else if (options_.tokenization_type() ==
|
|
libtextclassifier::FeatureProcessorOptions::ICU ||
|
|
options_.tokenization_type() ==
|
|
libtextclassifier::FeatureProcessorOptions::MIXED) {
|
|
std::vector<Token> result;
|
|
if (!ICUTokenize(utf8_text, &result)) {
|
|
return {};
|
|
}
|
|
if (options_.tokenization_type() ==
|
|
libtextclassifier::FeatureProcessorOptions::MIXED) {
|
|
InternalRetokenize(utf8_text, &result);
|
|
}
|
|
return result;
|
|
} else {
|
|
TC_LOG(ERROR) << "Unknown tokenization type specified. Using "
|
|
"internal.";
|
|
return tokenizer_.Tokenize(utf8_text);
|
|
}
|
|
}
|
|
|
|
bool FeatureProcessor::LabelToSpan(
|
|
const int label, const VectorSpan<Token>& tokens,
|
|
std::pair<CodepointIndex, CodepointIndex>* span) const {
|
|
if (tokens.size() != GetNumContextTokens()) {
|
|
return false;
|
|
}
|
|
|
|
TokenSpan token_span;
|
|
if (!LabelToTokenSpan(label, &token_span)) {
|
|
return false;
|
|
}
|
|
|
|
const int result_begin_token = token_span.first;
|
|
const int result_begin_codepoint =
|
|
tokens[options_.context_size() - result_begin_token].start;
|
|
const int result_end_token = token_span.second;
|
|
const int result_end_codepoint =
|
|
tokens[options_.context_size() + result_end_token].end;
|
|
|
|
if (result_begin_codepoint == kInvalidIndex ||
|
|
result_end_codepoint == kInvalidIndex) {
|
|
*span = CodepointSpan({kInvalidIndex, kInvalidIndex});
|
|
} else {
|
|
*span = CodepointSpan({result_begin_codepoint, result_end_codepoint});
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool FeatureProcessor::LabelToTokenSpan(const int label,
|
|
TokenSpan* token_span) const {
|
|
if (label >= 0 && label < label_to_selection_.size()) {
|
|
*token_span = label_to_selection_[label];
|
|
return true;
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool FeatureProcessor::SpanToLabel(
|
|
const std::pair<CodepointIndex, CodepointIndex>& span,
|
|
const std::vector<Token>& tokens, int* label) const {
|
|
if (tokens.size() != GetNumContextTokens()) {
|
|
return false;
|
|
}
|
|
|
|
const int click_position =
|
|
options_.context_size(); // Click is always in the middle.
|
|
const int padding = options_.context_size() - options_.max_selection_span();
|
|
|
|
int span_left = 0;
|
|
for (int i = click_position - 1; i >= padding; i--) {
|
|
if (tokens[i].start != kInvalidIndex && tokens[i].end > span.first) {
|
|
++span_left;
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
|
|
int span_right = 0;
|
|
for (int i = click_position + 1; i < tokens.size() - padding; ++i) {
|
|
if (tokens[i].end != kInvalidIndex && tokens[i].start < span.second) {
|
|
++span_right;
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Check that the spanned tokens cover the whole span.
|
|
bool tokens_match_span;
|
|
if (options_.snap_label_span_boundaries_to_containing_tokens()) {
|
|
tokens_match_span =
|
|
tokens[click_position - span_left].start <= span.first &&
|
|
tokens[click_position + span_right].end >= span.second;
|
|
} else {
|
|
tokens_match_span =
|
|
tokens[click_position - span_left].start == span.first &&
|
|
tokens[click_position + span_right].end == span.second;
|
|
}
|
|
|
|
if (tokens_match_span) {
|
|
*label = TokenSpanToLabel({span_left, span_right});
|
|
} else {
|
|
*label = kInvalidLabel;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
int FeatureProcessor::TokenSpanToLabel(const TokenSpan& span) const {
|
|
auto it = selection_to_label_.find(span);
|
|
if (it != selection_to_label_.end()) {
|
|
return it->second;
|
|
} else {
|
|
return kInvalidLabel;
|
|
}
|
|
}
|
|
|
|
TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
|
|
CodepointSpan codepoint_span) {
|
|
const int codepoint_start = std::get<0>(codepoint_span);
|
|
const int codepoint_end = std::get<1>(codepoint_span);
|
|
|
|
TokenIndex start_token = kInvalidIndex;
|
|
TokenIndex end_token = kInvalidIndex;
|
|
for (int i = 0; i < selectable_tokens.size(); ++i) {
|
|
if (codepoint_start <= selectable_tokens[i].start &&
|
|
codepoint_end >= selectable_tokens[i].end &&
|
|
!selectable_tokens[i].is_padding) {
|
|
if (start_token == kInvalidIndex) {
|
|
start_token = i;
|
|
}
|
|
end_token = i + 1;
|
|
}
|
|
}
|
|
return {start_token, end_token};
|
|
}
|
|
|
|
CodepointSpan TokenSpanToCodepointSpan(
|
|
const std::vector<Token>& selectable_tokens, TokenSpan token_span) {
|
|
return {selectable_tokens[token_span.first].start,
|
|
selectable_tokens[token_span.second - 1].end};
|
|
}
|
|
|
|
namespace {
|
|
|
|
// Finds a single token that completely contains the given span.
|
|
int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens,
|
|
CodepointSpan codepoint_span) {
|
|
const int codepoint_start = std::get<0>(codepoint_span);
|
|
const int codepoint_end = std::get<1>(codepoint_span);
|
|
|
|
for (int i = 0; i < selectable_tokens.size(); ++i) {
|
|
if (codepoint_start >= selectable_tokens[i].start &&
|
|
codepoint_end <= selectable_tokens[i].end) {
|
|
return i;
|
|
}
|
|
}
|
|
return kInvalidIndex;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
namespace internal {
|
|
|
|
int CenterTokenFromClick(CodepointSpan span,
|
|
const std::vector<Token>& selectable_tokens) {
|
|
int range_begin;
|
|
int range_end;
|
|
std::tie(range_begin, range_end) =
|
|
CodepointSpanToTokenSpan(selectable_tokens, span);
|
|
|
|
// If no exact match was found, try finding a token that completely contains
|
|
// the click span. This is useful e.g. when Android builds the selection
|
|
// using ICU tokenization, and ends up with only a portion of our space-
|
|
// separated token. E.g. for "(857)" Android would select "857".
|
|
if (range_begin == kInvalidIndex || range_end == kInvalidIndex) {
|
|
int token_index = FindTokenThatContainsSpan(selectable_tokens, span);
|
|
if (token_index != kInvalidIndex) {
|
|
range_begin = token_index;
|
|
range_end = token_index + 1;
|
|
}
|
|
}
|
|
|
|
// We only allow clicks that are exactly 1 selectable token.
|
|
if (range_end - range_begin == 1) {
|
|
return range_begin;
|
|
} else {
|
|
return kInvalidIndex;
|
|
}
|
|
}
|
|
|
|
int CenterTokenFromMiddleOfSelection(
|
|
CodepointSpan span, const std::vector<Token>& selectable_tokens) {
|
|
int range_begin;
|
|
int range_end;
|
|
std::tie(range_begin, range_end) =
|
|
CodepointSpanToTokenSpan(selectable_tokens, span);
|
|
|
|
// Center the clicked token in the selection range.
|
|
if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
|
|
return (range_begin + range_end - 1) / 2;
|
|
} else {
|
|
return kInvalidIndex;
|
|
}
|
|
}
|
|
|
|
} // namespace internal
|
|
|
|
int FeatureProcessor::FindCenterToken(CodepointSpan span,
|
|
const std::vector<Token>& tokens) const {
|
|
if (options_.center_token_selection_method() ==
|
|
FeatureProcessorOptions::CENTER_TOKEN_FROM_CLICK) {
|
|
return internal::CenterTokenFromClick(span, tokens);
|
|
} else if (options_.center_token_selection_method() ==
|
|
FeatureProcessorOptions::CENTER_TOKEN_MIDDLE_OF_SELECTION) {
|
|
return internal::CenterTokenFromMiddleOfSelection(span, tokens);
|
|
} else if (options_.center_token_selection_method() ==
|
|
FeatureProcessorOptions::DEFAULT_CENTER_TOKEN_METHOD) {
|
|
// TODO(zilka): Remove once we have new models on the device.
|
|
// It uses the fact that sharing model use
|
|
// split_tokens_on_selection_boundaries and selection not. So depending on
|
|
// this we select the right way of finding the click location.
|
|
if (!options_.split_tokens_on_selection_boundaries()) {
|
|
// SmartSelection model.
|
|
return internal::CenterTokenFromClick(span, tokens);
|
|
} else {
|
|
// SmartSharing model.
|
|
return internal::CenterTokenFromMiddleOfSelection(span, tokens);
|
|
}
|
|
} else {
|
|
TC_LOG(ERROR) << "Invalid center token selection method.";
|
|
return kInvalidIndex;
|
|
}
|
|
}
|
|
|
|
bool FeatureProcessor::SelectionLabelSpans(
|
|
const VectorSpan<Token> tokens,
|
|
std::vector<CodepointSpan>* selection_label_spans) const {
|
|
for (int i = 0; i < label_to_selection_.size(); ++i) {
|
|
CodepointSpan span;
|
|
if (!LabelToSpan(i, tokens, &span)) {
|
|
TC_LOG(ERROR) << "Could not convert label to span: " << i;
|
|
return false;
|
|
}
|
|
selection_label_spans->push_back(span);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void FeatureProcessor::PrepareCodepointRanges(
|
|
const std::vector<FeatureProcessorOptions::CodepointRange>&
|
|
codepoint_ranges,
|
|
std::vector<CodepointRange>* prepared_codepoint_ranges) {
|
|
prepared_codepoint_ranges->clear();
|
|
prepared_codepoint_ranges->reserve(codepoint_ranges.size());
|
|
for (const FeatureProcessorOptions::CodepointRange& range :
|
|
codepoint_ranges) {
|
|
prepared_codepoint_ranges->push_back(
|
|
CodepointRange(range.start(), range.end()));
|
|
}
|
|
|
|
std::sort(prepared_codepoint_ranges->begin(),
|
|
prepared_codepoint_ranges->end(),
|
|
[](const CodepointRange& a, const CodepointRange& b) {
|
|
return a.start < b.start;
|
|
});
|
|
}
|
|
|
|
float FeatureProcessor::SupportedCodepointsRatio(
|
|
int click_pos, const std::vector<Token>& tokens) const {
|
|
int num_supported = 0;
|
|
int num_total = 0;
|
|
for (int i = click_pos - options_.context_size();
|
|
i <= click_pos + options_.context_size(); ++i) {
|
|
const bool is_valid_token = i >= 0 && i < tokens.size();
|
|
if (is_valid_token) {
|
|
const UnicodeText value =
|
|
UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false);
|
|
for (auto codepoint : value) {
|
|
if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) {
|
|
++num_supported;
|
|
}
|
|
++num_total;
|
|
}
|
|
}
|
|
}
|
|
return static_cast<float>(num_supported) / static_cast<float>(num_total);
|
|
}
|
|
|
|
bool FeatureProcessor::IsCodepointInRanges(
|
|
int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const {
|
|
auto it = std::lower_bound(codepoint_ranges.begin(), codepoint_ranges.end(),
|
|
codepoint,
|
|
[](const CodepointRange& range, int codepoint) {
|
|
// This function compares range with the
|
|
// codepoint for the purpose of finding the first
|
|
// greater or equal range. Because of the use of
|
|
// std::lower_bound it needs to return true when
|
|
// range < codepoint; the first time it will
|
|
// return false the lower bound is found and
|
|
// returned.
|
|
//
|
|
// It might seem weird that the condition is
|
|
// range.end <= codepoint here but when codepoint
|
|
// == range.end it means it's actually just
|
|
// outside of the range, thus the range is less
|
|
// than the codepoint.
|
|
return range.end <= codepoint;
|
|
});
|
|
if (it != codepoint_ranges.end() && it->start <= codepoint &&
|
|
it->end > codepoint) {
|
|
return true;
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
int FeatureProcessor::CollectionToLabel(const std::string& collection) const {
|
|
const auto it = collection_to_label_.find(collection);
|
|
if (it == collection_to_label_.end()) {
|
|
return options_.default_collection();
|
|
} else {
|
|
return it->second;
|
|
}
|
|
}
|
|
|
|
std::string FeatureProcessor::LabelToCollection(int label) const {
|
|
if (label >= 0 && label < collection_to_label_.size()) {
|
|
return options_.collections(label);
|
|
} else {
|
|
return GetDefaultCollection();
|
|
}
|
|
}
|
|
|
|
void FeatureProcessor::MakeLabelMaps() {
|
|
for (int i = 0; i < options_.collections().size(); ++i) {
|
|
collection_to_label_[options_.collections(i)] = i;
|
|
}
|
|
|
|
int selection_label_id = 0;
|
|
for (int l = 0; l < (options_.max_selection_span() + 1); ++l) {
|
|
for (int r = 0; r < (options_.max_selection_span() + 1); ++r) {
|
|
if (!options_.selection_reduced_output_space() ||
|
|
r + l <= options_.max_selection_span()) {
|
|
TokenSpan token_span{l, r};
|
|
selection_to_label_[token_span] = selection_label_id;
|
|
label_to_selection_.push_back(token_span);
|
|
++selection_label_id;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void FeatureProcessor::TokenizeAndFindClick(const std::string& context,
|
|
CodepointSpan input_span,
|
|
std::vector<Token>* tokens,
|
|
int* click_pos) const {
|
|
TC_CHECK(tokens != nullptr);
|
|
*tokens = Tokenize(context);
|
|
|
|
if (options_.split_tokens_on_selection_boundaries()) {
|
|
internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
|
|
}
|
|
|
|
if (options_.only_use_line_with_click()) {
|
|
internal::StripTokensFromOtherLines(context, input_span, tokens);
|
|
}
|
|
|
|
int local_click_pos;
|
|
if (click_pos == nullptr) {
|
|
click_pos = &local_click_pos;
|
|
}
|
|
*click_pos = FindCenterToken(input_span, *tokens);
|
|
}
|
|
|
|
namespace internal {
|
|
|
|
void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
|
|
std::vector<Token>* tokens, int* click_pos) {
|
|
int right_context_needed = relative_click_span.second + context_size;
|
|
if (*click_pos + right_context_needed + 1 >= tokens->size()) {
|
|
// Pad max the context size.
|
|
const int num_pad_tokens = std::min(
|
|
context_size, static_cast<int>(*click_pos + right_context_needed + 1 -
|
|
tokens->size()));
|
|
std::vector<Token> pad_tokens(num_pad_tokens);
|
|
tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end());
|
|
} else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) {
|
|
// Strip unused tokens.
|
|
auto it = tokens->begin();
|
|
std::advance(it, *click_pos + right_context_needed + 1);
|
|
tokens->erase(it, tokens->end());
|
|
}
|
|
|
|
int left_context_needed = relative_click_span.first + context_size;
|
|
if (*click_pos < left_context_needed) {
|
|
// Pad max the context size.
|
|
const int num_pad_tokens =
|
|
std::min(context_size, left_context_needed - *click_pos);
|
|
std::vector<Token> pad_tokens(num_pad_tokens);
|
|
tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end());
|
|
*click_pos += num_pad_tokens;
|
|
} else if (*click_pos > left_context_needed) {
|
|
// Strip unused tokens.
|
|
auto it = tokens->begin();
|
|
std::advance(it, *click_pos - left_context_needed);
|
|
*click_pos -= it - tokens->begin();
|
|
tokens->erase(tokens->begin(), it);
|
|
}
|
|
}
|
|
|
|
} // namespace internal
|
|
|
|
bool FeatureProcessor::ExtractFeatures(
|
|
const std::string& context, CodepointSpan input_span,
|
|
TokenSpan relative_click_span, const FeatureVectorFn& feature_vector_fn,
|
|
int feature_vector_size, std::vector<Token>* tokens, int* click_pos,
|
|
std::unique_ptr<CachedFeatures>* cached_features) const {
|
|
TokenizeAndFindClick(context, input_span, tokens, click_pos);
|
|
|
|
// If the default click method failed, let's try to do sub-token matching
|
|
// before we fail.
|
|
if (*click_pos == kInvalidIndex) {
|
|
*click_pos = internal::CenterTokenFromClick(input_span, *tokens);
|
|
if (*click_pos == kInvalidIndex) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
internal::StripOrPadTokens(relative_click_span, options_.context_size(),
|
|
tokens, click_pos);
|
|
|
|
if (options_.min_supported_codepoint_ratio() > 0) {
|
|
const float supported_codepoint_ratio =
|
|
SupportedCodepointsRatio(*click_pos, *tokens);
|
|
if (supported_codepoint_ratio < options_.min_supported_codepoint_ratio()) {
|
|
TC_LOG(INFO) << "Not enough supported codepoints in the context: "
|
|
<< supported_codepoint_ratio;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
std::vector<std::vector<int>> sparse_features(tokens->size());
|
|
std::vector<std::vector<float>> dense_features(tokens->size());
|
|
for (int i = 0; i < tokens->size(); ++i) {
|
|
const Token& token = (*tokens)[i];
|
|
if (!feature_extractor_.Extract(token, token.IsContainedInSpan(input_span),
|
|
&(sparse_features[i]),
|
|
&(dense_features[i]))) {
|
|
TC_LOG(ERROR) << "Could not extract token's features: " << token;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
cached_features->reset(new CachedFeatures(
|
|
*tokens, options_.context_size(), sparse_features, dense_features,
|
|
feature_vector_fn, feature_vector_size));
|
|
|
|
if (*cached_features == nullptr) {
|
|
return false;
|
|
}
|
|
|
|
if (options_.feature_version() == 0) {
|
|
(*cached_features)
|
|
->SetV0FeatureMode(feature_vector_size -
|
|
feature_extractor_.DenseFeaturesCount());
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool FeatureProcessor::ICUTokenize(const std::string& context,
|
|
std::vector<Token>* result) const {
|
|
icu::ErrorCode status;
|
|
icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(context);
|
|
std::unique_ptr<icu::BreakIterator> break_iterator(
|
|
icu::BreakIterator::createWordInstance(icu::Locale("en"), status));
|
|
if (!status.isSuccess()) {
|
|
TC_LOG(ERROR) << "Break iterator did not initialize properly: "
|
|
<< status.errorName();
|
|
return false;
|
|
}
|
|
|
|
break_iterator->setText(unicode_text);
|
|
|
|
size_t last_break_index = 0;
|
|
size_t break_index = 0;
|
|
size_t last_unicode_index = 0;
|
|
size_t unicode_index = 0;
|
|
while ((break_index = break_iterator->next()) != icu::BreakIterator::DONE) {
|
|
icu::UnicodeString token(unicode_text, last_break_index,
|
|
break_index - last_break_index);
|
|
int token_length = token.countChar32();
|
|
unicode_index = last_unicode_index + token_length;
|
|
|
|
std::string token_utf8;
|
|
token.toUTF8String(token_utf8);
|
|
|
|
bool is_whitespace = true;
|
|
for (int i = 0; i < token.length(); i++) {
|
|
if (!u_isWhitespace(token.char32At(i))) {
|
|
is_whitespace = false;
|
|
}
|
|
}
|
|
|
|
if (!is_whitespace || options_.icu_preserve_whitespace_tokens()) {
|
|
result->push_back(Token(token_utf8, last_unicode_index, unicode_index));
|
|
}
|
|
|
|
last_break_index = break_index;
|
|
last_unicode_index = unicode_index;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
void FeatureProcessor::InternalRetokenize(const std::string& context,
|
|
std::vector<Token>* tokens) const {
|
|
const UnicodeText unicode_text =
|
|
UTF8ToUnicodeText(context, /*do_copy=*/false);
|
|
|
|
std::vector<Token> result;
|
|
CodepointSpan span(-1, -1);
|
|
for (Token& token : *tokens) {
|
|
const UnicodeText unicode_token_value =
|
|
UTF8ToUnicodeText(token.value, /*do_copy=*/false);
|
|
bool should_retokenize = true;
|
|
for (const int codepoint : unicode_token_value) {
|
|
if (!IsCodepointInRanges(codepoint,
|
|
internal_tokenizer_codepoint_ranges_)) {
|
|
should_retokenize = false;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (should_retokenize) {
|
|
if (span.first < 0) {
|
|
span.first = token.start;
|
|
}
|
|
span.second = token.end;
|
|
} else {
|
|
TokenizeSubstring(unicode_text, span, &result);
|
|
span.first = -1;
|
|
result.emplace_back(std::move(token));
|
|
}
|
|
}
|
|
TokenizeSubstring(unicode_text, span, &result);
|
|
|
|
*tokens = std::move(result);
|
|
}
|
|
|
|
void FeatureProcessor::TokenizeSubstring(const UnicodeText& unicode_text,
|
|
CodepointSpan span,
|
|
std::vector<Token>* result) const {
|
|
if (span.first < 0) {
|
|
// There is no span to tokenize.
|
|
return;
|
|
}
|
|
|
|
// Extract the substring.
|
|
UnicodeText::const_iterator it_begin = unicode_text.begin();
|
|
for (int i = 0; i < span.first; ++i) {
|
|
++it_begin;
|
|
}
|
|
UnicodeText::const_iterator it_end = unicode_text.begin();
|
|
for (int i = 0; i < span.second; ++i) {
|
|
++it_end;
|
|
}
|
|
const std::string text = unicode_text.UTF8Substring(it_begin, it_end);
|
|
|
|
// Run the tokenizer and update the token bounds to reflect the offset of the
|
|
// substring.
|
|
std::vector<Token> tokens = tokenizer_.Tokenize(text);
|
|
for (Token& token : tokens) {
|
|
token.start += span.first;
|
|
token.end += span.first;
|
|
result->emplace_back(std::move(token));
|
|
}
|
|
}
|
|
|
|
} // namespace libtextclassifier
|