8000 Apm 55 by alexbakharew · Pull Request #14364 · arangodb/arangodb · GitHub
[go: up one dir, main page]

Skip to content

Apm 55 #14364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 17, 2021
Merged

Apm 55 #14364

Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
gauss decay
  • Loading branch information
alexbakharew committed Jun 11, 2021
commit 9517969690bd5fbc7e6189f3f795be9bad6232d7
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ build

*.deb
*.rpm
*.user

.DS_Store
*.swp
Expand Down
3 changes: 3 additions & 0 deletions arangod/Aql/AqlFunctionFeature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ void AqlFunctionFeature::addListFunctions() {
add({"REPLACE_NTH", ".,.,.|.", flags, &Functions::ReplaceNth});
add({"INTERLEAVE", ".,.|+", flags, &Functions::Interleave});

add({"GAUSS_DECAY", ".,.,.,.,.,", flags, &Functions::GaussDecay});
add({"EXP_DECAY", ".,.,.,.,.,", flags, &Functions::ExpDecay});
add({"LINEAR_DECAY", ".,.,.,.,.,", flags, &Functions::LinearDecay});
// special flags:
// CALL and APPLY will always run on the coordinator and are not deterministic
// and not cacheable, as we don't know what function is actually gonna be
Expand Down
99 changes: 99 additions & 0 deletions arangod/Aql/Functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8848,6 +8848,105 @@ AqlValue Functions::MakeDistributeGraphInput(arangodb::aql::ExpressionContext* e
return AqlValue{input};
}

double gauss_decay_f(const double arg, const double origin, const double scale, const double offset = 0, const double decay = 0.5) {

double sigma_sqr = - (scale * scale) / (2 * std::log(decay));
double val = std::exp(- (std::pow(std::max(0.0, std::fabs(arg - origin) - offset), 2)) / (2 * sigma_sqr));
return val > decay ? val : decay;
}

AqlValue Functions::GaussDecay(arangodb::aql::ExpressionContext* expressionContext,
AstNode const& node,
VPackFunctionParameters const& parameters) {

AqlValue const& arg_value = extractFunctionParameterValue(parameters, 0);
AqlValue const& origin_value = extractFunctionParameterValue(parameters, 1);
AqlValue const& scale_value = extractFunctionParameterValue(parameters, 2);
AqlValue const& offset_value = extractFunctionParameterValue(parameters, 3);
AqlValue const& decay_value = extractFunctionParameterValue(parameters, 4);

// extract AQL function name
auto const* impl = static_cast<arangodb::aql::Function*>(node.getData());
TRI_ASSERT(impl != nullptr);

// check type of arguments
if ((!arg_value.isRange() && !arg_value.isArray() && !arg_value.isNumber()) ||
!origin_value.isNumber() || !scale_value.isNumber() ||
!offset_value.isNumber() || !decay_value.isNumber()) {

registerInvalidArgumentWarning(expressionContext, impl->name.c_str());
return AqlValue(AqlValueHintNull());
}

// extracting values
bool failed;
bool error = false;
double origin = origin_value.toDouble(failed);
error |= failed;
double scale = scale_value.toDouble(failed);
error |= failed;
double offset = offset_value.toDouble(failed);
error |= failed;
double decay = decay_value.toDouble(failed);
error |= failed;

if (error) {
registerWarning(expressionContext, impl->name.c_str(), TRI_ERROR_QUERY_FUNCTION_ARGUMENT_TYPE_MISMATCH);
return AqlValue(AqlValueHintNull());
}

// check that parameters are correct
if (origin < 0 || scale < 0 || offset < 0 || decay < 0) {
registerWarning(expressionContext, impl->name.c_str(), TRI_ERROR_QUERY_NUMBER_OUT_OF_RANGE);
return AqlValue(AqlValueHintNull());
}

// argument is number
if (arg_value.isNumber()) {
double arg = arg_value.slice().getNumber<double>();
double func_res = gauss_decay_f(arg, origin, scale, offset, decay);
return ::numberValue(func_res, true);
} else {
// argument is array or range
auto* trx = &expressionContext->trx();
AqlValueMaterializer materializer(&trx->vpackOptions());
VPackSlice slice = materializer.slice(arg_value, true);
TRI_ASSERT(slice.isArray());

VPackBuilder builder;
{
VPackArrayBuilder array_builder(&builder);
for (VPackSlice curr_arg : VPackArrayIterator(slice)) {
if (!curr_arg.isNumber()) {
registerWarning(expressionContext, impl->name.c_str(), TRI_ERROR_QUERY_FUNCTION_ARGUMENT_TYPE_MISMATCH);
return AqlValue(AqlValueHintNull());
}
double arg = curr_arg.getNumber<double>();
double func_res = gauss_decay_f(arg, origin, scale, offset, decay);
builder.add(VPackValue(func_res));
}
}

return AqlValue(builder.slice());
}
}

AqlValue Functions::ExpDecay(arangodb::aql::ExpressionContext* expressionContext,
AstNode const& node,
VPackFunctionParameters const& parameters) {
// temp plug
AqlValue tmp;
return tmp;
}

AqlValue Functions::LinearDecay(arangodb::aql::ExpressionContext* expressionContext,
AstNode const& node,
VPackFunctionParameters const& parameters) {
// temp plug
AqlValue tmp;
return tmp;
}

AqlValue Functions::NotImplemented(ExpressionContext* expressionContext, AstNode const&,
VPackFunctionParameters const& params) {
registerError(expressionContext, "UNKNOWN", TRI_ERROR_NOT_IMPLEMENTED);
Expand Down
9 changes: 9 additions & 0 deletions arangod/Aql/Functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,15 @@ struct Functions {
static AqlValue MakeDistributeGraphInput(arangodb::aql::ExpressionContext*,
AstNode const&, VPackFunctionParameters const&);

static AqlValue GaussDecay(arangodb::aql::ExpressionContext*,
AstNode const&, VPackFunctionParameters const&);

static AqlValue ExpDecay(arangodb::aql::ExpressionContext*,
AstNode const&, VPackFunctionParameters const&);

static AqlValue LinearDecay(arangodb::aql::ExpressionContext*,
AstNode const&, VPackFunctionParameters const&);

/// @brief dummy function that will only throw an error when called
static AqlValue NotImplemented(arangodb::aql::ExpressionContext*,
AstNode const&, VPackFunctionParameters const&);
Expand Down
162 changes: 162 additions & 0 deletions tests/Aql/DecaysFunctionTest.cpp
6D47
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
////////////////////////////////////////////////////////////////////////////////
/// DISCLAIMER
///
/// Copyright 2014-2020 ArangoDB GmbH, Cologne, Germany
/// Copyright 2004-2014 triAGENS GmbH, Cologne, Germany
///
/// Licensed under the Apache License, Version 2.0 (the "License");
/// you may not use this file except in compliance with the License.
/// You may obtain a copy of the License at
///
/// http://www.apache.org/licenses/LICENSE-2.0
///
/// Unless required by applicable law or agreed to in writing, software
/// distributed under the License is distributed on an "AS IS" BASIS,
/// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
/// See the License for the specific language governing permissions and
/// limitations under the License.
///
/// Copyright holder is ArangoDB GmbH, Cologne, Germany
///
/// @author Alexey Bakharew
////////////////////////////////////////////////////////////////////////////////

#include "gtest/gtest.h"

#include "fakeit.hpp"

#include <vector>

#include "Aql/AstNode.h"
#include "Aql/ExpressionContext.h"
#include "Aql/Function.h"
#include "Aql/Functions.h"
#include "Containers/SmallVector.h"
#include "Transaction/Context.h"
#include "Transaction/Methods.h"
#include "IResearch/IResearchQueryCommon.h"
#include <velocypack/Builder.h>
#include <velocypack/Iterator.h>
#include <velocypack/Parser.h>
#include <velocypack/Slice.h>
#include <velocypack/velocypack-aliases.h>

using namespace arangodb;
using namespace arangodb::aql;
using namespace arangodb::containers;

namespace {


// helper function
SmallVector<AqlValue> create_arg_vec(const VPackSlice slice) {

SmallVector<AqlValue>::allocator_type::arena_type arena;
SmallVector<AqlValue> params{arena};

for (const auto arg : VPackArrayIterator(slice)) {
params.emplace_back(AqlValue(arg));
}

return params;
}

void expect_eq_slices(const VPackSlice actual_slice, const VPackSlice expected_slice) {
if (actual_slice.isArray() && expected_slice.isArray()) {
VPackValueLength actual_size = actual_slice.length();
VPackValueLength expected_size = actual_slice.length();
ASSERT_EQ(actual_size, expected_size);

double lhs, rhs;
for(VPackValueLength i = 0; i < actual_size; ++i) {
lhs = actual_slice.at(i).getNumber<decltype (lhs)>();
rhs = actual_slice.at(i).getNumber<decltype (rhs)>();
ASSERT_DOUBLE_EQ(lhs, rhs);
}
} else if (actual_slice.isNumber() && expected_slice.isNumber()) {
double lhs = actual_slice.getNumber<decltype (lhs)>();
double rhs = expected_slice.getNumber<decltype (rhs)>();
ASSERT_DOUBLE_EQ(lhs, rhs);
} else {
ASSERT_TRUE(false);
}

return;
}

AqlValue evaluate_gauss(const SmallVector<AqlValue> params) {
fakeit::Mock<ExpressionContext> expressionContextMock;
ExpressionContext& expressionContext = expressionContextMock.get();
fakeit::When(Method(expressionContextMock, registerWarning)).AlwaysDo([](ErrorCode, char const*){ });

VPackOptions options;
fakeit::Mock<transaction::Context> trxCtxMock;
fakeit::When(Method(trxCtxMock, getVPackOptions)).AlwaysReturn(&options);
transaction::Context& trxCtx = trxCtxMock.get();

fakeit::Mock<transaction::Methods> trxMock;
fakeit::When(Method(trxMock, transactionContextPtr)).AlwaysReturn(&trxCtx);
fakeit::When(Method(trxMock, vpackOptions)).AlwaysReturn(options);
transaction::Methods& trx = trxMock.get();

fakeit::When(Method(expressionContextMock, trx)).AlwaysDo([&trx]() -> transaction::Methods& {
return trx;
});

arangodb::aql::Function f("GAUSS_DECAY", &Functions::GaussDecay);
arangodb::aql::AstNode node(NODE_TYPE_FCALL);
node.setData(static_cast<void const*>(&f));

return Functions::GaussDecay(&expressionContext, node, params);
}

void assertGauss(char const* expected, char const* args) {

// get slice for expected value
auto const expected_json = VPackParser::fromJson(expected);
auto const expected_slice = expected_json->slice();
ASSERT_TRUE(expected_slice.isArray() || expected_slice.isNumber());

// get slice for args value
auto const args_json = VPackParser::fromJson(args);
auto const args_slice = args_json->slice();
ASSERT_TRUE(args_slice.isArray());

// create params vector from args slice
SmallVector<AqlValue> params = create_arg_vec(args_slice);

// evaluate
auto const actual_value = evaluate_gauss(params);
ASSERT_TRUE(actual_value.isNumber() || actual_value.isArray());

// check equality
expect_eq_slices(actual_value.slice(), expected_slice);
return;
}

void assertGaussFail(char const* args) {
// get slice for args value
auto const args_json = VPackParser::fromJson(args);
auto const args_slice = args_json->slice();
ASSERT_TRUE(args_slice.isArray());

// create params vector from args slice
SmallVector<AqlValue> params = create_arg_vec(args_slice);

ASSERT_TRUE(evaluate_gauss(params).isNull(false));
}

TEST(GaussDecayFunctionTest, test) {
assertGauss("0.5", "[20, 40, 5, 5, 0.5]");
assertGauss("1", "[41, 40, 5, 5, 0.5]");
assertGauss("[0.5, 1.0]", "[[20.0, 41], 40, 5, 5, 0.5]");
assertGauss("1.0", "[40, 40, 5, 5, 0.5]");
assertGauss("1.0", "[49.987, 49.987, 0.001, 0.001, 0.2]");
assertGauss("0.2715403018822964", "[49.9889, 49.987, 0.001, 0.001, 0.2]");
assertGauss("0.1", "[-10, 40, 5, 0, 0.1]");
assertGaussFail("[30, 40, 5]");
assertGaussFail("[30, 40, 5, 100]");
assertGaussFail("[30, 40, 5, 100, -100]");
}

} // namespase
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ set(ARANGODB_TESTS_SOURCES
Aql/VelocyPackHelper.cpp
Aql/WaitingExecutionBlockMock.cpp
Aql/WindowExecutorTest.cpp
Aql/DecaysFunctionTest.cpp
AsyncAgencyComm/AsyncAgencyCommTest.cpp
Auth/UserManagerTest.cpp
Basics/conversions-test.cpp
Expand Down
0