753 lines
24 KiB
C++
753 lines
24 KiB
C++
#include <algorithm>
|
|
#include <stdexcept>
|
|
|
|
#include "trie.h"
|
|
|
|
namespace marisa {
|
|
namespace {
|
|
|
|
template <typename T, typename U>
|
|
class PredictCallback {
|
|
public:
|
|
PredictCallback(T key_ids, U keys, std::size_t max_num_results)
|
|
: key_ids_(key_ids), keys_(keys),
|
|
max_num_results_(max_num_results), num_results_(0) {}
|
|
PredictCallback(const PredictCallback &callback)
|
|
: key_ids_(callback.key_ids_), keys_(callback.keys_),
|
|
max_num_results_(callback.max_num_results_),
|
|
num_results_(callback.num_results_) {}
|
|
|
|
bool operator()(marisa::UInt32 key_id, const std::string &key) {
|
|
if (key_ids_.is_valid()) {
|
|
key_ids_.insert(num_results_, key_id);
|
|
}
|
|
if (keys_.is_valid()) {
|
|
keys_.insert(num_results_, key);
|
|
}
|
|
return ++num_results_ < max_num_results_;
|
|
}
|
|
|
|
private:
|
|
T key_ids_;
|
|
U keys_;
|
|
const std::size_t max_num_results_;
|
|
std::size_t num_results_;
|
|
|
|
// Disallows assignment.
|
|
PredictCallback &operator=(const PredictCallback &);
|
|
};
|
|
|
|
} // namespace
|
|
|
|
std::string Trie::restore(UInt32 key_id) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
|
|
std::string key;
|
|
restore_(key_id, &key);
|
|
return key;
|
|
}
|
|
|
|
void Trie::restore(UInt32 key_id, std::string *key) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
|
|
MARISA_THROW_IF(key == NULL, MARISA_PARAM_ERROR);
|
|
restore_(key_id, key);
|
|
}
|
|
|
|
std::size_t Trie::restore(UInt32 key_id, char *key_buf,
|
|
std::size_t key_buf_size) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
|
|
MARISA_THROW_IF((key_buf == NULL) && (key_buf_size != 0),
|
|
MARISA_PARAM_ERROR);
|
|
return restore_(key_id, key_buf, key_buf_size);
|
|
}
|
|
|
|
UInt32 Trie::lookup(const char *str) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
|
|
return lookup_<CQuery>(CQuery(str));
|
|
}
|
|
|
|
UInt32 Trie::lookup(const char *ptr, std::size_t length) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
|
|
return lookup_<const Query &>(Query(ptr, length));
|
|
}
|
|
|
|
std::size_t Trie::find(const char *str,
|
|
UInt32 *key_ids, std::size_t *key_lengths,
|
|
std::size_t max_num_results) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
|
|
return find_<CQuery>(CQuery(str),
|
|
MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
|
|
}
|
|
|
|
std::size_t Trie::find(const char *ptr, std::size_t length,
|
|
UInt32 *key_ids, std::size_t *key_lengths,
|
|
std::size_t max_num_results) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
|
|
return find_<const Query &>(Query(ptr, length),
|
|
MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
|
|
}
|
|
|
|
std::size_t Trie::find(const char *str,
|
|
std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
|
|
std::size_t max_num_results) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
|
|
return find_<CQuery>(CQuery(str),
|
|
MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
|
|
}
|
|
|
|
std::size_t Trie::find(const char *ptr, std::size_t length,
|
|
std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
|
|
std::size_t max_num_results) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
|
|
return find_<const Query &>(Query(ptr, length),
|
|
MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
|
|
}
|
|
|
|
UInt32 Trie::find_first(const char *str,
|
|
std::size_t *key_length) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
|
|
return find_first_<CQuery>(CQuery(str), key_length);
|
|
}
|
|
|
|
UInt32 Trie::find_first(const char *ptr, std::size_t length,
|
|
std::size_t *key_length) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
|
|
return find_first_<const Query &>(Query(ptr, length), key_length);
|
|
}
|
|
|
|
UInt32 Trie::find_last(const char *str,
|
|
std::size_t *key_length) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
|
|
return find_last_<CQuery>(CQuery(str), key_length);
|
|
}
|
|
|
|
UInt32 Trie::find_last(const char *ptr, std::size_t length,
|
|
std::size_t *key_length) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
|
|
return find_last_<const Query &>(Query(ptr, length), key_length);
|
|
}
|
|
|
|
std::size_t Trie::predict(const char *str,
|
|
UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
|
|
return (keys == NULL) ?
|
|
predict_breadth_first(str, key_ids, keys, max_num_results) :
|
|
predict_depth_first(str, key_ids, keys, max_num_results);
|
|
}
|
|
|
|
std::size_t Trie::predict(const char *ptr, std::size_t length,
|
|
UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
|
|
return (keys == NULL) ?
|
|
predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
|
|
predict_depth_first(ptr, length, key_ids, keys, max_num_results);
|
|
}
|
|
|
|
std::size_t Trie::predict(const char *str,
|
|
std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
|
|
std::size_t max_num_results) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
|
|
return (keys == NULL) ?
|
|
predict_breadth_first(str, key_ids, keys, max_num_results) :
|
|
predict_depth_first(str, key_ids, keys, max_num_results);
|
|
}
|
|
|
|
std::size_t Trie::predict(const char *ptr, std::size_t length,
|
|
std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
|
|
std::size_t max_num_results) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
|
|
return (keys == NULL) ?
|
|
predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
|
|
predict_depth_first(ptr, length, key_ids, keys, max_num_results);
|
|
}
|
|
|
|
std::size_t Trie::predict_breadth_first(const char *str,
|
|
UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
|
|
return predict_breadth_first_<CQuery>(CQuery(str),
|
|
MakeContainer(key_ids), MakeContainer(keys), max_num_results);
|
|
}
|
|
|
|
std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
|
|
UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
|
|
return predict_breadth_first_<const Query &>(Query(ptr, length),
|
|
MakeContainer(key_ids), MakeContainer(keys), max_num_results);
|
|
}
|
|
|
|
std::size_t Trie::predict_breadth_first(const char *str,
|
|
std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
|
|
std::size_t max_num_results) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
|
|
return predict_breadth_first_<CQuery>(CQuery(str),
|
|
MakeContainer(key_ids), MakeContainer(keys), max_num_results);
|
|
}
|
|
|
|
std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
|
|
std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
|
|
std::size_t max_num_results) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
|
|
return predict_breadth_first_<const Query &>(Query(ptr, length),
|
|
MakeContainer(key_ids), MakeContainer(keys), max_num_results);
|
|
}
|
|
|
|
std::size_t Trie::predict_depth_first(const char *str,
|
|
UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
|
|
return predict_depth_first_<CQuery>(CQuery(str),
|
|
MakeContainer(key_ids), MakeContainer(keys), max_num_results);
|
|
}
|
|
|
|
std::size_t Trie::predict_depth_first(const char *ptr, std::size_t length,
|
|
UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
|
|
return predict_depth_first_<const Query &>(Query(ptr, length),
|
|
MakeContainer(key_ids), MakeContainer(keys), max_num_results);
|
|
}
|
|
|
|
std::size_t Trie::predict_depth_first(
|
|
const char *str, std::vector<UInt32> *key_ids,
|
|
std::vector<std::string> *keys, std::size_t max_num_results) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
|
|
return predict_depth_first_<CQuery>(CQuery(str),
|
|
MakeContainer(key_ids), MakeContainer(keys), max_num_results);
|
|
}
|
|
|
|
std::size_t Trie::predict_depth_first(
|
|
const char *ptr, std::size_t length, std::vector<UInt32> *key_ids,
|
|
std::vector<std::string> *keys, std::size_t max_num_results) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
|
|
return predict_depth_first_<const Query &>(Query(ptr, length),
|
|
MakeContainer(key_ids), MakeContainer(keys), max_num_results);
|
|
}
|
|
|
|
void Trie::restore_(UInt32 key_id, std::string *key) const {
|
|
const std::size_t start_pos = key->length();
|
|
UInt32 node = key_id_to_node(key_id);
|
|
while (node != 0) {
|
|
if (has_link(node)) {
|
|
const std::size_t prev_pos = key->length();
|
|
if (has_trie()) {
|
|
trie_->trie_restore(get_link(node), key);
|
|
} else {
|
|
tail_restore(node, key);
|
|
}
|
|
std::reverse(key->begin() + prev_pos, key->end());
|
|
} else {
|
|
*key += labels_[node];
|
|
}
|
|
node = get_parent(node);
|
|
}
|
|
std::reverse(key->begin() + start_pos, key->end());
|
|
}
|
|
|
|
void Trie::trie_restore(UInt32 node, std::string *key) const {
|
|
do {
|
|
if (has_link(node)) {
|
|
if (has_trie()) {
|
|
trie_->trie_restore(get_link(node), key);
|
|
} else {
|
|
tail_restore(node, key);
|
|
}
|
|
} else {
|
|
*key += labels_[node];
|
|
}
|
|
node = get_parent(node);
|
|
} while (node != 0);
|
|
}
|
|
|
|
void Trie::tail_restore(UInt32 node, std::string *key) const {
|
|
const UInt32 link_id = link_flags_.rank1(node);
|
|
const UInt32 offset = (links_[link_id] * 256) + labels_[node];
|
|
if (tail_.mode() == MARISA_BINARY_TAIL) {
|
|
const UInt32 length = (links_[link_id + 1] * 256)
|
|
+ labels_[link_flags_.select1(link_id + 1)] - offset;
|
|
key->append(reinterpret_cast<const char *>(tail_[offset]), length);
|
|
} else {
|
|
key->append(reinterpret_cast<const char *>(tail_[offset]));
|
|
}
|
|
}
|
|
|
|
std::size_t Trie::restore_(UInt32 key_id, char *key_buf,
|
|
std::size_t key_buf_size) const {
|
|
std::size_t pos = 0;
|
|
UInt32 node = key_id_to_node(key_id);
|
|
while (node != 0) {
|
|
if (has_link(node)) {
|
|
const std::size_t prev_pos = pos;
|
|
if (has_trie()) {
|
|
trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
|
|
} else {
|
|
tail_restore(node, key_buf, key_buf_size, pos);
|
|
}
|
|
if (pos < key_buf_size) {
|
|
std::reverse(key_buf + prev_pos, key_buf + pos);
|
|
}
|
|
} else {
|
|
if (pos < key_buf_size) {
|
|
key_buf[pos] = labels_[node];
|
|
}
|
|
++pos;
|
|
}
|
|
node = get_parent(node);
|
|
}
|
|
if (pos < key_buf_size) {
|
|
key_buf[pos] = '\0';
|
|
std::reverse(key_buf, key_buf + pos);
|
|
}
|
|
return pos;
|
|
}
|
|
|
|
void Trie::trie_restore(UInt32 node, char *key_buf,
|
|
std::size_t key_buf_size, std::size_t &pos) const {
|
|
do {
|
|
if (has_link(node)) {
|
|
if (has_trie()) {
|
|
trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
|
|
} else {
|
|
tail_restore(node, key_buf, key_buf_size, pos);
|
|
}
|
|
} else {
|
|
if (pos < key_buf_size) {
|
|
key_buf[pos] = labels_[node];
|
|
}
|
|
++pos;
|
|
}
|
|
node = get_parent(node);
|
|
} while (node != 0);
|
|
}
|
|
|
|
void Trie::tail_restore(UInt32 node, char *key_buf,
|
|
std::size_t key_buf_size, std::size_t &pos) const {
|
|
const UInt32 link_id = link_flags_.rank1(node);
|
|
const UInt32 offset = (links_[link_id] * 256) + labels_[node];
|
|
if (tail_.mode() == MARISA_BINARY_TAIL) {
|
|
const UInt8 *ptr = tail_[offset];
|
|
const UInt32 length = (links_[link_id + 1] * 256)
|
|
+ labels_[link_flags_.select1(link_id + 1)] - offset;
|
|
for (UInt32 i = 0; i < length; ++i) {
|
|
if (pos < key_buf_size) {
|
|
key_buf[pos] = ptr[i];
|
|
}
|
|
++pos;
|
|
}
|
|
} else {
|
|
for (const UInt8 *str = tail_[offset]; *str != '\0'; ++str) {
|
|
if (pos < key_buf_size) {
|
|
key_buf[pos] = *str;
|
|
}
|
|
++pos;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
UInt32 Trie::lookup_(T query) const {
|
|
UInt32 node = 0;
|
|
std::size_t pos = 0;
|
|
while (!query.ends_at(pos)) {
|
|
if (!find_child<T>(node, query, pos)) {
|
|
return notfound();
|
|
}
|
|
}
|
|
return terminal_flags_[node] ? node_to_key_id(node) : notfound();
|
|
}
|
|
|
|
template <typename T>
|
|
std::size_t Trie::trie_match(UInt32 node, T query,
|
|
std::size_t pos) const {
|
|
if (has_link(node)) {
|
|
std::size_t next_pos;
|
|
if (has_trie()) {
|
|
next_pos = trie_->trie_match<T>(get_link(node), query, pos);
|
|
} else {
|
|
next_pos = tail_match<T>(node, get_link_id(node), query, pos);
|
|
}
|
|
if ((next_pos == mismatch()) || (next_pos == pos)) {
|
|
return next_pos;
|
|
}
|
|
pos = next_pos;
|
|
} else if (labels_[node] != query[pos]) {
|
|
return pos;
|
|
} else {
|
|
++pos;
|
|
}
|
|
node = get_parent(node);
|
|
while (node != 0) {
|
|
if (query.ends_at(pos)) {
|
|
return mismatch();
|
|
}
|
|
if (has_link(node)) {
|
|
std::size_t next_pos;
|
|
if (has_trie()) {
|
|
next_pos = trie_->trie_match<T>(get_link(node), query, pos);
|
|
} else {
|
|
next_pos = tail_match<T>(node, get_link_id(node), query, pos);
|
|
}
|
|
if ((next_pos == mismatch()) || (next_pos == pos)) {
|
|
return mismatch();
|
|
}
|
|
pos = next_pos;
|
|
} else if (labels_[node] != query[pos]) {
|
|
return mismatch();
|
|
} else {
|
|
++pos;
|
|
}
|
|
node = get_parent(node);
|
|
}
|
|
return pos;
|
|
}
|
|
|
|
template std::size_t Trie::trie_match<CQuery>(UInt32 node,
|
|
CQuery query, std::size_t pos) const;
|
|
template std::size_t Trie::trie_match<const Query &>(UInt32 node,
|
|
const Query &query, std::size_t pos) const;
|
|
|
|
template <typename T>
|
|
std::size_t Trie::tail_match(UInt32 node, UInt32 link_id,
|
|
T query, std::size_t pos) const {
|
|
const UInt32 offset = (links_[link_id] * 256) + labels_[node];
|
|
const UInt8 *ptr = tail_[offset];
|
|
if (*ptr != query[pos]) {
|
|
return pos;
|
|
} else if (tail_.mode() == MARISA_BINARY_TAIL) {
|
|
const UInt32 length = (links_[link_id + 1] * 256)
|
|
+ labels_[link_flags_.select1(link_id + 1)] - offset;
|
|
for (UInt32 i = 1; i < length; ++i) {
|
|
if (query.ends_at(pos + i) || (ptr[i] != query[pos + i])) {
|
|
return mismatch();
|
|
}
|
|
}
|
|
return pos + length;
|
|
} else {
|
|
for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
|
|
if (query.ends_at(pos) || (*ptr != query[pos])) {
|
|
return mismatch();
|
|
}
|
|
}
|
|
return pos;
|
|
}
|
|
}
|
|
|
|
template std::size_t Trie::tail_match<CQuery>(UInt32 node,
|
|
UInt32 link_id, CQuery query, std::size_t pos) const;
|
|
template std::size_t Trie::tail_match<const Query &>(UInt32 node,
|
|
UInt32 link_id, const Query &query, std::size_t pos) const;
|
|
|
|
template <typename T, typename U, typename V>
|
|
std::size_t Trie::find_(T query, U key_ids, V key_lengths,
|
|
std::size_t max_num_results) const {
|
|
if (max_num_results == 0) {
|
|
return 0;
|
|
}
|
|
std::size_t count = 0;
|
|
UInt32 node = 0;
|
|
std::size_t pos = 0;
|
|
do {
|
|
if (terminal_flags_[node]) {
|
|
if (key_ids.is_valid()) {
|
|
key_ids.insert(count, node_to_key_id(node));
|
|
}
|
|
if (key_lengths.is_valid()) {
|
|
key_lengths.insert(count, pos);
|
|
}
|
|
if (++count >= max_num_results) {
|
|
return count;
|
|
}
|
|
}
|
|
} while (!query.ends_at(pos) && find_child<T>(node, query, pos));
|
|
return count;
|
|
}
|
|
|
|
template <typename T>
|
|
UInt32 Trie::find_first_(T query, std::size_t *key_length) const {
|
|
UInt32 node = 0;
|
|
std::size_t pos = 0;
|
|
do {
|
|
if (terminal_flags_[node]) {
|
|
if (key_length != NULL) {
|
|
*key_length = pos;
|
|
}
|
|
return node_to_key_id(node);
|
|
}
|
|
} while (!query.ends_at(pos) && find_child<T>(node, query, pos));
|
|
return notfound();
|
|
}
|
|
|
|
template <typename T>
|
|
UInt32 Trie::find_last_(T query, std::size_t *key_length) const {
|
|
UInt32 node = 0;
|
|
UInt32 node_found = notfound();
|
|
std::size_t pos = 0;
|
|
std::size_t pos_found = mismatch();
|
|
do {
|
|
if (terminal_flags_[node]) {
|
|
node_found = node;
|
|
pos_found = pos;
|
|
}
|
|
} while (!query.ends_at(pos) && find_child<T>(node, query, pos));
|
|
if (node_found != notfound()) {
|
|
if (key_length != NULL) {
|
|
*key_length = pos_found;
|
|
}
|
|
return node_to_key_id(node_found);
|
|
}
|
|
return notfound();
|
|
}
|
|
|
|
template <typename T, typename U, typename V>
|
|
std::size_t Trie::predict_breadth_first_(T query, U key_ids, V keys,
|
|
std::size_t max_num_results) const {
|
|
if (max_num_results == 0) {
|
|
return 0;
|
|
}
|
|
UInt32 node = 0;
|
|
std::size_t pos = 0;
|
|
while (!query.ends_at(pos)) {
|
|
if (!predict_child<T>(node, query, pos, NULL)) {
|
|
return 0;
|
|
}
|
|
}
|
|
std::string key;
|
|
std::size_t count = 0;
|
|
if (terminal_flags_[node]) {
|
|
const UInt32 key_id = node_to_key_id(node);
|
|
if (key_ids.is_valid()) {
|
|
key_ids.insert(count, key_id);
|
|
}
|
|
if (keys.is_valid()) {
|
|
restore(key_id, &key);
|
|
keys.insert(count, key);
|
|
}
|
|
if (++count >= max_num_results) {
|
|
return count;
|
|
}
|
|
}
|
|
const UInt32 louds_pos = get_child(node);
|
|
if (!louds_[louds_pos]) {
|
|
return count;
|
|
}
|
|
UInt32 node_begin = louds_pos_to_node(louds_pos, node);
|
|
UInt32 node_end = louds_pos_to_node(get_child(node + 1), node + 1);
|
|
while (node_begin < node_end) {
|
|
const UInt32 key_id_begin = node_to_key_id(node_begin);
|
|
const UInt32 key_id_end = node_to_key_id(node_end);
|
|
if (key_ids.is_valid()) {
|
|
UInt32 temp_count = count;
|
|
for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
|
|
key_ids.insert(temp_count, key_id);
|
|
if (++temp_count >= max_num_results) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
if (keys.is_valid()) {
|
|
UInt32 temp_count = count;
|
|
for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
|
|
key.clear();
|
|
restore(key_id, &key);
|
|
keys.insert(temp_count, key);
|
|
if (++temp_count >= max_num_results) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
count += key_id_end - key_id_begin;
|
|
if (count >= max_num_results) {
|
|
return max_num_results;
|
|
}
|
|
node_begin = louds_pos_to_node(get_child(node_begin), node_begin);
|
|
node_end = louds_pos_to_node(get_child(node_end), node_end);
|
|
}
|
|
return count;
|
|
}
|
|
|
|
template <typename T, typename U, typename V>
|
|
std::size_t Trie::predict_depth_first_(T query, U key_ids, V keys,
|
|
std::size_t max_num_results) const {
|
|
if (max_num_results == 0) {
|
|
return 0;
|
|
} else if (keys.is_valid()) {
|
|
PredictCallback<U, V> callback(key_ids, keys, max_num_results);
|
|
return predict_callback_(query, callback);
|
|
}
|
|
|
|
UInt32 node = 0;
|
|
std::size_t pos = 0;
|
|
while (!query.ends_at(pos)) {
|
|
if (!predict_child<T>(node, query, pos, NULL)) {
|
|
return 0;
|
|
}
|
|
}
|
|
std::size_t count = 0;
|
|
if (terminal_flags_[node]) {
|
|
if (key_ids.is_valid()) {
|
|
key_ids.insert(count, node_to_key_id(node));
|
|
}
|
|
if (++count >= max_num_results) {
|
|
return count;
|
|
}
|
|
}
|
|
Cell cell;
|
|
cell.set_louds_pos(get_child(node));
|
|
if (!louds_[cell.louds_pos()]) {
|
|
return count;
|
|
}
|
|
cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
|
|
cell.set_key_id(node_to_key_id(cell.node()));
|
|
Vector<Cell> stack;
|
|
stack.push_back(cell);
|
|
std::size_t stack_pos = 1;
|
|
while (stack_pos != 0) {
|
|
Cell &cur = stack[stack_pos - 1];
|
|
if (!louds_[cur.louds_pos()]) {
|
|
cur.set_louds_pos(cur.louds_pos() + 1);
|
|
--stack_pos;
|
|
continue;
|
|
}
|
|
cur.set_louds_pos(cur.louds_pos() + 1);
|
|
if (terminal_flags_[cur.node()]) {
|
|
if (key_ids.is_valid()) {
|
|
key_ids.insert(count, cur.key_id());
|
|
}
|
|
if (++count >= max_num_results) {
|
|
return count;
|
|
}
|
|
cur.set_key_id(cur.key_id() + 1);
|
|
}
|
|
if (stack_pos == stack.size()) {
|
|
cell.set_louds_pos(get_child(cur.node()));
|
|
cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
|
|
cell.set_key_id(node_to_key_id(cell.node()));
|
|
stack.push_back(cell);
|
|
}
|
|
stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
|
|
++stack_pos;
|
|
}
|
|
return count;
|
|
}
|
|
|
|
template <typename T>
|
|
std::size_t Trie::trie_prefix_match(UInt32 node, T query,
|
|
std::size_t pos, std::string *key) const {
|
|
if (has_link(node)) {
|
|
std::size_t next_pos;
|
|
if (has_trie()) {
|
|
next_pos = trie_->trie_prefix_match<T>(get_link(node), query, pos, key);
|
|
} else {
|
|
next_pos = tail_prefix_match<T>(
|
|
node, get_link_id(node), query, pos, key);
|
|
}
|
|
if ((next_pos == mismatch()) || (next_pos == pos)) {
|
|
return next_pos;
|
|
}
|
|
pos = next_pos;
|
|
} else if (labels_[node] != query[pos]) {
|
|
return pos;
|
|
} else {
|
|
++pos;
|
|
}
|
|
node = get_parent(node);
|
|
while (node != 0) {
|
|
if (query.ends_at(pos)) {
|
|
if (key != NULL) {
|
|
trie_restore(node, key);
|
|
}
|
|
return pos;
|
|
}
|
|
if (has_link(node)) {
|
|
std::size_t next_pos;
|
|
if (has_trie()) {
|
|
next_pos = trie_->trie_prefix_match<T>(
|
|
get_link(node), query, pos, key);
|
|
} else {
|
|
next_pos = tail_prefix_match<T>(
|
|
node, get_link_id(node), query, pos, key);
|
|
}
|
|
if ((next_pos == mismatch()) || (next_pos == pos)) {
|
|
return next_pos;
|
|
}
|
|
pos = next_pos;
|
|
} else if (labels_[node] != query[pos]) {
|
|
return mismatch();
|
|
} else {
|
|
++pos;
|
|
}
|
|
node = get_parent(node);
|
|
}
|
|
return pos;
|
|
}
|
|
|
|
template std::size_t Trie::trie_prefix_match<CQuery>(UInt32 node,
|
|
CQuery query, std::size_t pos, std::string *key) const;
|
|
template std::size_t Trie::trie_prefix_match<const Query &>(UInt32 node,
|
|
const Query &query, std::size_t pos, std::string *key) const;
|
|
|
|
template <typename T>
|
|
std::size_t Trie::tail_prefix_match(UInt32 node, UInt32 link_id,
|
|
T query, std::size_t pos, std::string *key) const {
|
|
const UInt32 offset = (links_[link_id] * 256) + labels_[node];
|
|
const UInt8 *ptr = tail_[offset];
|
|
if (*ptr != query[pos]) {
|
|
return pos;
|
|
} else if (tail_.mode() == MARISA_BINARY_TAIL) {
|
|
const UInt32 length = (links_[link_id + 1] * 256)
|
|
+ labels_[link_flags_.select1(link_id + 1)] - offset;
|
|
for (UInt32 i = 1; i < length; ++i) {
|
|
if (query.ends_at(pos + i)) {
|
|
if (key != NULL) {
|
|
key->append(reinterpret_cast<const char *>(ptr + i), length - i);
|
|
}
|
|
return pos + i;
|
|
} else if (ptr[i] != query[pos + i]) {
|
|
return mismatch();
|
|
}
|
|
}
|
|
return pos + length;
|
|
} else {
|
|
for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
|
|
if (query.ends_at(pos)) {
|
|
if (key != NULL) {
|
|
key->append(reinterpret_cast<const char *>(ptr));
|
|
}
|
|
return pos;
|
|
} else if (*ptr != query[pos]) {
|
|
return mismatch();
|
|
}
|
|
}
|
|
return pos;
|
|
}
|
|
}
|
|
|
|
template std::size_t Trie::tail_prefix_match<CQuery>(
|
|
UInt32 node, UInt32 link_id,
|
|
CQuery query, std::size_t pos, std::string *key) const;
|
|
template std::size_t Trie::tail_prefix_match<const Query &>(
|
|
UInt32 node, UInt32 link_id,
|
|
const Query &query, std::size_t pos, std::string *key) const;
|
|
|
|
} // namespace marisa
|