131 lines
5 KiB
C++
131 lines
5 KiB
C++
/*
|
|
* Copyright (C) 2017 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
// Inference code for the feed-forward text classification models.
|
|
|
|
#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
|
|
#define LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
|
|
|
|
#include <memory>
|
|
#include <set>
|
|
#include <string>
|
|
|
|
#include "base.h"
|
|
#include "common/embedding-network.h"
|
|
#include "common/feature-extractor.h"
|
|
#include "common/memory_image/embedding-network-params-from-image.h"
|
|
#include "common/mmap.h"
|
|
#include "smartselect/feature-processor.h"
|
|
#include "smartselect/model-params.h"
|
|
#include "smartselect/text-classification-model.pb.h"
|
|
#include "smartselect/types.h"
|
|
|
|
namespace libtextclassifier {
|
|
|
|
// SmartSelection/Sharing feed-forward model.
|
|
class TextClassificationModel {
|
|
public:
|
|
// Loads TextClassificationModel from given file given by an int
|
|
// file descriptor.
|
|
explicit TextClassificationModel(int fd);
|
|
|
|
// Bit flags for the input selection.
|
|
enum SelectionInputFlags { SELECTION_IS_URL = 0x1, SELECTION_IS_EMAIL = 0x2 };
|
|
|
|
// Runs inference for given a context and current selection (i.e. index
|
|
// of the first and one past last selected characters (utf8 codepoint
|
|
// offsets)). Returns the indices (utf8 codepoint offsets) of the selection
|
|
// beginning character and one past selection end character.
|
|
// Returns the original click_indices if an error occurs.
|
|
// NOTE: The selection indices are passed in and returned in terms of
|
|
// UTF8 codepoints (not bytes).
|
|
// Requires that the model is a smart selection model.
|
|
CodepointSpan SuggestSelection(const std::string& context,
|
|
CodepointSpan click_indices) const;
|
|
|
|
// Classifies the selected text given the context string.
|
|
// Requires that the model is a smart sharing model.
|
|
// Returns an empty result if an error occurs.
|
|
std::vector<std::pair<std::string, float>> ClassifyText(
|
|
const std::string& context, CodepointSpan click_indices,
|
|
int input_flags = 0) const;
|
|
|
|
protected:
|
|
// Removes punctuation from the beginning and end of the selection and returns
|
|
// the new selection span.
|
|
CodepointSpan StripPunctuation(CodepointSpan selection,
|
|
const std::string& context) const;
|
|
|
|
// During evaluation we need access to the feature processor.
|
|
FeatureProcessor* SelectionFeatureProcessor() const {
|
|
return selection_feature_processor_.get();
|
|
}
|
|
|
|
// Collection name when url hint is accepted.
|
|
const std::string kUrlHintCollection = "url";
|
|
|
|
// Collection name when email hint is accepted.
|
|
const std::string kEmailHintCollection = "email";
|
|
|
|
// Collection name for other.
|
|
const std::string kOtherCollection = "other";
|
|
|
|
// Collection name for phone.
|
|
const std::string kPhoneCollection = "phone";
|
|
|
|
SelectionModelOptions selection_options_;
|
|
SharingModelOptions sharing_options_;
|
|
|
|
private:
|
|
bool LoadModels(const nlp_core::MmapHandle& mmap_handle);
|
|
|
|
nlp_core::EmbeddingNetwork::Vector InferInternal(
|
|
const std::string& context, CodepointSpan span,
|
|
const FeatureProcessor& feature_processor,
|
|
const nlp_core::EmbeddingNetwork& network,
|
|
const FeatureVectorFn& feature_vector_fn,
|
|
std::vector<CodepointSpan>* selection_label_spans) const;
|
|
|
|
// Returns a selection suggestion with a score.
|
|
std::pair<CodepointSpan, float> SuggestSelectionInternal(
|
|
const std::string& context, CodepointSpan click_indices) const;
|
|
|
|
// Returns a selection suggestion and makes sure it's symmetric. Internally
|
|
// runs several times SuggestSelectionInternal.
|
|
CodepointSpan SuggestSelectionSymmetrical(const std::string& context,
|
|
CodepointSpan click_indices) const;
|
|
|
|
bool initialized_;
|
|
nlp_core::ScopedMmap mmap_;
|
|
std::unique_ptr<ModelParams> selection_params_;
|
|
std::unique_ptr<FeatureProcessor> selection_feature_processor_;
|
|
std::unique_ptr<nlp_core::EmbeddingNetwork> selection_network_;
|
|
FeatureVectorFn selection_feature_fn_;
|
|
std::unique_ptr<FeatureProcessor> sharing_feature_processor_;
|
|
std::unique_ptr<ModelParams> sharing_params_;
|
|
std::unique_ptr<nlp_core::EmbeddingNetwork> sharing_network_;
|
|
FeatureVectorFn sharing_feature_fn_;
|
|
|
|
std::set<int> punctuation_to_strip_;
|
|
};
|
|
|
|
// Parses the merged image given as a file descriptor, and reads
|
|
// the ModelOptions proto from the selection model.
|
|
bool ReadSelectionModelOptions(int fd, ModelOptions* model_options);
|
|
|
|
} // namespace libtextclassifier
|
|
|
|
#endif // LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
|