351 lines
9.5 KiB
C++
351 lines
9.5 KiB
C++
#ifndef MARISA_TRIE_INLINE_H_
|
|
#define MARISA_TRIE_INLINE_H_
|
|
|
|
#include <stdexcept>
|
|
|
|
#include "cell.h"
|
|
|
|
namespace marisa {
|
|
|
|
inline std::string Trie::operator[](UInt32 key_id) const {
|
|
std::string key;
|
|
restore(key_id, &key);
|
|
return key;
|
|
}
|
|
|
|
inline UInt32 Trie::operator[](const char *str) const {
|
|
return lookup(str);
|
|
}
|
|
|
|
inline UInt32 Trie::operator[](const std::string &str) const {
|
|
return lookup(str);
|
|
}
|
|
|
|
inline UInt32 Trie::lookup(const std::string &str) const {
|
|
return lookup(str.c_str(), str.length());
|
|
}
|
|
|
|
inline std::size_t Trie::find(const std::string &str,
|
|
UInt32 *key_ids, std::size_t *key_lengths,
|
|
std::size_t max_num_results) const {
|
|
return find(str.c_str(), str.length(),
|
|
key_ids, key_lengths, max_num_results);
|
|
}
|
|
|
|
inline std::size_t Trie::find(const std::string &str,
|
|
std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
|
|
std::size_t max_num_results) const {
|
|
return find(str.c_str(), str.length(),
|
|
key_ids, key_lengths, max_num_results);
|
|
}
|
|
|
|
inline UInt32 Trie::find_first(const std::string &str,
|
|
std::size_t *key_length) const {
|
|
return find_first(str.c_str(), str.length(), key_length);
|
|
}
|
|
|
|
inline UInt32 Trie::find_last(const std::string &str,
|
|
std::size_t *key_length) const {
|
|
return find_last(str.c_str(), str.length(), key_length);
|
|
}
|
|
|
|
template <typename T>
|
|
inline std::size_t Trie::find_callback(const char *str,
|
|
T callback) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
|
|
return find_callback_<CQuery>(CQuery(str), callback);
|
|
}
|
|
|
|
template <typename T>
|
|
inline std::size_t Trie::find_callback(const char *ptr, std::size_t length,
|
|
T callback) const {
|
|
MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
|
|
return find_callback_<const Query &>(Query(ptr, length), callback);
|
|
}
|
|
|
|
template <typename T>
|
|
inline std::size_t Trie::find_callback(const std::string &str,
|
|
T callback) const {
|
|
return find_callback(str.c_str(), str.length(), callback);
|
|
}
|
|
|
|
inline std::size_t Trie::predict(const std::string &str,
|
|
UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
|
|
return predict(str.c_str(), str.length(), key_ids, keys, max_num_results);
|
|
}
|
|
|
|
inline std::size_t Trie::predict(const std::string &str,
|
|
std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
|
|
std::size_t max_num_results) const {
|
|
return predict(str.c_str(), str.length(), key_ids, keys, max_num_results);
|
|
}
|
|
|
|
inline std::size_t Trie::predict_breadth_first(const std::string &str,
|
|
UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
|
|
return predict_breadth_first(str.c_str(), str.length(),
|
|
key_ids, keys, max_num_results);
|
|
}
|
|
|
|
inline std::size_t Trie::predict_breadth_first(const std::string &str,
|
|
std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
|
|
std::size_t max_num_results) const {
|
|
return predict_breadth_first(str.c_str(), str.length(),
|
|
key_ids, keys, max_num_results);
|
|
}
|
|
|
|
inline std::size_t Trie::predict_depth_first(const std::string &str,
|
|
UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
|
|
return predict_depth_first(str.c_str(), str.length(),
|
|
key_ids, keys, max_num_results);
|
|
}
|
|
|
|
inline std::size_t Trie::predict_depth_first(const std::string &str,
|
|
std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
|
|
std::size_t max_num_results) const {
|
|
return predict_depth_first(str.c_str(), str.length(),
|
|
key_ids, keys, max_num_results);
|
|
}
|
|
|
|
template <typename T>
|
|
inline std::size_t Trie::predict_callback(
|
|
const char *str, T callback) const {
|
|
return predict_callback_<CQuery>(CQuery(str), callback);
|
|
}
|
|
|
|
template <typename T>
|
|
inline std::size_t Trie::predict_callback(
|
|
const char *ptr, std::size_t length,
|
|
T callback) const {
|
|
return predict_callback_<const Query &>(Query(ptr, length), callback);
|
|
}
|
|
|
|
template <typename T>
|
|
inline std::size_t Trie::predict_callback(
|
|
const std::string &str, T callback) const {
|
|
return predict_callback(str.c_str(), str.length(), callback);
|
|
}
|
|
|
|
inline bool Trie::empty() const {
|
|
return louds_.empty();
|
|
}
|
|
|
|
inline std::size_t Trie::num_keys() const {
|
|
return num_keys_;
|
|
}
|
|
|
|
inline UInt32 Trie::notfound() {
|
|
return MARISA_NOT_FOUND;
|
|
}
|
|
|
|
inline std::size_t Trie::mismatch() {
|
|
return MARISA_MISMATCH;
|
|
}
|
|
|
|
template <typename T>
|
|
inline bool Trie::find_child(UInt32 &node, T query,
|
|
std::size_t &pos) const {
|
|
UInt32 louds_pos = get_child(node);
|
|
if (!louds_[louds_pos]) {
|
|
return false;
|
|
}
|
|
node = louds_pos_to_node(louds_pos, node);
|
|
UInt32 link_id = MARISA_UINT32_MAX;
|
|
do {
|
|
if (has_link(node)) {
|
|
if (link_id == MARISA_UINT32_MAX) {
|
|
link_id = get_link_id(node);
|
|
} else {
|
|
++link_id;
|
|
}
|
|
std::size_t next_pos = has_trie() ?
|
|
trie_->trie_match<T>(get_link(node, link_id), query, pos) :
|
|
tail_match<T>(node, link_id, query, pos);
|
|
if (next_pos == mismatch()) {
|
|
return false;
|
|
} else if (next_pos != pos) {
|
|
pos = next_pos;
|
|
return true;
|
|
}
|
|
} else if (labels_[node] == query[pos]) {
|
|
++pos;
|
|
return true;
|
|
}
|
|
++node;
|
|
++louds_pos;
|
|
} while (louds_[louds_pos]);
|
|
return false;
|
|
}
|
|
|
|
template <typename T, typename U>
|
|
std::size_t Trie::find_callback_(T query, U callback) const {
|
|
std::size_t count = 0;
|
|
UInt32 node = 0;
|
|
std::size_t pos = 0;
|
|
do {
|
|
if (terminal_flags_[node]) {
|
|
++count;
|
|
if (!callback(node_to_key_id(node), pos)) {
|
|
return count;
|
|
}
|
|
}
|
|
} while (!query.ends_at(pos) && find_child<T>(node, query, pos));
|
|
return count;
|
|
}
|
|
|
|
template <typename T>
|
|
inline bool Trie::predict_child(UInt32 &node, T query, std::size_t &pos,
|
|
std::string *key) const {
|
|
UInt32 louds_pos = get_child(node);
|
|
if (!louds_[louds_pos]) {
|
|
return false;
|
|
}
|
|
node = louds_pos_to_node(louds_pos, node);
|
|
UInt32 link_id = MARISA_UINT32_MAX;
|
|
do {
|
|
if (has_link(node)) {
|
|
if (link_id == MARISA_UINT32_MAX) {
|
|
link_id = get_link_id(node);
|
|
} else {
|
|
++link_id;
|
|
}
|
|
std::size_t next_pos = has_trie() ?
|
|
trie_->trie_prefix_match<T>(
|
|
get_link(node, link_id), query, pos, key) :
|
|
tail_prefix_match<T>(node, link_id, query, pos, key);
|
|
if (next_pos == mismatch()) {
|
|
return false;
|
|
} else if (next_pos != pos) {
|
|
pos = next_pos;
|
|
return true;
|
|
}
|
|
} else if (labels_[node] == query[pos]) {
|
|
++pos;
|
|
return true;
|
|
}
|
|
++node;
|
|
++louds_pos;
|
|
} while (louds_[louds_pos]);
|
|
return false;
|
|
}
|
|
|
|
template <typename T, typename U>
|
|
std::size_t Trie::predict_callback_(T query, U callback) const {
|
|
std::string key;
|
|
UInt32 node = 0;
|
|
std::size_t pos = 0;
|
|
while (!query.ends_at(pos)) {
|
|
if (!predict_child<T>(node, query, pos, &key)) {
|
|
return 0;
|
|
}
|
|
}
|
|
query.insert(&key);
|
|
std::size_t count = 0;
|
|
if (terminal_flags_[node]) {
|
|
++count;
|
|
if (!callback(node_to_key_id(node), key)) {
|
|
return count;
|
|
}
|
|
}
|
|
Cell cell;
|
|
cell.set_louds_pos(get_child(node));
|
|
if (!louds_[cell.louds_pos()]) {
|
|
return count;
|
|
}
|
|
cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
|
|
cell.set_key_id(node_to_key_id(cell.node()));
|
|
cell.set_length(key.length());
|
|
Vector<Cell> stack;
|
|
stack.push_back(cell);
|
|
std::size_t stack_pos = 1;
|
|
while (stack_pos != 0) {
|
|
Cell &cur = stack[stack_pos - 1];
|
|
if (!louds_[cur.louds_pos()]) {
|
|
cur.set_louds_pos(cur.louds_pos() + 1);
|
|
--stack_pos;
|
|
continue;
|
|
}
|
|
cur.set_louds_pos(cur.louds_pos() + 1);
|
|
key.resize(cur.length());
|
|
if (has_link(cur.node())) {
|
|
if (has_trie()) {
|
|
trie_->trie_restore(get_link(cur.node()), &key);
|
|
} else {
|
|
tail_restore(cur.node(), &key);
|
|
}
|
|
} else {
|
|
key += labels_[cur.node()];
|
|
}
|
|
if (terminal_flags_[cur.node()]) {
|
|
++count;
|
|
if (!callback(cur.key_id(), key)) {
|
|
return count;
|
|
}
|
|
cur.set_key_id(cur.key_id() + 1);
|
|
}
|
|
if (stack_pos == stack.size()) {
|
|
cell.set_louds_pos(get_child(cur.node()));
|
|
cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
|
|
cell.set_key_id(node_to_key_id(cell.node()));
|
|
stack.push_back(cell);
|
|
}
|
|
stack[stack_pos].set_length(key.length());
|
|
stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
|
|
++stack_pos;
|
|
}
|
|
return count;
|
|
}
|
|
|
|
inline UInt32 Trie::key_id_to_node(UInt32 key_id) const {
|
|
return terminal_flags_.select1(key_id);
|
|
}
|
|
|
|
inline UInt32 Trie::node_to_key_id(UInt32 node) const {
|
|
return terminal_flags_.rank1(node);
|
|
}
|
|
|
|
inline UInt32 Trie::louds_pos_to_node(UInt32 louds_pos,
|
|
UInt32 parent_node) const {
|
|
return louds_pos - parent_node - 1;
|
|
}
|
|
|
|
inline UInt32 Trie::get_child(UInt32 node) const {
|
|
return louds_.select0(node) + 1;
|
|
}
|
|
|
|
inline UInt32 Trie::get_parent(UInt32 node) const {
|
|
return (node > num_first_branches_) ? (louds_.select1(node) - node - 1) : 0;
|
|
}
|
|
|
|
inline bool Trie::has_link(UInt32 node) const {
|
|
return (link_flags_.empty()) ? false : link_flags_[node];
|
|
}
|
|
|
|
inline UInt32 Trie::get_link_id(UInt32 node) const {
|
|
return link_flags_.rank1(node);
|
|
}
|
|
|
|
inline UInt32 Trie::get_link(UInt32 node) const {
|
|
return get_link(node, get_link_id(node));
|
|
}
|
|
|
|
inline UInt32 Trie::get_link(UInt32 node, UInt32 link_id) const {
|
|
return (links_[link_id] * 256) + labels_[node];
|
|
}
|
|
|
|
inline bool Trie::has_link() const {
|
|
return !link_flags_.empty();
|
|
}
|
|
|
|
inline bool Trie::has_trie() const {
|
|
return trie_.get() != NULL;
|
|
}
|
|
|
|
inline bool Trie::has_tail() const {
|
|
return !tail_.empty();
|
|
}
|
|
|
|
} // namespace marisa
|
|
|
|
#endif // MARISA_TRIE_INLINE_H_
|