245 lines
8.5 KiB
C++
245 lines
8.5 KiB
C++
/*
|
|
* Copyright (C) 2017 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_
|
|
#define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_
|
|
|
|
#include <algorithm>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "common/embedding-network-package.pb.h"
|
|
#include "common/embedding-network-params.h"
|
|
#include "common/embedding-network.pb.h"
|
|
#include "common/float16.h"
|
|
#include "common/little-endian-data.h"
|
|
#include "common/task-context.h"
|
|
#include "common/task-spec.pb.h"
|
|
#include "util/base/integral_types.h"
|
|
#include "util/base/logging.h"
|
|
|
|
namespace libtextclassifier {
|
|
namespace nlp_core {
|
|
|
|
// A wrapper class that owns and exposes an EmbeddingNetworkProto message via
|
|
// the EmbeddingNetworkParams interface.
|
|
//
|
|
// The EmbeddingNetworkParams interface encapsulates the weight matrices of the
|
|
// embeddings, hidden and softmax layers as transposed versions of their
|
|
// counterparts in the original EmbeddingNetworkProto. The matrices in the proto
|
|
// passed to this class' constructor must likewise already have been transposed.
|
|
// See embedding-network-params.h for details.
|
|
class EmbeddingNetworkParamsFromProto : public EmbeddingNetworkParams {
|
|
public:
|
|
// Constructor that takes ownership of the provided proto. See class-comment
|
|
// for the requirements that certain weight matrices must satisfy.
|
|
explicit EmbeddingNetworkParamsFromProto(
|
|
std::unique_ptr<EmbeddingNetworkProto> proto)
|
|
: proto_(std::move(proto)) {
|
|
valid_ = true;
|
|
|
|
// Initialize these vectors to have the required number of elements
|
|
// regardless of quantization status. This is to support the unlikely case
|
|
// where only some embeddings are quantized, along with the fact that
|
|
// EmbeddingNetworkParams interface accesses them by index.
|
|
embeddings_quant_scales_.resize(proto_->embeddings_size());
|
|
embeddings_quant_weights_.resize(proto_->embeddings_size());
|
|
for (int i = 0; i < proto_->embeddings_size(); ++i) {
|
|
MatrixParams *embedding = proto_->mutable_embeddings()->Mutable(i);
|
|
if (!embedding->is_quantized()) {
|
|
continue;
|
|
}
|
|
|
|
bool success = FillVectorFromDataBytesInLittleEndian(
|
|
embedding->bytes_for_quantized_values(),
|
|
embedding->rows() * embedding->cols(),
|
|
&(embeddings_quant_weights_[i]));
|
|
if (!success) {
|
|
TC_LOG(ERROR) << "Problem decoding quant_weights for embeddings #" << i;
|
|
valid_ = false;
|
|
}
|
|
|
|
// The repeated field bytes_for_quantized_values uses a lot of memory.
|
|
// Since it's no longer necessary (and we own the proto), we clear it.
|
|
embedding->clear_bytes_for_quantized_values();
|
|
|
|
success = FillVectorFromDataBytesInLittleEndian(
|
|
embedding->bytes_for_col_scales(),
|
|
embedding->rows(),
|
|
&(embeddings_quant_scales_[i]));
|
|
if (!success) {
|
|
TC_LOG(ERROR) << "Problem decoding col_scales for embeddings #" << i;
|
|
valid_ = false;
|
|
}
|
|
|
|
// See comments for clear_bytes_for_quantized_values().
|
|
embedding->clear_bytes_for_col_scales();
|
|
}
|
|
}
|
|
|
|
const TaskSpec *GetTaskSpec() override {
|
|
if (!proto_) {
|
|
return nullptr;
|
|
}
|
|
auto extension_id = task_spec_in_embedding_network_proto;
|
|
if (proto_->HasExtension(extension_id)) {
|
|
return &(proto_->GetExtension(extension_id));
|
|
} else {
|
|
TC_LOG(ERROR) << "Unable to get TaskSpec from EmbeddingNetworkProto";
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
// Returns true if these params are valid. False otherwise (e.g., if the
|
|
// original proto data was corrupted).
|
|
bool is_valid() { return valid_; }
|
|
|
|
protected:
|
|
int embeddings_size() const override { return proto_->embeddings_size(); }
|
|
|
|
int embeddings_num_rows(int i) const override {
|
|
TC_DCHECK(InRange(i, embeddings_size()));
|
|
return proto_->embeddings(i).rows();
|
|
}
|
|
|
|
int embeddings_num_cols(int i) const override {
|
|
TC_DCHECK(InRange(i, embeddings_size()));
|
|
return proto_->embeddings(i).cols();
|
|
}
|
|
|
|
const void *embeddings_weights(int i) const override {
|
|
TC_DCHECK(InRange(i, embeddings_size()));
|
|
if (proto_->embeddings(i).is_quantized()) {
|
|
return static_cast<const void *>(embeddings_quant_weights_.at(i).data());
|
|
} else {
|
|
return static_cast<const void *>(proto_->embeddings(i).value().data());
|
|
}
|
|
}
|
|
|
|
QuantizationType embeddings_quant_type(int i) const override {
|
|
TC_DCHECK(InRange(i, embeddings_size()));
|
|
return proto_->embeddings(i).is_quantized() ? QuantizationType::UINT8
|
|
: QuantizationType::NONE;
|
|
}
|
|
|
|
const float16 *embeddings_quant_scales(int i) const override {
|
|
TC_DCHECK(InRange(i, embeddings_size()));
|
|
return proto_->embeddings(i).is_quantized()
|
|
? embeddings_quant_scales_.at(i).data()
|
|
: nullptr;
|
|
}
|
|
|
|
int hidden_size() const override { return proto_->hidden_size(); }
|
|
|
|
int hidden_num_rows(int i) const override {
|
|
TC_DCHECK(InRange(i, hidden_size()));
|
|
return proto_->hidden(i).rows();
|
|
}
|
|
|
|
int hidden_num_cols(int i) const override {
|
|
TC_DCHECK(InRange(i, hidden_size()));
|
|
return proto_->hidden(i).cols();
|
|
}
|
|
|
|
const void *hidden_weights(int i) const override {
|
|
TC_DCHECK(InRange(i, hidden_size()));
|
|
return proto_->hidden(i).value().data();
|
|
}
|
|
|
|
int hidden_bias_size() const override { return proto_->hidden_bias_size(); }
|
|
|
|
int hidden_bias_num_rows(int i) const override {
|
|
TC_DCHECK(InRange(i, hidden_bias_size()));
|
|
return proto_->hidden_bias(i).rows();
|
|
}
|
|
|
|
int hidden_bias_num_cols(int i) const override {
|
|
TC_DCHECK(InRange(i, hidden_bias_size()));
|
|
return proto_->hidden_bias(i).cols();
|
|
}
|
|
|
|
const void *hidden_bias_weights(int i) const override {
|
|
TC_DCHECK(InRange(i, hidden_bias_size()));
|
|
return proto_->hidden_bias(i).value().data();
|
|
}
|
|
|
|
int softmax_size() const override { return proto_->has_softmax() ? 1 : 0; }
|
|
|
|
int softmax_num_rows(int i) const override {
|
|
TC_DCHECK(InRange(i, softmax_size()));
|
|
return proto_->has_softmax() ? proto_->softmax().rows() : 0;
|
|
}
|
|
|
|
int softmax_num_cols(int i) const override {
|
|
TC_DCHECK(InRange(i, softmax_size()));
|
|
return proto_->has_softmax() ? proto_->softmax().cols() : 0;
|
|
}
|
|
|
|
const void *softmax_weights(int i) const override {
|
|
TC_DCHECK(InRange(i, softmax_size()));
|
|
return proto_->has_softmax() ? proto_->softmax().value().data() : nullptr;
|
|
}
|
|
|
|
int softmax_bias_size() const override {
|
|
return proto_->has_softmax_bias() ? 1 : 0;
|
|
}
|
|
|
|
int softmax_bias_num_rows(int i) const override {
|
|
TC_DCHECK(InRange(i, softmax_bias_size()));
|
|
return proto_->has_softmax_bias() ? proto_->softmax_bias().rows() : 0;
|
|
}
|
|
|
|
int softmax_bias_num_cols(int i) const override {
|
|
TC_DCHECK(InRange(i, softmax_bias_size()));
|
|
return proto_->has_softmax_bias() ? proto_->softmax_bias().cols() : 0;
|
|
}
|
|
|
|
const void *softmax_bias_weights(int i) const override {
|
|
TC_DCHECK(InRange(i, softmax_bias_size()));
|
|
return proto_->has_softmax_bias() ? proto_->softmax_bias().value().data()
|
|
: nullptr;
|
|
}
|
|
|
|
int embedding_num_features_size() const override {
|
|
return proto_->embedding_num_features_size();
|
|
}
|
|
|
|
int embedding_num_features(int i) const override {
|
|
TC_DCHECK(InRange(i, embedding_num_features_size()));
|
|
return proto_->embedding_num_features(i);
|
|
}
|
|
|
|
private:
|
|
std::unique_ptr<EmbeddingNetworkProto> proto_;
|
|
|
|
// True if these params are valid. May be false if the original proto was
|
|
// corrupted. We prefer to set this to false to CHECK-failing.
|
|
bool valid_;
|
|
|
|
// When the embeddings are quantized, these members are used to store their
|
|
// numeric values using the types expected by the rest of the class. Due to
|
|
// technical reasons, the proto stores this info using larger types (i.e.,
|
|
// more bits).
|
|
std::vector<std::vector<float16>> embeddings_quant_scales_;
|
|
std::vector<std::vector<uint8>> embeddings_quant_weights_;
|
|
};
|
|
|
|
} // namespace nlp_core
|
|
} // namespace libtextclassifier
|
|
|
|
#endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_
|