197 lines
7.2 KiB
C++
197 lines
7.2 KiB
C++
/*
|
|
* Copyright (C) 2017 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_
|
|
#define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "common/feature-extractor.h"
|
|
#include "common/task-context.h"
|
|
#include "common/workspace.h"
|
|
#include "util/base/logging.h"
|
|
#include "util/base/macros.h"
|
|
|
|
namespace libtextclassifier {
|
|
namespace nlp_core {
|
|
|
|
// An EmbeddingFeatureExtractor manages the extraction of features for
|
|
// embedding-based models. It wraps a sequence of underlying classes of feature
|
|
// extractors, along with associated predicate maps. Each class of feature
|
|
// extractors is associated with a name, e.g., "words", "labels", "tags".
|
|
//
|
|
// The class is split between a generic abstract version,
|
|
// GenericEmbeddingFeatureExtractor (that can be initialized without knowing the
|
|
// signature of the ExtractFeatures method) and a typed version.
|
|
//
|
|
// The predicate maps must be initialized before use: they can be loaded using
|
|
// Read() or updated via UpdateMapsForExample.
|
|
class GenericEmbeddingFeatureExtractor {
|
|
public:
|
|
GenericEmbeddingFeatureExtractor() {}
|
|
virtual ~GenericEmbeddingFeatureExtractor() {}
|
|
|
|
// Get the prefix std::string to put in front of all arguments, so they don't
|
|
// conflict with other embedding models.
|
|
virtual const std::string ArgPrefix() const = 0;
|
|
|
|
// Initializes predicate maps and embedding space names that are common for
|
|
// all embedding-based feature extractors.
|
|
virtual bool Init(TaskContext *context);
|
|
|
|
// Requests workspace for the underlying feature extractors. This is
|
|
// implemented in the typed class.
|
|
virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
|
|
|
|
// Returns number of embedding spaces.
|
|
int NumEmbeddings() const { return embedding_dims_.size(); }
|
|
|
|
// Number of predicates for the embedding at a given index (vocabulary size).
|
|
// Returns -1 if index is out of bounds.
|
|
int EmbeddingSize(int index) const {
|
|
const GenericFeatureExtractor *extractor = generic_feature_extractor(index);
|
|
return (extractor == nullptr) ? -1 : extractor->GetDomainSize();
|
|
}
|
|
|
|
// Returns the dimensionality of the embedding space.
|
|
int EmbeddingDims(int index) const { return embedding_dims_[index]; }
|
|
|
|
// Accessor for embedding dims (dimensions of the embedding spaces).
|
|
const std::vector<int> &embedding_dims() const { return embedding_dims_; }
|
|
|
|
const std::vector<std::string> &embedding_fml() const {
|
|
return embedding_fml_;
|
|
}
|
|
|
|
// Get parameter name by concatenating the prefix and the original name.
|
|
std::string GetParamName(const std::string ¶m_name) const {
|
|
std::string full_name = ArgPrefix();
|
|
full_name.push_back('_');
|
|
full_name.append(param_name);
|
|
return full_name;
|
|
}
|
|
|
|
protected:
|
|
// Provides the generic class with access to the templated extractors. This is
|
|
// used to get the type information out of the feature extractor without
|
|
// knowing the specific calling arguments of the extractor itself.
|
|
// Returns nullptr for an out-of-bounds idx.
|
|
virtual const GenericFeatureExtractor *generic_feature_extractor(
|
|
int idx) const = 0;
|
|
|
|
private:
|
|
// Embedding space names for parameter sharing.
|
|
std::vector<std::string> embedding_names_;
|
|
|
|
// FML strings for each feature extractor.
|
|
std::vector<std::string> embedding_fml_;
|
|
|
|
// Size of each of the embedding spaces (maximum predicate id).
|
|
std::vector<int> embedding_sizes_;
|
|
|
|
// Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
|
|
std::vector<int> embedding_dims_;
|
|
|
|
TC_DISALLOW_COPY_AND_ASSIGN(GenericEmbeddingFeatureExtractor);
|
|
};
|
|
|
|
// Templated, object-specific implementation of the
|
|
// EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ,
|
|
// ARGS...> class that has the appropriate FeatureTraits() to ensure that
|
|
// locator type features work.
|
|
//
|
|
// Note: for backwards compatibility purposes, this always reads the FML spec
|
|
// from "<prefix>_features".
|
|
template <class EXTRACTOR, class OBJ, class... ARGS>
|
|
class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
|
|
public:
|
|
// Initializes all predicate maps, feature extractors, etc.
|
|
bool Init(TaskContext *context) override {
|
|
if (!GenericEmbeddingFeatureExtractor::Init(context)) {
|
|
return false;
|
|
}
|
|
feature_extractors_.resize(embedding_fml().size());
|
|
for (int i = 0; i < embedding_fml().size(); ++i) {
|
|
feature_extractors_[i].reset(new EXTRACTOR());
|
|
if (!feature_extractors_[i]->Parse(embedding_fml()[i])) {
|
|
return false;
|
|
}
|
|
if (!feature_extractors_[i]->Setup(context)) {
|
|
return false;
|
|
}
|
|
}
|
|
for (auto &feature_extractor : feature_extractors_) {
|
|
if (!feature_extractor->Init(context)) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// Requests workspaces from the registry. Must be called after Init(), and
|
|
// before Preprocess().
|
|
void RequestWorkspaces(WorkspaceRegistry *registry) override {
|
|
for (auto &feature_extractor : feature_extractors_) {
|
|
feature_extractor->RequestWorkspaces(registry);
|
|
}
|
|
}
|
|
|
|
// Must be called on the object one state for each sentence, before any
|
|
// feature extraction (e.g., UpdateMapsForExample, ExtractFeatures).
|
|
void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const {
|
|
for (auto &feature_extractor : feature_extractors_) {
|
|
feature_extractor->Preprocess(workspaces, obj);
|
|
}
|
|
}
|
|
|
|
// Extracts features using the extractors. Note that features must already
|
|
// be initialized to the correct number of feature extractors. No predicate
|
|
// mapping is applied.
|
|
void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
|
|
ARGS... args,
|
|
std::vector<FeatureVector> *features) const {
|
|
TC_DCHECK(features != nullptr);
|
|
TC_DCHECK_EQ(features->size(), feature_extractors_.size());
|
|
for (int i = 0; i < feature_extractors_.size(); ++i) {
|
|
(*features)[i].clear();
|
|
feature_extractors_[i]->ExtractFeatures(workspaces, obj, args...,
|
|
&(*features)[i]);
|
|
}
|
|
}
|
|
|
|
protected:
|
|
// Provides generic access to the feature extractors.
|
|
const GenericFeatureExtractor *generic_feature_extractor(
|
|
int idx) const override {
|
|
if ((idx < 0) || (idx >= feature_extractors_.size())) {
|
|
TC_LOG(ERROR) << "Out of bounds index " << idx;
|
|
TC_DCHECK(false); // Crash in debug mode.
|
|
return nullptr;
|
|
}
|
|
return feature_extractors_[idx].get();
|
|
}
|
|
|
|
private:
|
|
// Templated feature extractor class.
|
|
std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_;
|
|
};
|
|
|
|
} // namespace nlp_core
|
|
} // namespace libtextclassifier
|
|
|
|
#endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_
|