8000 Feature/ngram similarity function by Dronplane · Pull Request #11276 · arangodb/arangodb · GitHub
[go: up one dir, main page]

Skip to content

Feature/ngram similarity function #11276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion arangod/Aql/AqlFunctionFeature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ void AqlFunctionFeature::addStringFunctions() {
add({"SOUNDEX", ".", flags, &Functions::Soundex});
add({"LEVENSHTEIN_DISTANCE", ".,.", flags, &Functions::LevenshteinDistance});
add({"LEVENSHTEIN_MATCH", ".,.,.|.", flags, &Functions::LevenshteinMatch}); // (attribute, target, max distance, [include transpositions])

add({"NGRAM_MATCH", ".,.|.,.", flags, &Functions::NgramMatch}); // (attribute, target, [threshold, analyzer]) OR (attribute, target, [analyzer])
add({"NGRAM_SIMILARITY", ".,.,.", flags, &Functions::NgramSimilarity}); // (attribute, target, ngram size)
add({"NGRAM_POSITIONAL_SIMILARITY", ".,.,.", flags, &Functions::NgramPositionalSimilarity}); // (attribute, target, ngram size)
// special flags:
add({"RANDOM_TOKEN", ".", Function::makeFlags(FF::CanRunOnDBServer),
&Functions::RandomToken}); // not deterministic and not cacheable
Expand Down
199 changes: 197 additions & 2 deletions arangod/Aql/Functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,11 @@
#include "Pregel/Worker.h"
#include "IResearch/VelocyPackHelper.h"
#include "IResearch/IResearchPDP.h"
#include "IResearch/IResearchAnalyzerFeature.h"
#include "IResearch/IResearchFilterFactory.h"
#include "Random/UniformCharacter.h"
#include "Rest/Version.h"
#include "RestServer/SystemDatabaseFeature.h"
#include "Ssl/SslInterface.h"
#include "Transaction/Context.h"
#include "Transaction/Helpers.h"
Expand All @@ -73,7 +76,10 @@
#include "VocBase/KeyGenerator.h"
#include "VocBase/LogicalCollection.h"

#include "analysis/token_attributes.hpp"
#include "utils/levenshtein_utils.hpp"
#include "utils/ngram_match_utils.hpp"


#include <boost/uuid/uuid.hpp>
#include <boost/uuid/uuid_generators.hpp>
Expand Down Expand Up @@ -1578,6 +1584,195 @@ AqlValue Functions::LevenshteinDistance(ExpressionContext*, transaction::Methods
return AqlValue(AqlValueHintInt(encoded));
}


namespace {
template<bool search_semantics>
AqlValue NgramSimilarityHelper(char const* AFN, ExpressionContext* ctx, transaction::Methods* trx,
VPackFunctionParameters const& args) {
if (args.size() < 3) {
registerWarning(
ctx, AFN,
arangodb::Result{ TRI_ERROR_QUERY_FUNCTION_ARGUMENT_NUMBER_MISMATCH,
"Minimum 3 arguments are expected." });
return AqlValue(AqlValueHintNull());
}

auto const& attribute = extractFunctionParameterValue(args, 0);
if (ADB_UNLIKELY(!attribute.isString())) {
arangodb::aql::registerInvalidArgumentWarning(ctx, AFN);
return arangodb::aql::AqlValue{ arangodb::aql::AqlValueHintNull{} };
}
auto const attributeValue = arangodb::iresearch::getStringRef(attribute.slice());

auto const& target = extractFunctionParameterValue(args, 1);
if (ADB_UNLIKELY(!target.isString())) {
arangodb::aql::registerInvalidArgumentWarning(ctx, AFN);
return arangodb::aql::AqlValue{ arangodb::aql::AqlValueHintNull{} };
}
auto const targetValue = arangodb::iresearch::getStringRef(target.slice());

auto const& ngramSize = extractFunctionParameterValue(args, 2);
if (ADB_UNLIKELY(!ngramSize.isNumber())) {
arangodb::aql::registerInvalidArgumentWarning(ctx, AFN);
return arangodb::aql::AqlValue{ arangodb::aql::AqlValueHintNull{} };
}
auto const ngramSizeValue = ngramSize.toInt64();

if (ADB_UNLIKELY(ngramSizeValue < 1)) {
arangodb::aql::registerWarning(ctx, AFN,
arangodb::Result{TRI_ERROR_BAD_PARAMETER,
"Invalid ngram size. Should be 1 or greater"});
return arangodb::aql::AqlValue{ arangodb::aql::AqlValueHintNull{} };
}

auto utf32Attribute = basics::StringUtils::characterCodes(attributeValue.c_str(), attributeValue.size());
auto utf32Target = basics::StringUtils::characterCodes(targetValue.c_str(), targetValue.size());

auto const similarity =
irs::ngram_similarity<uint32_t, search_semantics>(
utf32Target.data(), utf32Target.size(),
utf32Attribute.data(), utf32Attribute.size(),
ngramSizeValue);
return AqlValue(AqlValueHintDouble(similarity));
}
}

/// Executes NGRAM_SIMILARITY based on binary ngram similarity
AqlValue Functions::NgramSimilarity(ExpressionContext* ctx, transaction::Methods* trx,
VPackFunctionParameters const& args) {
static char const* AFN = "NGRAM_SIMILARITY";
return NgramSimilarityHelper<true>(AFN, ctx, trx, args);
}

/// Executes NGRAM_POSITIONAL_SIMILARITY based on positional ngram similarity
AqlValue Functions::NgramPositionalSimilarity(ExpressionContext* ctx, transaction::Methods* trx,
VPackFunctionParameters const& args) {
static char const* AFN = "NGRAM_POSITIONAL_SIMILARITY";
return NgramSimilarityHelper<false>(AFN, ctx, trx, args);
}

/// Executes NGRAM_MATCH based on binary ngram similarity
AqlValue Functions::NgramMatch(ExpressionContext* ctx, transaction::Methods* trx,
VPackFunctionParameters const& args) {
static char const* AFN = "NGRAM_MATCH";

auto const argc = args.size();

if (argc < 3) { // for const evaluation we need analyzer to be set explicitly (we can`t access filter context)
// but we can`t set analyzer as mandatory in function AQL signature - this will break SEARCH
registerWarning(
ctx, AFN,
arangodb::Result{ TRI_ERROR_QUERY_FUNCTION_ARGUMENT_NUMBER_MISMATCH,
"Minimum 3 arguments are expected."});
return AqlValue(AqlValueHintNull());
}

auto const& attribute = extractFunctionParameterValue(args, 0);
if (ADB_UNLIKELY(!attribute.isString())) {
arangodb::aql::registerInvalidArgumentWarning(ctx, AFN);
return arangodb::aql::AqlValue{ arangodb::aql::AqlValueHintNull{} };
}
auto const attributeValue = iresearch::getStringRef(attribute.slice());

auto const& target = extractFunctionParameterValue(args, 1);
if (ADB_UNLIKELY(!target.isString())) {
arangodb::aql::registerInvalidArgumentWarning(ctx, AFN);
return arangodb::aql::AqlValue{ arangodb::aql::AqlValueHintNull{} };
}
auto const targetValue = iresearch::getStringRef(target.slice());

auto threshold = arangodb::iresearch::FilterConstants::DefaultNgramMatchThreshold;
size_t analyzerPosition = 2;
if (argc > 3) {// 4 args given. 3rd is threshold
auto const& thresholdArg = extractFunctionParameterValue(args, 2);
analyzerPosition = 3;
if (ADB_UNLIKELY(!thresholdArg.isNumber())) {
arangodb::aql::registerInvalidArgumentWarning(ctx, AFN);
return arangodb::aql::AqlValue{ arangodb::aql::AqlValueHintNull{} };
}
threshold = thresholdArg.toDouble();
if (threshold <= 0 || threshold > 1) {
arangodb::aql::registerWarning(
ctx, AFN,
arangodb::Result{TRI_ERROR_BAD_PARAMETER, "Threshold must be between 0 and 1" });
}
}

auto const& analyzerArg = extractFunctionParameterValue(args, analyzerPosition);
if (ADB_UNLIKELY(!analyzerArg.isString())) {
arangodb::aql::registerInvalidArgumentWarning(ctx, AFN);
return arangodb::aql::AqlValue{ arangodb::aql::AqlValueHintNull{} };
}
if (ADB_UNLIKELY(nullptr == trx)) {
arangodb::aql::registerWarning(ctx, AFN, TRI_ERROR_INTERNAL);
return arangodb::aql::AqlValue{ arangodb::aql::AqlValueHintNull{} };
}
auto const analyzerId = arangodb::iresearch::getStringRef(analyzerArg.slice());
auto& server = trx->vocbase().server();
if (!server.hasFeature<iresearch::IResearchAnalyzerFeature>()) {
arangodb::aql::registerWarning(ctx, AFN, TRI_ERROR_INTERNAL);
return arangodb::aql::AqlValue{ arangodb::aql::AqlValueHintNull{} };
}
auto& analyzerFeature = server.getFeature<iresearch::IResearchAnalyzerFeature>();

auto sysVocbase = server.hasFeature<arangodb::SystemDatabaseFeature>()
? server.getFeature<arangodb::SystemDatabaseFeature>().use()
: nullptr;

if (ADB_UNLIKELY(nullptr == sysVocbase)) {
arangodb::aql::registerWarning(ctx, AFN, TRI_ERROR_INTERNAL);
return arangodb::aql::AqlValue{ arangodb::aql::AqlValueHintNull{} };
}
auto analyzer = analyzerFeature.get(analyzerId, trx->vocbase(), *sysVocbase);
if (!analyzer) {
arangodb::aql::registerWarning(
ctx, AFN,
arangodb::Result{ TRI_ERROR_BAD_PARAMETER, "Unable to load requested analyzer" });
return arangodb::aql::AqlValue{ arangodb::aql::AqlValueHintNull{} };
}

auto analyzerImpl = analyzer->get();
TRI_ASSERT(analyzerImpl);
irs::term_attribute const& token = *analyzerImpl->attributes().get<irs::term_attribute>();

std::vector<irs::bstring> attrNgrams;
analyzerImpl->reset(attributeValue);
while (analyzerImpl->next()) {
attrNgrams.push_back(token.value());
}

std::vector<irs::bstring> targetNgrams;
analyzerImpl->reset(targetValue);
while (analyzerImpl->next()) {
targetNgrams.push_back(token.value());
}

// consider only non empty ngrams sets. As no ngrams emitted - means no data in index = no match
if (!targetNgrams.empty() && !attrNgrams.empty()) {
size_t thresholdMatches = (size_t)std::ceil((float_t)targetNgrams.size() * threshold);
size_t d = 0; // will store upper-left cell value for current cache row
std::vector<size_t> cache(attrNgrams.size() + 1, 0);
for (auto const& targetNgram : targetNgrams) {
size_t s_ngram_idx = 1;
for (; s_ngram_idx <= attrNgrams.size(); ++s_ngram_idx) {
size_t curMatches = d + (size_t)(attrNgrams[s_ngram_idx - 1] == targetNgram);
if (curMatches >= thresholdMatches) {
return arangodb::aql::AqlValue{ arangodb::aql::AqlValueHintBool{true} };
}
auto tmp = cache[s_ngram_idx];
cache[s_ngram_idx] =
std::max(
std::max(cache[s_ngram_idx - 1],
cache[s_ngram_idx]),
curMatches);
d = tmp;
}
}
}
return arangodb::aql::AqlValue{ arangodb::aql::AqlValueHintBool{false} };
}


/// Executes LEVENSHTEIN_MATCH
AqlValue Functions::LevenshteinMatch(ExpressionContext* ctx, transaction::Methods* trx,
VPackFunctionParameters const& args) {
Expand Down Expand Up @@ -6560,7 +6755,7 @@ AqlValue Functions::ReplaceNth(ExpressionContext* expressionContext, transaction
registerInvalidArgumentWarning(expressionContext, AFN);
return AqlValue(AqlValueHintNull());
}

if (offset.isNull(true)) {
THROW_ARANGO_EXCEPTION_PARAMS(TRI_ERROR_QUERY_FUNCTION_ARGUMENT_TYPE_MISMATCH, AFN);
}
Expand All @@ -6580,7 +6775,7 @@ AqlValue Functions::ReplaceNth(ExpressionContext* expressionContext, transaction
AqlValueMaterializer materializer(trx);
VPackSlice arraySlice = materializer.slice(baseArray, false);
VPackSlice replaceValue = materializer.slice(newValue, false);

transaction::BuilderLeaser builder(trx);
builder->openArray();

Expand Down
7 changes: 7 additions & 0 deletions arangod/Aql/Functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ struct Functions {
static AqlValue LevenshteinMatch(arangodb::aql::ExpressionContext*,
transaction::Methods*,
VPackFunctionParameters const&);
static AqlValue NgramSimilarity(ExpressionContext*, transaction::Methods*,
VPackFunctionParameters const&);
static AqlValue NgramPositionalSimilarity(ExpressionContext* ctx,
transaction::Methods*,
VPackFunctionParameters const&);
static AqlValue NgramMatch(ExpressionContext*, transaction::Methods*,
VPackFunctionParameters const&);
// Date
static AqlValue DateNow(arangodb::aql::ExpressionContext*,
transaction::Methods*, VPackFunctionParameters const&);
Expand Down
6 changes: 3 additions & 3 deletions arangod/IResearch/IResearchFeature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ arangodb::aql::AqlValue dummyFilterFunc(arangodb::aql::ExpressionContext*,
arangodb::containers::SmallVector<arangodb::aql::AqlValue> const&) {
THROW_ARANGO_EXCEPTION_MESSAGE(
TRI_ERROR_NOT_IMPLEMENTED,
"ArangoSearch filter functions EXISTS, IN_RANGE, PHRASE, NGRAM_MATCH "
"ArangoSearch filter functions EXISTS, IN_RANGE, PHRASE "
" are designed to be used only within a corresponding SEARCH statement "
"of ArangoSearch view."
" Please ensure function signature is correct.");
Expand Down Expand Up @@ -407,7 +407,6 @@ void registerFilters(arangodb::aql::AqlFunctionFeature& functions) {
addFunction(functions, { "MIN_MATCH", ".,.|.+", flags, &minMatchFunc }); // (filter expression [, filter expression, ... ], min match count)
addFunction(functions, { "BOOST", ".,.", flags, &contextFunc }); // (filter expression, boost)
addFunction(functions, { "ANALYZER", ".,.", flags, &contextFunc }); // (filter expression, analyzer)
addFunction(functions, { "NGRAM_MATCH", ".,.|.,.", flags, &dummyFilterFunc }); // (attribute, target, threshold, [analyzer]) OR (attribute, target, [analyzer])
}

namespace {
Expand Down Expand Up @@ -584,7 +583,8 @@ bool isFilter(arangodb::aql::Function const& func) noexcept {
func.implementation == &minMatchFunc ||
func.implementation == &startsWithFunc ||
func.implementation == &aql::Functions::LevenshteinMatch ||
func.implementation == &aql::Functions::Like;
func.implementation == &aql::Functions::Like ||
func.implementation == &aql::Functions::NgramMatch;
}

bool isScorer(arangodb::aql::Function const& func) noexcept {
Expand Down
8 changes: 4 additions & 4 deletions arangod/IResearch/IResearchFilterFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2237,7 +2237,7 @@ arangodb::Result getLevenshteinArguments(char const* funcName, bool isFilter,
};
}

scoringLimit = 128; // FIXME make configurable
scoringLimit = FilterConstants::DefaultScoringTermsLimit;

return {};
}
Expand Down Expand Up @@ -2623,7 +2623,7 @@ arangodb::Result fromFuncNgramMatch(
}
}

double_t threshold = 0.7;
auto threshold = FilterConstants::DefaultNgramMatchThreshold;
TRI_ASSERT(filterCtx.analyzer);
auto analyzerPool = filterCtx.analyzer;

Expand Down Expand Up @@ -2769,7 +2769,7 @@ arangodb::Result fromFuncStartsWith(
return rv;
}

size_t scoringLimit = 128; // FIXME make configurable
size_t scoringLimit = FilterConstants::DefaultScoringTermsLimit;

if (argc > 2) {
// 3rd (optional) argument defines a number of scored terms
Expand Down Expand Up @@ -2915,7 +2915,7 @@ arangodb::Result fromFuncLike(
return res;
}

const size_t scoringLimit = 128; // FIXME make configurable
const auto scoringLimit = FilterConstants::DefaultScoringTermsLimit;

if (filter) {
std::string name;
Expand Down
7 changes: 7 additions & 0 deletions arangod/IResearch/IResearchFilterFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ struct FilterFactory {
arangodb::aql::AstNode const& node);
}; // FilterFactory


struct FilterConstants {
// Defaults
static constexpr size_t DefaultScoringTermsLimit { 128 };
static constexpr double_t DefaultNgramMatchThreshold { 0.7 };
};

} // namespace iresearch
} // namespace arangodb

Expand Down
Loading
0