152 lines
5 KiB
C++
152 lines
5 KiB
C++
/*
|
|
* Copyright (C) 2017 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
// Model parameter loading.
|
|
|
|
#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_
|
|
#define LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_
|
|
|
|
#include "common/embedding-network.h"
|
|
#include "common/memory_image/embedding-network-params-from-image.h"
|
|
#include "smartselect/text-classification-model.pb.h"
|
|
|
|
namespace libtextclassifier {
|
|
|
|
class EmbeddingParams : public nlp_core::EmbeddingNetworkParamsFromImage {
|
|
public:
|
|
EmbeddingParams(const void* start, uint64 num_bytes, int context_size)
|
|
: EmbeddingNetworkParamsFromImage(start, num_bytes),
|
|
context_size_(context_size) {}
|
|
|
|
int embeddings_size() const override { return context_size_ * 2 + 1; }
|
|
|
|
int embedding_num_features_size() const override {
|
|
return context_size_ * 2 + 1;
|
|
}
|
|
|
|
int embedding_num_features(int i) const override { return 1; }
|
|
|
|
int embeddings_num_rows(int i) const override {
|
|
return EmbeddingNetworkParamsFromImage::embeddings_num_rows(0);
|
|
};
|
|
|
|
int embeddings_num_cols(int i) const override {
|
|
return EmbeddingNetworkParamsFromImage::embeddings_num_cols(0);
|
|
};
|
|
|
|
const void* embeddings_weights(int i) const override {
|
|
return EmbeddingNetworkParamsFromImage::embeddings_weights(0);
|
|
};
|
|
|
|
nlp_core::QuantizationType embeddings_quant_type(int i) const override {
|
|
return EmbeddingNetworkParamsFromImage::embeddings_quant_type(0);
|
|
}
|
|
|
|
const nlp_core::float16* embeddings_quant_scales(int i) const override {
|
|
return EmbeddingNetworkParamsFromImage::embeddings_quant_scales(0);
|
|
}
|
|
|
|
private:
|
|
int context_size_;
|
|
};
|
|
|
|
// Loads and holds the parameters of the inference network.
|
|
//
|
|
// This class overrides a couple of methods of EmbeddingNetworkParamsFromImage
|
|
// because we only have one embedding matrix for all positions of context,
|
|
// whereas the original class would have a separate one for each.
|
|
class ModelParams : public nlp_core::EmbeddingNetworkParamsFromImage {
|
|
public:
|
|
const FeatureProcessorOptions& GetFeatureProcessorOptions() const {
|
|
return feature_processor_options_;
|
|
}
|
|
|
|
const SelectionModelOptions& GetSelectionModelOptions() const {
|
|
return selection_options_;
|
|
}
|
|
|
|
const SharingModelOptions& GetSharingModelOptions() const {
|
|
return sharing_options_;
|
|
}
|
|
|
|
std::shared_ptr<EmbeddingParams> GetEmbeddingParams() const {
|
|
return embedding_params_;
|
|
}
|
|
|
|
protected:
|
|
int embeddings_size() const override {
|
|
return embedding_params_->embeddings_size();
|
|
}
|
|
|
|
int embedding_num_features_size() const override {
|
|
return embedding_params_->embedding_num_features_size();
|
|
}
|
|
|
|
int embedding_num_features(int i) const override {
|
|
return embedding_params_->embedding_num_features(i);
|
|
}
|
|
|
|
int embeddings_num_rows(int i) const override {
|
|
return embedding_params_->embeddings_num_rows(i);
|
|
};
|
|
|
|
int embeddings_num_cols(int i) const override {
|
|
return embedding_params_->embeddings_num_cols(i);
|
|
};
|
|
|
|
const void* embeddings_weights(int i) const override {
|
|
return embedding_params_->embeddings_weights(i);
|
|
};
|
|
|
|
nlp_core::QuantizationType embeddings_quant_type(int i) const override {
|
|
return embedding_params_->embeddings_quant_type(i);
|
|
}
|
|
|
|
const nlp_core::float16* embeddings_quant_scales(int i) const override {
|
|
return embedding_params_->embeddings_quant_scales(i);
|
|
}
|
|
|
|
private:
|
|
friend ModelParams* ModelParamsBuilder(
|
|
const void* start, uint64 num_bytes,
|
|
std::shared_ptr<EmbeddingParams> external_embedding_params);
|
|
|
|
ModelParams(const void* start, uint64 num_bytes,
|
|
std::shared_ptr<EmbeddingParams> embedding_params,
|
|
const SelectionModelOptions& selection_options,
|
|
const SharingModelOptions& sharing_options,
|
|
const FeatureProcessorOptions& feature_processor_options)
|
|
: EmbeddingNetworkParamsFromImage(start, num_bytes),
|
|
selection_options_(selection_options),
|
|
sharing_options_(sharing_options),
|
|
feature_processor_options_(feature_processor_options),
|
|
context_size_(feature_processor_options_.context_size()),
|
|
embedding_params_(std::move(embedding_params)) {}
|
|
|
|
SelectionModelOptions selection_options_;
|
|
SharingModelOptions sharing_options_;
|
|
FeatureProcessorOptions feature_processor_options_;
|
|
int context_size_;
|
|
std::shared_ptr<EmbeddingParams> embedding_params_;
|
|
};
|
|
|
|
ModelParams* ModelParamsBuilder(
|
|
const void* start, uint64 num_bytes,
|
|
std::shared_ptr<EmbeddingParams> external_embedding_params);
|
|
|
|
} // namespace libtextclassifier
|
|
|
|
#endif // LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_
|