8000 Apm 55 (#14364) · arangodb/arangodb@7f9c774 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7f9c774

Browse files
authored
Apm 55 (#14364)
* gauss decay * 3 functions are ready * fix for win build * changes after code review * changelog * changes in functions + js tests + range test * renamed functions
1 parent 657985c commit 7f9c774

File tree

8 files changed

+532
-0
lines changed

8 files changed

+532
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ build
2020

2121
*.deb
2222
*.rpm
23+
*.user
2324

2425
.DS_Store
2526
*.swp

CHANGELOG

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ devel
55

66
* Slightly improve specific warning messages for better readability.
77

8+
* Add 3 AQL functions: DECAY_GAUSS, DECAY_EXP and DECAY_LINEAR.
9+
810
* Fix URL request parsing in case data is handed in in small chunks.
911
Previously the URL could be cut off if the chunk size was smaller than
1012
the URL size.

arangod/Aql/AqlFunctionFeature.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,9 @@ void AqlFunctionFeature::addListFunctions() {
314314
add({"REPLACE_NTH", ".,.,.|.", flags, &Functions::ReplaceNth});
315315
add({"INTERLEAVE", ".,.|+", flags, &Functions::Interleave});
316316

317+
add({"DECAY_GAUSS", ".,.,.,.,.,", flags, &Functions::DecayGauss});
318+
add({"DECAY_EXP", ".,.,.,.,.,", flags, &Functions::DecayExp});
319+
add({"DECAY_LINEAR", ".,.,.,.,.,", flags, &Functions::DecayLinear});
317320
// special flags:
318321
// CALL and APPLY will always run on the coordinator and are not deterministic
319322
// and not cacheable, as we don't know what function is actually gonna be

arangod/Aql/Functions.cpp

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,6 +1313,14 @@ Result parseShape(ExpressionContext* exprCtx,
13131313
}
13141314
}
13151315

1316+
irs::string_ref getFunctionName(const AstNode& node) {
1317+
1318+
TRI_ASSERT(aql::NODE_TYPE_FCALL == node.type);
1319+
auto const* impl = static_cast<arangodb::aql::Function*>(node.getData());
1320+
TRI_ASSERT(impl != nullptr);
1321+
return impl->name;
1322+
}
1323+
13161324
} // namespace
13171325

13181326
namespace arangodb {
@@ -8848,6 +8856,148 @@ AqlValue Functions::MakeDistributeGraphInput(arangodb::aql::ExpressionContext* e
88488856
return AqlValue{input};
88498857
}
88508858

8859+
template <typename F>
8860+
AqlValue decayFuncImpl(arangodb::aql::ExpressionContext* expressionContext,
8861+
AstNode const& node,
8862+
VPackFunctionParameters const& parameters,
8863+
F&& decayFuncFactory) {
8864+
8865+
AqlValue const& argValue = extractFunctionParameterValue(parameters, 0);
8866+
AqlValue const& originValue = extractFunctionParameterValue(parameters, 1);
8867+
AqlValue const& scaleValue = extractFunctionParameterValue(parameters, 2);
8868+
AqlValue const& offsetValue = extractFunctionParameterValue(parameters, 3);
8869+
AqlValue const& decayValue = extractFunctionParameterValue(parameters, 4);
8870+
8871+
// check type of arguments
8872+
if ((!argValue.isRange() && !argValue.isArray() && !argValue.isNumber()) ||
8873+
!originValue.isNumber() || !scaleValue.isNumber() ||
8874+
!offsetValue.isNumber() || !decayValue.isNumber()) {
8875+
8876+
registerInvalidArgumentWarning(expressionContext,
8877+
getFunctionName(node).c_str());
8878+
return AqlValue(AqlValueHintNull());
8879+
}
8880+
8881+
// extracting values
8882+
bool failed;
8883+
bool error = false;
8884+
double origin = originValue.toDouble(failed);
8885+
error |= failed;
8886+
double scale = scaleValue.toDouble(failed);
8887+
error |= failed;
8888+
double offset = offsetValue.toDouble(failed);
8889+
error |= failed;
8890+
double decay = decayValue.toDouble(failed);
8891+
error |= failed;
8892+
8893+
if (error) {
8894+
registerWarning(expressionContext,
8895+
getFunctionName(node).c_str(),
8896+
TRI_ERROR_QUERY_FUNCTION_ARGUMENT_TYPE_MISMATCH);
8897+
return AqlValue(AqlValueHintNull());
8898+
}
8899+
8900+
// check that parameters are correct
8901+
if (scale <= 0 || offset < 0 || decay <= 0 || decay >= 1) {
8902+
registerWarning(expressionContext,
8903+
getFunctionName(node).c_str(),
8904+
TRI_ERROR_QUERY_NUMBER_OUT_OF_RANGE);
8905+
return AqlValue(AqlValueHintNull());
8906+
}
8907+
8908+
// get lambda for further calculation
8909+
auto decayFunc = decayFuncFactory(origin, scale, offset, decay);
8910+
8911+
// argument is number
8912+
if (argValue.isNumber()) {
8913+
double arg = argValue.slice().getNumber<double>();
8914+
double funcRes = decayFunc(arg);
8915+
return ::numberValue(funcRes, true);
8916+
} else {
8917+
// argument is array or range
8918+
auto* trx = &expressionContext->trx();
8919+
AqlValueMaterializer materializer(&trx->vpackOptions());
8920+
VPackSlice slice = materializer.slice(argValue, true);
8921+
TRI_ASSERT(slice.isArray());
8922+
8923+
VPackBuilder builder;
8924+
{
8925+
VPackArrayBuilder arrayBuilder(&builder);
8926+
for (VPackSlice currArg : VPackArrayIterator(slice)) {
8927+
if (!currArg.isNumber()) {
8928+
registerWarning(expressionContext,
8929+
getFunctionName(node).c_str(),
8930+
TRI_ERROR_QUERY_FUNCTION_ARGUMENT_TYPE_MISMATCH);
8931+
return AqlValue(AqlValueHintNull());
8932+
}
8933+
double arg = currArg.getNumber<double>();
8934+
double funcRes = decayFunc(arg);
8935+
builder.add(VPackValue(funcRes));
8936+
}
8937+
}
8938+
8939+
return AqlValue(std::move(*builder.steal()));
8940+
}
8941+
}
8942+
8943+
AqlValue Functions::DecayGauss(arangodb::aql::ExpressionContext* expressionContext,
8944+
AstNode const& node,
8945+
VPackFunctionParameters const& parameters) {
8946+
8947+
auto gaussDecayFactory = [](const double origin,
8948+
const double scale,
8949+
const double offset,
8950+
const double decay) {
8951+
const double sigmaSqr = - (scale * scale) / (2 * std::log(decay));
8952+
return [=](double arg) {
8953+
double max = std::max(0.0, std::fabs(arg - origin) - offset);
8954+
double numerator = max * max;
8955+
double val = std::exp(- numerator / (2 * sigmaSqr));
8956+
return val;
8957+
};
8958+
};
8959+
8960+
return decayFuncImpl(expressionContext, node, parameters, gaussDecayFactory);
8961+
}
8962+
8963+
AqlValue Functions::DecayExp(arangodb::aql::ExpressionContext* expressionContext,
8964+
AstNode const& node,
8965+
VPackFunctionParameters const& parameters) {
8966+
8967+
auto expDecayFactory = [](const double origin,
8968+
const double scale,
8969+
const double offset,
8970+
const double decay) {
8971+
const double lambda = std::log(decay) / scale;
8972+
return [=](double arg) {
8973+
double numerator = lambda * std::max(0.0, std::abs(arg - origin) - offset);
8974+
double val = std::exp(numerator);
8975+
return val;
8976+
};
8977+
};
8978+
8979+
return decayFuncImpl(expressionContext, node, parameters, expDecayFactory);
8980+
}
8981+
8982+
AqlValue Functions::DecayLinear(arangodb::aql::ExpressionContext* expressionContext,
8983+
AstNode const& node,
8984+
VPackFunctionParameters const& parameters) {
8985+
8986+
auto linearDecayFactory = [](const double origin,
8987+
const double scale,
8988+
const double offset,
8989+
const double decay) {
8990+
const double s = scale / (1.0 - decay);
8991+
return [=](double arg) {
8992+
double max = std::max(0.0, std::fabs(arg - origin) - offset);
8993+
double val = std::max((s - max) / s, 0.0);
8994+
return val;
8995+
};
8996+
};
8997+
8998+
return decayFuncImpl(expressionContext, node, parameters, linearDecayFactory);
8999+
}
9000+
88519001
AqlValue Functions::NotImplemented(ExpressionContext* expressionContext, AstNode const&,
88529002
VPackFunctionParameters const& params) {
88539003
registerError(expressionContext, "UNKNOWN", TRI_ERROR_NOT_IMPLEMENTED);

arangod/Aql/Functions.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,15 @@ struct Functions {
532532
static AqlValue MakeDistributeGraphInput(arangodb::aql::ExpressionContext*,
533533
AstNode const&, VPackFunctionParameters const&);
534534

535+
static AqlValue DecayGauss(arangodb::aql::ExpressionContext*,
536+
AstNode const&, VPackFunctionParameters const&);
537+
538+
static AqlValue DecayExp(arangodb::aql::ExpressionContext*,
539+
AstNode const&, VPackFunctionParameters const&);
540+
541+
static AqlValue DecayLinear(arangodb::aql::ExpressionContext*,
542+
As 548F tNode const&, VPackFunctionParameters const&);
543+
535544
/// @brief dummy function that will only throw an error when called
536545
static AqlValue NotImplemented(arangodb::aql::ExpressionContext*,
537546
AstNode const&, VPackFunctionParameters const&);

0 commit comments

Comments
 (0)
0