246 lines
9.1 KiB
C++
246 lines
9.1 KiB
C++
/*
|
|
* Copyright (C) 2017 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_
|
|
#define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_
|
|
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include "common/embedding-network-params.h"
|
|
#include "common/feature-extractor.h"
|
|
#include "common/vector-span.h"
|
|
#include "util/base/integral_types.h"
|
|
#include "util/base/logging.h"
|
|
#include "util/base/macros.h"
|
|
|
|
namespace libtextclassifier {
|
|
namespace nlp_core {
|
|
|
|
// Classifier using a hand-coded feed-forward neural network.
|
|
//
|
|
// No gradient computation, just inference.
|
|
//
|
|
// Classification works as follows:
|
|
//
|
|
// Discrete features -> Embeddings -> Concatenation -> Hidden+ -> Softmax
|
|
//
|
|
// In words: given some discrete features, this class extracts the embeddings
|
|
// for these features, concatenates them, passes them through one or two hidden
|
|
// layers (each layer uses Relu) and next through a softmax layer that computes
|
|
// an unnormalized score for each possible class. Note: there is always a
|
|
// softmax layer.
|
|
class EmbeddingNetwork {
|
|
public:
|
|
// Class used to represent an embedding matrix. Each row is the embedding on
|
|
// a vocabulary element. Number of columns = number of embedding dimensions.
|
|
class EmbeddingMatrix {
|
|
public:
|
|
explicit EmbeddingMatrix(const EmbeddingNetworkParams::Matrix source_matrix)
|
|
: rows_(source_matrix.rows),
|
|
cols_(source_matrix.cols),
|
|
quant_type_(source_matrix.quant_type),
|
|
data_(source_matrix.elements),
|
|
row_size_in_bytes_(GetRowSizeInBytes(cols_, quant_type_)),
|
|
quant_scales_(source_matrix.quant_scales) {}
|
|
|
|
// Returns vocabulary size; one embedding for each vocabulary element.
|
|
int size() const { return rows_; }
|
|
|
|
// Returns number of weights in embedding of each vocabulary element.
|
|
int dim() const { return cols_; }
|
|
|
|
// Returns quantization type for this embedding matrix.
|
|
QuantizationType quant_type() const { return quant_type_; }
|
|
|
|
// Gets embedding for k-th vocabulary element: on return, sets *data to
|
|
// point to the embedding weights and *scale to the quantization scale (1.0
|
|
// if no quantization).
|
|
void get_embedding(int k, const void **data, float *scale) const {
|
|
if ((k < 0) || (k >= size())) {
|
|
TC_LOG(ERROR) << "Index outside [0, " << size() << "): " << k;
|
|
|
|
// In debug mode, crash. In prod, pretend that k is 0.
|
|
TC_DCHECK(false);
|
|
k = 0;
|
|
}
|
|
*data = reinterpret_cast<const char *>(data_) + k * row_size_in_bytes_;
|
|
if (quant_type_ == QuantizationType::NONE) {
|
|
*scale = 1.0;
|
|
} else {
|
|
*scale = Float16To32(quant_scales_[k]);
|
|
}
|
|
}
|
|
|
|
private:
|
|
static int GetRowSizeInBytes(int cols, QuantizationType quant_type) {
|
|
switch (quant_type) {
|
|
case QuantizationType::NONE:
|
|
return cols * sizeof(float);
|
|
case QuantizationType::UINT8:
|
|
return cols * sizeof(uint8);
|
|
default:
|
|
TC_LOG(ERROR) << "Unknown quant type: "
|
|
<< static_cast<int>(quant_type);
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
// Vocabulary size.
|
|
const int rows_;
|
|
|
|
// Number of elements in each embedding.
|
|
const int cols_;
|
|
|
|
const QuantizationType quant_type_;
|
|
|
|
// Pointer to the embedding weights, in row-major order. This is a pointer
|
|
// to an array of floats / uint8, depending on the quantization type.
|
|
// Not owned.
|
|
const void *const data_;
|
|
|
|
// Number of bytes for one row. Used to jump to next row in data_.
|
|
const int row_size_in_bytes_;
|
|
|
|
// Pointer to quantization scales. nullptr if no quantization. Otherwise,
|
|
// quant_scales_[i] is scale for embedding of i-th vocabulary element.
|
|
const float16 *const quant_scales_;
|
|
|
|
TC_DISALLOW_COPY_AND_ASSIGN(EmbeddingMatrix);
|
|
};
|
|
|
|
// An immutable vector that doesn't own the memory that stores the underlying
|
|
// floats. Can be used e.g., as a wrapper around model weights stored in the
|
|
// static memory.
|
|
class VectorWrapper {
|
|
public:
|
|
VectorWrapper() : VectorWrapper(nullptr, 0) {}
|
|
|
|
// Constructs a vector wrapper around the size consecutive floats that start
|
|
// at address data. Note: the underlying data should be alive for at least
|
|
// the lifetime of this VectorWrapper object. That's trivially true if data
|
|
// points to statically allocated data :)
|
|
VectorWrapper(const float *data, int size) : data_(data), size_(size) {}
|
|
|
|
int size() const { return size_; }
|
|
|
|
const float *data() const { return data_; }
|
|
|
|
private:
|
|
const float *data_; // Not owned.
|
|
int size_;
|
|
|
|
// Doesn't own anything, so it can be copied and assigned at will :)
|
|
};
|
|
|
|
typedef std::vector<VectorWrapper> Matrix;
|
|
typedef std::vector<float> Vector;
|
|
|
|
// Constructs an embedding network using the parameters from model.
|
|
//
|
|
// Note: model should stay alive for at least the lifetime of this
|
|
// EmbeddingNetwork object.
|
|
explicit EmbeddingNetwork(const EmbeddingNetworkParams *model);
|
|
|
|
virtual ~EmbeddingNetwork() {}
|
|
|
|
// Returns true if this EmbeddingNetwork object has been correctly constructed
|
|
// and is ready to use. Idea: in case of errors, mark this EmbeddingNetwork
|
|
// object as invalid, but do not crash.
|
|
bool is_valid() const { return valid_; }
|
|
|
|
// Runs forward computation to fill scores with unnormalized output unit
|
|
// scores. This is useful for making predictions.
|
|
//
|
|
// Returns true on success, false on error (e.g., if !is_valid()).
|
|
bool ComputeFinalScores(const std::vector<FeatureVector> &features,
|
|
Vector *scores) const;
|
|
|
|
// Same as above, but allows specification of extra neural network inputs that
|
|
// will be appended to the embedding vector build from features.
|
|
bool ComputeFinalScores(const std::vector<FeatureVector> &features,
|
|
const std::vector<float> extra_inputs,
|
|
Vector *scores) const;
|
|
|
|
// Constructs the concatenated input embedding vector in place in output
|
|
// vector concat. Returns true on success, false on error.
|
|
bool ConcatEmbeddings(const std::vector<FeatureVector> &features,
|
|
Vector *concat) const;
|
|
|
|
// Sums embeddings for all features from |feature_vector| and adds result
|
|
// to values from the array pointed-to by |output|. Embeddings for continuous
|
|
// features are weighted by the feature weight.
|
|
//
|
|
// NOTE: output should point to an array of EmbeddingSize(es_index) floats.
|
|
bool GetEmbedding(const FeatureVector &feature_vector, int es_index,
|
|
float *embedding) const;
|
|
|
|
// Runs the feed-forward neural network for |input| and computes logits for
|
|
// softmax layer.
|
|
bool ComputeLogits(const Vector &input, Vector *scores) const;
|
|
|
|
// Same as above but uses a view of the feature vector.
|
|
bool ComputeLogits(const VectorSpan<float> &input, Vector *scores) const;
|
|
|
|
// Returns the size (the number of columns) of the embedding space es_index.
|
|
int EmbeddingSize(int es_index) const;
|
|
|
|
protected:
|
|
// Builds an embedding for given feature vector, and places it from
|
|
// concat_offset to the concat vector.
|
|
bool GetEmbeddingInternal(const FeatureVector &feature_vector,
|
|
EmbeddingMatrix *embedding_matrix,
|
|
int concat_offset, float *concat,
|
|
int embedding_size) const;
|
|
|
|
// Templated function that computes the logit scores given the concatenated
|
|
// input embeddings.
|
|
bool ComputeLogitsInternal(const VectorSpan<float> &concat,
|
|
Vector *scores) const;
|
|
|
|
// Computes the softmax scores (prior to normalization) from the concatenated
|
|
// representation. Returns true on success, false on error.
|
|
template <typename ScaleAdderClass>
|
|
bool FinishComputeFinalScoresInternal(const VectorSpan<float> &concat,
|
|
Vector *scores) const;
|
|
|
|
// Set to true on successful construction, false otherwise.
|
|
bool valid_ = false;
|
|
|
|
// Network parameters.
|
|
|
|
// One weight matrix for each embedding space.
|
|
std::vector<std::unique_ptr<EmbeddingMatrix>> embedding_matrices_;
|
|
|
|
// concat_offset_[i] is the input layer offset for i-th embedding space.
|
|
std::vector<int> concat_offset_;
|
|
|
|
// Size of the input ("concatenation") layer.
|
|
int concat_layer_size_;
|
|
|
|
// One weight matrix and one vector of bias weights for each hiden layer.
|
|
std::vector<Matrix> hidden_weights_;
|
|
std::vector<VectorWrapper> hidden_bias_;
|
|
|
|
// Weight matrix and bias vector for the softmax layer.
|
|
Matrix softmax_weights_;
|
|
VectorWrapper softmax_bias_;
|
|
};
|
|
|
|
} // namespace nlp_core
|
|
} // namespace libtextclassifier
|
|
|
|
#endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_
|