106 lines
4.4 KiB
C++
106 lines
4.4 KiB
C++
/*
|
|
* Copyright (C) 2017 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
|
|
#define LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
|
|
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include "base.h"
|
|
#include "common/vector-span.h"
|
|
#include "smartselect/types.h"
|
|
|
|
namespace libtextclassifier {
|
|
|
|
// Holds state for extracting features across multiple calls and reusing them.
|
|
// Assumes that features for each Token are independent.
|
|
class CachedFeatures {
|
|
public:
|
|
// Extracts the features for the given sequence of tokens.
|
|
// - context_size: Specifies how many tokens to the left, and how many
|
|
// tokens to the right spans the context.
|
|
// - sparse_features, dense_features: Extracted features for each token.
|
|
// - feature_vector_fn: Writes features for given Token to the specified
|
|
// storage.
|
|
// NOTE: The function can assume that the underlying
|
|
// storage is initialized to all zeros.
|
|
// - feature_vector_size: Size of a feature vector for one Token.
|
|
CachedFeatures(VectorSpan<Token> tokens, int context_size,
|
|
const std::vector<std::vector<int>>& sparse_features,
|
|
const std::vector<std::vector<float>>& dense_features,
|
|
const std::function<bool(const std::vector<int>&,
|
|
const std::vector<float>&, float*)>&
|
|
feature_vector_fn,
|
|
int feature_vector_size)
|
|
: tokens_(tokens),
|
|
context_size_(context_size),
|
|
feature_vector_size_(feature_vector_size),
|
|
remap_v0_feature_vector_(false),
|
|
remap_v0_chargram_embedding_size_(-1) {
|
|
Extract(sparse_features, dense_features, feature_vector_fn);
|
|
}
|
|
|
|
// Gets a VectorSpan with the features for given click position.
|
|
bool Get(int click_pos, VectorSpan<float>* features,
|
|
VectorSpan<Token>* output_tokens);
|
|
|
|
// Turns on a compatibility mode, which re-maps the extracted features to the
|
|
// v0 feature format (where the dense features were at the end).
|
|
// WARNING: Internally v0_feature_storage_ is used as a backing buffer for
|
|
// VectorSpan<float>, so the output of Extract is valid only until the next
|
|
// call or destruction of the current CachedFeatures object.
|
|
// TODO(zilka): Remove when we'll have retrained models.
|
|
void SetV0FeatureMode(int chargram_embedding_size) {
|
|
remap_v0_feature_vector_ = true;
|
|
remap_v0_chargram_embedding_size_ = chargram_embedding_size;
|
|
v0_feature_storage_.resize(feature_vector_size_ * (context_size_ * 2 + 1));
|
|
}
|
|
|
|
protected:
|
|
// Extracts features for all tokens and stores them for later retrieval.
|
|
void Extract(const std::vector<std::vector<int>>& sparse_features,
|
|
const std::vector<std::vector<float>>& dense_features,
|
|
const std::function<bool(const std::vector<int>&,
|
|
const std::vector<float>&, float*)>&
|
|
feature_vector_fn);
|
|
|
|
// Remaps extracted features to V0 feature format. The mapping is using
|
|
// the v0_feature_storage_ as the backing storage for the mapped features.
|
|
// For each token the features consist of:
|
|
// - chargram embeddings
|
|
// - dense features
|
|
// They are concatenated together as [chargram embeddings; dense features]
|
|
// for each token independently.
|
|
// The V0 features require that the chargram embeddings for tokens are
|
|
// concatenated first together, and at the end, the dense features for the
|
|
// tokens are concatenated to it.
|
|
void RemapV0FeatureVector(VectorSpan<float>* features);
|
|
|
|
private:
|
|
const VectorSpan<Token> tokens_;
|
|
const int context_size_;
|
|
const int feature_vector_size_;
|
|
bool remap_v0_feature_vector_;
|
|
int remap_v0_chargram_embedding_size_;
|
|
|
|
std::vector<float> features_;
|
|
std::vector<float> v0_feature_storage_;
|
|
};
|
|
|
|
} // namespace libtextclassifier
|
|
|
|
#endif // LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
|