diff --git a/pom.xml b/pom.xml
index 3dea1ef15..2b8783fb8 100644
--- a/pom.xml
+++ b/pom.xml
@@ -266,6 +266,16 @@
h2
1.4.200
+
+ org.mongodb
+ mongodb-driver-sync
+ 4.1.1
+
+
+ com.arangodb
+ arangodb-java-driver
+ 6.9.0
+
diff --git a/src/sqlancer/GlobalState.java b/src/sqlancer/GlobalState.java
index 97c5a7a9d..642c9da6a 100644
--- a/src/sqlancer/GlobalState.java
+++ b/src/sqlancer/GlobalState.java
@@ -89,13 +89,13 @@ private ExecutionTimer executePrologue(Query> q) throws Exception {
timer = new ExecutionTimer().start();
}
if (getOptions().printAllStatements()) {
- System.out.println(q.getQueryString());
+ System.out.println(q.getLogString());
}
if (getOptions().logEachSelect()) {
if (logExecutionTime) {
- getLogger().writeCurrentNoLineBreak(q.getQueryString());
+ getLogger().writeCurrentNoLineBreak(q.getLogString());
} else {
- getLogger().writeCurrent(q.getQueryString());
+ getLogger().writeCurrent(q.getLogString());
}
}
return timer;
diff --git a/src/sqlancer/Main.java b/src/sqlancer/Main.java
index 755ce36a8..c5da9ec98 100644
--- a/src/sqlancer/Main.java
+++ b/src/sqlancer/Main.java
@@ -21,15 +21,18 @@
import com.beust.jcommander.JCommander;
import com.beust.jcommander.JCommander.Builder;
+import sqlancer.arangodb.ArangoDBProvider;
import sqlancer.citus.CitusProvider;
import sqlancer.clickhouse.ClickHouseProvider;
import sqlancer.cockroachdb.CockroachDBProvider;
import sqlancer.common.log.Loggable;
import sqlancer.common.query.Query;
import sqlancer.common.query.SQLancerResultSet;
+import sqlancer.cosmos.CosmosProvider;
import sqlancer.duckdb.DuckDBProvider;
import sqlancer.h2.H2Provider;
import sqlancer.mariadb.MariaDBProvider;
+import sqlancer.mongodb.MongoDBProvider;
import sqlancer.mysql.MySQLProvider;
import sqlancer.postgres.PostgresProvider;
import sqlancer.sqlite3.SQLite3Provider;
@@ -209,7 +212,7 @@ private void printState(FileWriter writer, StateToReproduce state) {
.getInfo(state.getDatabaseName(), state.getDatabaseVersion(), state.getSeedValue()).getLogString());
for (Query> s : state.getStatements()) {
- sb.append(s.getQueryString());
+ sb.append(s.getLogString());
sb.append('\n');
}
try {
@@ -554,6 +557,9 @@ private boolean run(MainOptions options, ExecutorService execService,
providers.add(new ClickHouseProvider());
providers.add(new DuckDBProvider());
providers.add(new H2Provider());
+ providers.add(new MongoDBProvider());
+ providers.add(new CosmosProvider());
+ providers.add(new ArangoDBProvider());
return providers;
}
diff --git a/src/sqlancer/arangodb/ArangoDBComparatorHelper.java b/src/sqlancer/arangodb/ArangoDBComparatorHelper.java
new file mode 100644
index 000000000..2a00a312d
--- /dev/null
+++ b/src/sqlancer/arangodb/ArangoDBComparatorHelper.java
@@ -0,0 +1,73 @@
+package sqlancer.arangodb;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import com.arangodb.entity.BaseDocument;
+
+import sqlancer.IgnoreMeException;
+import sqlancer.Main;
+import sqlancer.arangodb.query.ArangoDBSelectQuery;
+import sqlancer.common.query.ExpectedErrors;
+
+public final class ArangoDBComparatorHelper {
+
+ private ArangoDBComparatorHelper() {
+
+ }
+
+ public static List getResultSetAsDocumentList(ArangoDBSelectQuery query,
+ ArangoDBProvider.ArangoDBGlobalState state) throws Exception {
+ ExpectedErrors errors = query.getExpectedErrors();
+ List result;
+ try {
+ query.executeAndGet(state);
+ Main.nrSuccessfulActions.addAndGet(1);
+ result = query.getResultSet();
+ return result;
+ } catch (Exception e) {
+ if (e instanceof IgnoreMeException) {
+ throw e;
+ }
+ Main.nrUnsuccessfulActions.addAndGet(1);
+ if (e.getMessage() == null) {
+ throw new AssertionError(query.getLogString(), e);
+ }
+ if (errors.errorIsExpected(e.getMessage())) {
+ throw new IgnoreMeException();
+ }
+ throw new AssertionError(query.getLogString(), e);
+ }
+
+ }
+
+ public static void assumeResultSetsAreEqual(List resultSet, List secondResultSet,
+ ArangoDBSelectQuery originalQuery) {
+ if (resultSet.size() != secondResultSet.size()) {
+ String assertionMessage = String.format("The Size of the result sets mismatch (%d and %d)!\n%s",
+ resultSet.size(), secondResultSet.size(), originalQuery.getLogString());
+ throw new AssertionError(assertionMessage);
+ }
+ Set firstHashSet = new HashSet<>(resultSet);
+ Set secondHashSet = new HashSet<>(secondResultSet);
+
+ if (!firstHashSet.equals(secondHashSet)) {
+ Set firstResultSetMisses = new HashSet<>(firstHashSet);
+ firstResultSetMisses.removeAll(secondHashSet);
+ Set secondResultSetMisses = new HashSet<>(secondHashSet);
+ secondResultSetMisses.removeAll(firstHashSet);
+ StringBuilder firstMisses = new StringBuilder();
+ for (BaseDocument document : firstResultSetMisses) {
+ firstMisses.append(document).append(" ");
+ }
+ StringBuilder secondMisses = new StringBuilder();
+ for (BaseDocument document : secondResultSetMisses) {
+ secondMisses.append(document).append(" ");
+ }
+ String assertMessage = String.format("The Content of the result sets mismatch!\n %s \n %s\n %s",
+ firstMisses.toString(), secondMisses.toString(), originalQuery.getLogString());
+ throw new AssertionError(assertMessage);
+ }
+ }
+}
diff --git a/src/sqlancer/arangodb/ArangoDBConnection.java b/src/sqlancer/arangodb/ArangoDBConnection.java
new file mode 100644
index 000000000..b3e5b85d3
--- /dev/null
+++ b/src/sqlancer/arangodb/ArangoDBConnection.java
@@ -0,0 +1,31 @@
+package sqlancer.arangodb;
+
+import com.arangodb.ArangoDB;
+import com.arangodb.ArangoDatabase;
+
+import sqlancer.SQLancerDBConnection;
+
+public class ArangoDBConnection implements SQLancerDBConnection {
+
+ private final ArangoDB client;
+ private final ArangoDatabase database;
+
+ public ArangoDBConnection(ArangoDB client, ArangoDatabase database) {
+ this.client = client;
+ this.database = database;
+ }
+
+ @Override
+ public String getDatabaseVersion() throws Exception {
+ return client.getVersion().getVersion();
+ }
+
+ @Override
+ public void close() throws Exception {
+ client.shutdown();
+ }
+
+ public ArangoDatabase getDatabase() {
+ return database;
+ }
+}
diff --git a/src/sqlancer/arangodb/ArangoDBLoggableFactory.java b/src/sqlancer/arangodb/ArangoDBLoggableFactory.java
new file mode 100644
index 000000000..927d9f320
--- /dev/null
+++ b/src/sqlancer/arangodb/ArangoDBLoggableFactory.java
@@ -0,0 +1,40 @@
+package sqlancer.arangodb;
+
+import java.util.Arrays;
+
+import sqlancer.common.log.Loggable;
+import sqlancer.common.log.LoggableFactory;
+import sqlancer.common.log.LoggedString;
+import sqlancer.common.query.Query;
+
+public class ArangoDBLoggableFactory extends LoggableFactory {
+ @Override
+ protected Loggable createLoggable(String input, String suffix) {
+ return new LoggedString(input + suffix);
+ }
+
+ @Override
+ public Query> getQueryForStateToReproduce(String queryString) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Query> commentOutQuery(Query> query) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ protected Loggable infoToLoggable(String time, String databaseName, String databaseVersion, long seedValue) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("// Time: ").append(time).append("\n");
+ sb.append("// Database: ").append(databaseName).append("\n");
+ sb.append("// Database version: ").append(databaseVersion).append("\n");
+ sb.append("// seed value: ").append(seedValue).append("\n");
+ return new LoggedString(sb.toString());
+ }
+
+ @Override
+ public Loggable convertStacktraceToLoggable(Throwable throwable) {
+ return new LoggedString(Arrays.toString(throwable.getStackTrace()) + "\n" + throwable.getMessage());
+ }
+}
diff --git a/src/sqlancer/arangodb/ArangoDBOptions.java b/src/sqlancer/arangodb/ArangoDBOptions.java
new file mode 100644
index 000000000..cdb7ee759
--- /dev/null
+++ b/src/sqlancer/arangodb/ArangoDBOptions.java
@@ -0,0 +1,44 @@
+package sqlancer.arangodb;
+
+import static sqlancer.arangodb.ArangoDBOptions.ArangoDBOracleFactory.QUERY_PARTITIONING;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import com.beust.jcommander.Parameter;
+
+import sqlancer.DBMSSpecificOptions;
+import sqlancer.OracleFactory;
+import sqlancer.arangodb.test.ArangoDBQueryPartitioningWhereTester;
+import sqlancer.common.oracle.CompositeTestOracle;
+import sqlancer.common.oracle.TestOracle;
+
+public class ArangoDBOptions implements DBMSSpecificOptions {
+
+ @Parameter(names = "--oracle")
+ public List oracles = Arrays.asList(QUERY_PARTITIONING);
+
+ @Parameter(names = "--test-random-type-inserts", description = "Insert random types instead of schema types.")
+ public boolean testRandomTypeInserts;
+
+ @Parameter(names = "--max-number-indexes", description = "The maximum number of indexes used.", arity = 1)
+ public int maxNumberIndexes = 15;
+
+ @Override
+ public List getTestOracleFactory() {
+ return oracles;
+ }
+
+ public enum ArangoDBOracleFactory implements OracleFactory {
+ QUERY_PARTITIONING {
+ @Override
+ public TestOracle create(ArangoDBProvider.ArangoDBGlobalState globalState) throws Exception {
+ List oracles = new ArrayList<>();
+ oracles.add(new ArangoDBQueryPartitioningWhereTester(globalState));
+ return new CompositeTestOracle(oracles, globalState);
+ }
+ }
+
+ }
+}
diff --git a/src/sqlancer/arangodb/ArangoDBProvider.java b/src/sqlancer/arangodb/ArangoDBProvider.java
new file mode 100644
index 000000000..d63d4a1cc
--- /dev/null
+++ b/src/sqlancer/arangodb/ArangoDBProvider.java
@@ -0,0 +1,134 @@
+package sqlancer.arangodb;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import com.arangodb.ArangoDB;
+import com.arangodb.ArangoDatabase;
+
+import sqlancer.AbstractAction;
+import sqlancer.ExecutionTimer;
+import sqlancer.GlobalState;
+import sqlancer.IgnoreMeException;
+import sqlancer.ProviderAdapter;
+import sqlancer.Randomly;
+import sqlancer.StatementExecutor;
+import sqlancer.arangodb.gen.ArangoDBCreateIndexGenerator;
+import sqlancer.arangodb.gen.ArangoDBInsertGenerator;
+import sqlancer.arangodb.gen.ArangoDBTableGenerator;
+import sqlancer.common.log.LoggableFactory;
+import sqlancer.common.query.Query;
+
+public class ArangoDBProvider
+ extends ProviderAdapter {
+
+ public ArangoDBProvider() {
+ super(ArangoDBGlobalState.class, ArangoDBOptions.class);
+ }
+
+ enum Action implements AbstractAction {
+ INSERT(ArangoDBInsertGenerator::getQuery), CREATE_INDEX(ArangoDBCreateIndexGenerator::getQuery);
+
+ private final ArangoDBQueryProvider queryProvider;
+
+ Action(ArangoDBQueryProvider queryProvider) {
+ this.queryProvider = queryProvider;
+ }
+
+ @Override
+ public Query> getQuery(ArangoDBGlobalState globalState) throws Exception {
+ return queryProvider.getQuery(globalState);
+ }
+ }
+
+ private static int mapActions(ArangoDBGlobalState globalState, Action a) {
+ Randomly r = globalState.getRandomly();
+ switch (a) {
+ case INSERT:
+ return r.getInteger(0, globalState.getOptions().getMaxNumberInserts());
+ case CREATE_INDEX:
+ return r.getInteger(0, globalState.getDmbsSpecificOptions().maxNumberIndexes);
+ default:
+ throw new AssertionError(a);
+ }
+ }
+
+ public static class ArangoDBGlobalState extends GlobalState {
+
+ private final List schemaTables = new ArrayList<>();
+
+ public void addTable(ArangoDBSchema.ArangoDBTable table) {
+ schemaTables.add(table);
+ }
+
+ @Override
+ protected void executeEpilogue(Query> q, boolean success, ExecutionTimer timer) throws Exception {
+ boolean logExecutionTime = getOptions().logExecutionTime();
+ if (success && getOptions().printSucceedingStatements()) {
+ System.out.println(q.getLogString());
+ }
+ if (logExecutionTime) {
+ getLogger().writeCurrent("//" + timer.end().asString());
+ }
+ if (q.couldAffectSchema()) {
+ updateSchema();
+ }
+ }
+
+ @Override
+ protected ArangoDBSchema readSchema() throws Exception {
+ return new ArangoDBSchema(schemaTables);
+ }
+ }
+
+ @Override
+ protected void checkViewsAreValid(ArangoDBGlobalState globalState) {
+
+ }
+
+ @Override
+ public void generateDatabase(ArangoDBGlobalState globalState) throws Exception {
+ for (int i = 0; i < Randomly.fromOptions(4, 5, 6); i++) {
+ boolean success;
+ do {
+ ArangoDBQueryAdapter queryAdapter = new ArangoDBTableGenerator().getQuery(globalState);
+ success = globalState.executeStatement(queryAdapter);
+ } while (!success);
+ }
+ StatementExecutor se = new StatementExecutor<>(globalState, Action.values(),
+ ArangoDBProvider::mapActions, (q) -> {
+ if (globalState.getSchema().getDatabaseTables().isEmpty()) {
+ throw new IgnoreMeException();
+ }
+ });
+ se.executeStatements();
+ }
+
+ @Override
+ public ArangoDBConnection createDatabase(ArangoDBGlobalState globalState) throws Exception {
+ ArangoDB arangoDB = new ArangoDB.Builder().user(globalState.getOptions().getUserName())
+ .password(globalState.getOptions().getPassword()).build();
+ ArangoDatabase database = arangoDB.db(globalState.getDatabaseName());
+ try {
+ database.drop();
+ // When the database does not exist, an ArangoDB exception is thrown. Since we are not sure
+ // if this is the first time the database is used, the simplest is dropping it and ignoring
+ // the exception.
+ } catch (Exception ignored) {
+
+ }
+ arangoDB.createDatabase(globalState.getDatabaseName());
+ database = arangoDB.db(globalState.getDatabaseName());
+ return new ArangoDBConnection(arangoDB, database);
+ }
+
+ @Override
+ public String getDBMSName() {
+ return "arangodb";
+ }
+
+ @Override
+ public LoggableFactory getLoggableFactory() {
+ return new ArangoDBLoggableFactory();
+ }
+}
diff --git a/src/sqlancer/arangodb/ArangoDBQueryAdapter.java b/src/sqlancer/arangodb/ArangoDBQueryAdapter.java
new file mode 100644
index 000000000..34cdb3709
--- /dev/null
+++ b/src/sqlancer/arangodb/ArangoDBQueryAdapter.java
@@ -0,0 +1,16 @@
+package sqlancer.arangodb;
+
+import sqlancer.common.query.Query;
+
+public abstract class ArangoDBQueryAdapter extends Query {
+ @Override
+ public String getQueryString() {
+ // Should not be called as it is used only in SQL dependent classes
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public String getUnterminatedQueryString() {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/src/sqlancer/arangodb/ArangoDBQueryProvider.java b/src/sqlancer/arangodb/ArangoDBQueryProvider.java
new file mode 100644
index 000000000..94a4ffda3
--- /dev/null
+++ b/src/sqlancer/arangodb/ArangoDBQueryProvider.java
@@ -0,0 +1,6 @@
+package sqlancer.arangodb;
+
+@FunctionalInterface
+public interface ArangoDBQueryProvider {
+ ArangoDBQueryAdapter getQuery(S globalState) throws Exception;
+}
diff --git a/src/sqlancer/arangodb/ArangoDBSchema.java b/src/sqlancer/arangodb/ArangoDBSchema.java
new file mode 100644
index 000000000..35e251b8b
--- /dev/null
+++ b/src/sqlancer/arangodb/ArangoDBSchema.java
@@ -0,0 +1,70 @@
+package sqlancer.arangodb;
+
+import java.util.Collections;
+import java.util.List;
+
+import sqlancer.Randomly;
+import sqlancer.common.schema.AbstractSchema;
+import sqlancer.common.schema.AbstractTable;
+import sqlancer.common.schema.AbstractTableColumn;
+import sqlancer.common.schema.AbstractTables;
+import sqlancer.common.schema.TableIndex;
+
+public class ArangoDBSchema extends AbstractSchema {
+
+ public enum ArangoDBDataType {
+ INTEGER, DOUBLE, STRING, BOOLEAN;
+
+ public static ArangoDBDataType getRandom() {
+ return Randomly.fromOptions(values());
+ }
+ }
+
+ public static class ArangoDBColumn extends AbstractTableColumn {
+
+ private final boolean isId;
+ private final boolean isNullable;
+
+ public ArangoDBColumn(String name, ArangoDBDataType type, boolean isId, boolean isNullable) {
+ super(name, null, type);
+ this.isId = isId;
+ this.isNullable = isNullable;
+ }
+
+ public boolean isId() {
+ return isId;
+ }
+
+ public boolean isNullable() {
+ return isNullable;
+ }
+ }
+
+ public ArangoDBSchema(List databaseTables) {
+ super(databaseTables);
+ }
+
+ public static class ArangoDBTables extends AbstractTables {
+
+ public ArangoDBTables(List tables) {
+ super(tables);
+ }
+ }
+
+ public static class ArangoDBTable
+ extends AbstractTable {
+
+ public ArangoDBTable(String name, List columns, boolean isView) {
+ super(name, columns, Collections.emptyList(), isView);
+ }
+
+ @Override
+ public long getNrRows(ArangoDBProvider.ArangoDBGlobalState globalState) {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ public ArangoDBTables getRandomTableNonEmptyTables() {
+ return new ArangoDBTables(Randomly.nonEmptySubset(getDatabaseTables()));
+ }
+}
diff --git a/src/sqlancer/arangodb/ast/ArangoDBConstant.java b/src/sqlancer/arangodb/ast/ArangoDBConstant.java
new file mode 100644
index 000000000..351dbd822
--- /dev/null
+++ b/src/sqlancer/arangodb/ast/ArangoDBConstant.java
@@ -0,0 +1,108 @@
+package sqlancer.arangodb.ast;
+
+import com.arangodb.entity.BaseDocument;
+
+import sqlancer.common.ast.newast.Node;
+
+public abstract class ArangoDBConstant implements Node {
+ private ArangoDBConstant() {
+
+ }
+
+ public abstract void setValueInDocument(BaseDocument document, String key);
+
+ public abstract Object getValue();
+
+ public static class ArangoDBIntegerConstant extends ArangoDBConstant {
+
+ private final int value;
+
+ public ArangoDBIntegerConstant(int value) {
+ this.value = value;
+ }
+
+ @Override
+ public void setValueInDocument(BaseDocument document, String key) {
+ document.addAttribute(key, value);
+ }
+
+ @Override
+ public Object getValue() {
+ return value;
+ }
+ }
+
+ public static Node createIntegerConstant(int value) {
+ return new ArangoDBIntegerConstant(value);
+ }
+
+ public static class ArangoDBStringConstant extends ArangoDBConstant {
+ private final String value;
+
+ public ArangoDBStringConstant(String value) {
+ this.value = value;
+ }
+
+ @Override
+ public void setValueInDocument(BaseDocument document, String key) {
+ document.addAttribute(key, value);
+ }
+
+ @Override
+ public Object getValue() {
+ return "'" + value.replace("\\", "\\\\").replace("'", "\\'") + "'";
+ }
+ }
+
+ public static Node createStringConstant(String value) {
+ return new ArangoDBStringConstant(value);
+ }
+
+ public static class ArangoDBBooleanConstant extends ArangoDBConstant {
+ private final boolean value;
+
+ public ArangoDBBooleanConstant(boolean value) {
+ this.value = value;
+ }
+
+ @Override
+ public void setValueInDocument(BaseDocument document, String key) {
+ document.addAttribute(key, value);
+ }
+
+ @Override
+ public Object getValue() {
+ return value;
+ }
+ }
+
+ public static Node createBooleanConstant(boolean value) {
+ return new ArangoDBBooleanConstant(value);
+ }
+
+ public static class ArangoDBDoubleConstant extends ArangoDBConstant {
+ private final double value;
+
+ public ArangoDBDoubleConstant(double value) {
+ if (Double.isInfinite(value) || Double.isNaN(value)) {
+ this.value = 0.0;
+ } else {
+ this.value = value;
+ }
+ }
+
+ @Override
+ public void setValueInDocument(BaseDocument document, String key) {
+ document.addAttribute(key, value);
+ }
+
+ @Override
+ public Object getValue() {
+ return value;
+ }
+ }
+
+ public static Node createDoubleConstant(double value) {
+ return new ArangoDBDoubleConstant(value);
+ }
+}
diff --git a/src/sqlancer/arangodb/ast/ArangoDBExpression.java b/src/sqlancer/arangodb/ast/ArangoDBExpression.java
new file mode 100644
index 000000000..facbbfe9e
--- /dev/null
+++ b/src/sqlancer/arangodb/ast/ArangoDBExpression.java
@@ -0,0 +1,4 @@
+package sqlancer.arangodb.ast;
+
+public interface ArangoDBExpression {
+}
diff --git a/src/sqlancer/arangodb/ast/ArangoDBSelect.java b/src/sqlancer/arangodb/ast/ArangoDBSelect.java
new file mode 100644
index 000000000..9fb91d553
--- /dev/null
+++ b/src/sqlancer/arangodb/ast/ArangoDBSelect.java
@@ -0,0 +1,79 @@
+package sqlancer.arangodb.ast;
+
+import java.util.List;
+
+import sqlancer.arangodb.ArangoDBSchema;
+import sqlancer.common.ast.newast.Node;
+
+public class ArangoDBSelect implements Node {
+ private List fromColumns;
+ private List projectionColumns;
+ private boolean hasFilter;
+ private Node filterClause;
+ private boolean hasComputed;
+ private List> computedClause;
+
+ public List getFromColumns() {
+ if (fromColumns == null || fromColumns.isEmpty()) {
+ throw new IllegalStateException();
+ }
+ return fromColumns;
+ }
+
+ public void setFromColumns(List fromColumns) {
+ if (fromColumns == null || fromColumns.isEmpty()) {
+ throw new IllegalStateException();
+ }
+ this.fromColumns = fromColumns;
+ }
+
+ public List getProjectionColumns() {
+ if (projectionColumns == null) {
+ throw new IllegalStateException();
+ }
+ return projectionColumns;
+ }
+
+ public void setProjectionColumns(List projectionColumns) {
+ if (projectionColumns == null) {
+ throw new IllegalStateException();
+ }
+ this.projectionColumns = projectionColumns;
+ }
+
+ public void setFilterClause(Node filterClause) {
+ if (filterClause == null) {
+ hasFilter = false;
+ this.filterClause = null;
+ return;
+ }
+ hasFilter = true;
+ this.filterClause = filterClause;
+ }
+
+ public Node getFilterClause() {
+ return filterClause;
+ }
+
+ public boolean hasFilter() {
+ return hasFilter;
+ }
+
+ public void setComputedClause(List> computedColumns) {
+ if (computedColumns == null || computedColumns.isEmpty()) {
+ hasComputed = false;
+ this.computedClause = null;
+ return;
+ }
+ hasComputed = true;
+ this.computedClause = computedColumns;
+ }
+
+ public List> getComputedClause() {
+ return computedClause;
+ }
+
+ public boolean hasComputed() {
+ return hasComputed;
+ }
+}
diff --git a/src/sqlancer/arangodb/ast/ArangoDBUnsupportedPredicate.java b/src/sqlancer/arangodb/ast/ArangoDBUnsupportedPredicate.java
new file mode 100644
index 000000000..eabd25578
--- /dev/null
+++ b/src/sqlancer/arangodb/ast/ArangoDBUnsupportedPredicate.java
@@ -0,0 +1,6 @@
+package sqlancer.arangodb.ast;
+
+import sqlancer.common.ast.newast.Node;
+
+public class ArangoDBUnsupportedPredicate implements Node {
+}
diff --git a/src/sqlancer/arangodb/gen/ArangoDBComputedExpressionGenerator.java b/src/sqlancer/arangodb/gen/ArangoDBComputedExpressionGenerator.java
new file mode 100644
index 000000000..8a3b98871
--- /dev/null
+++ b/src/sqlancer/arangodb/gen/ArangoDBComputedExpressionGenerator.java
@@ -0,0 +1,85 @@
+package sqlancer.arangodb.gen;
+
+import sqlancer.Randomly;
+import sqlancer.arangodb.ArangoDBProvider;
+import sqlancer.arangodb.ArangoDBSchema;
+import sqlancer.arangodb.ast.ArangoDBConstant;
+import sqlancer.arangodb.ast.ArangoDBExpression;
+import sqlancer.common.ast.newast.ColumnReferenceNode;
+import sqlancer.common.ast.newast.NewFunctionNode;
+import sqlancer.common.ast.newast.Node;
+import sqlancer.common.gen.UntypedExpressionGenerator;
+
+public class ArangoDBComputedExpressionGenerator
+ extends UntypedExpressionGenerator, ArangoDBSchema.ArangoDBColumn> {
+ private final ArangoDBProvider.ArangoDBGlobalState globalState;
+
+ public ArangoDBComputedExpressionGenerator(ArangoDBProvider.ArangoDBGlobalState globalState) {
+ this.globalState = globalState;
+ }
+
+ @Override
+ public Node generateConstant() {
+ ArangoDBSchema.ArangoDBDataType dataType = ArangoDBSchema.ArangoDBDataType.getRandom();
+ switch (dataType) {
+ case INTEGER:
+ return ArangoDBConstant.createIntegerConstant((int) globalState.getRandomly().getInteger());
+ case BOOLEAN:
+ return ArangoDBConstant.createBooleanConstant(Randomly.getBoolean());
+ case DOUBLE:
+ return ArangoDBConstant.createDoubleConstant(globalState.getRandomly().getDouble());
+ case STRING:
+ return ArangoDBConstant.createStringConstant(globalState.getRandomly().getString());
+ default:
+ throw new AssertionError(dataType);
+ }
+ }
+
+ public enum ComputedFunction {
+ ADD(2, "+"), MINUS(2, "-"), MULTIPLY(2, "*"), DIVISION(2, "/"), MODULUS(2, "%");
+
+ private final int nrArgs;
+ private final String operatorName;
+
+ ComputedFunction(int nrArgs, String operatorName) {
+ this.nrArgs = nrArgs;
+ this.operatorName = operatorName;
+ }
+
+ public static ComputedFunction getRandom() {
+ return Randomly.fromOptions(values());
+ }
+
+ public int getNrArgs() {
+ return nrArgs;
+ }
+
+ public String getOperatorName() {
+ return operatorName;
+ }
+ }
+
+ @Override
+ protected Node generateExpression(int depth) {
+ if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) {
+ return generateLeafNode();
+ }
+ ComputedFunction function = ComputedFunction.getRandom();
+ return new NewFunctionNode<>(generateExpressions(depth + 1, function.getNrArgs()), function);
+ }
+
+ @Override
+ protected Node generateColumn() {
+ return new ColumnReferenceNode<>(Randomly.fromList(columns));
+ }
+
+ @Override
+ public Node negatePredicate(Node predicate) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Node isNull(Node expr) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/src/sqlancer/arangodb/gen/ArangoDBCreateIndexGenerator.java b/src/sqlancer/arangodb/gen/ArangoDBCreateIndexGenerator.java
new file mode 100644
index 000000000..6a1b872da
--- /dev/null
+++ b/src/sqlancer/arangodb/gen/ArangoDBCreateIndexGenerator.java
@@ -0,0 +1,18 @@
+package sqlancer.arangodb.gen;
+
+import sqlancer.arangodb.ArangoDBProvider;
+import sqlancer.arangodb.ArangoDBQueryAdapter;
+import sqlancer.arangodb.ArangoDBSchema;
+import sqlancer.arangodb.query.ArangoDBCreateIndexQuery;
+
+public final class ArangoDBCreateIndexGenerator {
+ private ArangoDBCreateIndexGenerator() {
+
+ }
+
+ public static ArangoDBQueryAdapter getQuery(ArangoDBProvider.ArangoDBGlobalState globalState) {
+ ArangoDBSchema.ArangoDBTable randomTable = globalState.getSchema().getRandomTable();
+ ArangoDBSchema.ArangoDBColumn column = randomTable.getRandomColumn();
+ return new ArangoDBCreateIndexQuery(column);
+ }
+}
diff --git a/src/sqlancer/arangodb/gen/ArangoDBFilterExpressionGenerator.java b/src/sqlancer/arangodb/gen/ArangoDBFilterExpressionGenerator.java
new file mode 100644
index 000000000..1a2fc4b5e
--- /dev/null
+++ b/src/sqlancer/arangodb/gen/ArangoDBFilterExpressionGenerator.java
@@ -0,0 +1,153 @@
+package sqlancer.arangodb.gen;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import sqlancer.Randomly;
+import sqlancer.arangodb.ArangoDBProvider;
+import sqlancer.arangodb.ArangoDBSchema;
+import sqlancer.arangodb.ast.ArangoDBConstant;
+import sqlancer.arangodb.ast.ArangoDBExpression;
+import sqlancer.arangodb.ast.ArangoDBUnsupportedPredicate;
+import sqlancer.common.ast.BinaryOperatorNode;
+import sqlancer.common.ast.newast.ColumnReferenceNode;
+import sqlancer.common.ast.newast.NewBinaryOperatorNode;
+import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode;
+import sqlancer.common.ast.newast.Node;
+import sqlancer.common.gen.UntypedExpressionGenerator;
+
+public class ArangoDBFilterExpressionGenerator
+ extends UntypedExpressionGenerator, ArangoDBSchema.ArangoDBColumn> {
+
+ private final ArangoDBProvider.ArangoDBGlobalState globalState;
+ private int numberOfComputedVariables;
+
+ private enum Expression {
+ BINARY_LOGICAL, UNARY_PREFIX, BINARY_COMPARISON
+ }
+
+ public ArangoDBFilterExpressionGenerator(ArangoDBProvider.ArangoDBGlobalState globalState) {
+ this.globalState = globalState;
+ }
+
+ public void setNumberOfComputedVariables(int numberOfComputedVariables) {
+ this.numberOfComputedVariables = numberOfComputedVariables;
+ }
+
+ @Override
+ public Node generateConstant() {
+ ArangoDBSchema.ArangoDBDataType dataType = ArangoDBSchema.ArangoDBDataType.getRandom();
+ switch (dataType) {
+ case INTEGER:
+ return ArangoDBConstant.createIntegerConstant((int) globalState.getRandomly().getInteger());
+ case BOOLEAN:
+ return ArangoDBConstant.createBooleanConstant(Randomly.getBoolean());
+ case DOUBLE:
+ return ArangoDBConstant.createDoubleConstant(globalState.getRandomly().getDouble());
+ case STRING:
+ return ArangoDBConstant.createStringConstant(globalState.getRandomly().getString());
+ default:
+ throw new AssertionError(dataType);
+ }
+ }
+
+ @Override
+ protected Node generateExpression(int depth) {
+ if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) {
+ return generateLeafNode();
+ }
+ List possibleOptions = new ArrayList<>(Arrays.asList(Expression.values()));
+ Expression expression = Randomly.fromList(possibleOptions);
+ switch (expression) {
+ case BINARY_COMPARISON:
+ BinaryOperatorNode.Operator op = ArangoDBBinaryComparisonOperator.getRandom();
+ return new NewBinaryOperatorNode<>(generateExpression(depth + 1), generateExpression(depth + 1), op);
+ case UNARY_PREFIX:
+ return new NewUnaryPrefixOperatorNode<>(generateExpression(depth + 1),
+ ArangoDBUnaryPrefixOperator.getRandom());
+ case BINARY_LOGICAL:
+ op = ArangoDBBinaryLogicalOperator.getRandom();
+ return new NewBinaryOperatorNode<>(generateExpression(depth + 1), generateExpression(depth + 1), op);
+ default:
+ throw new AssertionError(expression);
+ }
+ }
+
+ @Override
+ protected Node generateColumn() {
+ ArangoDBSchema.ArangoDBTable dummy = new ArangoDBSchema.ArangoDBTable("", new ArrayList<>(), false);
+ if (Randomly.getBoolean() || numberOfComputedVariables == 0) {
+ ArangoDBSchema.ArangoDBColumn column = Randomly.fromList(columns);
+ return new ColumnReferenceNode<>(column);
+ } else {
+ int maxNumber = globalState.getRandomly().getInteger(0, numberOfComputedVariables);
+ ArangoDBSchema.ArangoDBColumn column = new ArangoDBSchema.ArangoDBColumn("c" + maxNumber,
+ ArangoDBSchema.ArangoDBDataType.INTEGER, false, false);
+ column.setTable(dummy);
+ return new ColumnReferenceNode<>(column);
+ }
+ }
+
+ @Override
+ public Node negatePredicate(Node predicate) {
+ return new NewUnaryPrefixOperatorNode<>(predicate, ArangoDBUnaryPrefixOperator.NOT);
+ }
+
+ @Override
+ public Node isNull(Node expr) {
+ return new ArangoDBUnsupportedPredicate<>();
+ }
+
+ public enum ArangoDBBinaryComparisonOperator implements BinaryOperatorNode.Operator {
+ EQUALS("=="), NOT_EQUALS("!="), LESS_THAN("<"), LESS_OR_EQUAL("<="), GREATER_THAN(">"), GREATER_OR_EQUAL(">=");
+
+ private final String representation;
+
+ ArangoDBBinaryComparisonOperator(String representation) {
+ this.representation = representation;
+ }
+
+ @Override
+ public String getTextRepresentation() {
+ return representation;
+ }
+
+ public static ArangoDBBinaryComparisonOperator getRandom() {
+ return Randomly.fromOptions(values());
+ }
+ }
+
+ public enum ArangoDBUnaryPrefixOperator implements BinaryOperatorNode.Operator {
+ NOT("!");
+
+ private final String representation;
+
+ ArangoDBUnaryPrefixOperator(String representation) {
+ this.representation = representation;
+ }
+
+ @Override
+ public String getTextRepresentation() {
+ return representation;
+ }
+
+ public static ArangoDBUnaryPrefixOperator getRandom() {
+ return Randomly.fromOptions(values());
+ }
+ }
+
+ public enum ArangoDBBinaryLogicalOperator implements BinaryOperatorNode.Operator {
+ AND, OR;
+
+ @Override
+ public String getTextRepresentation() {
+ return toString();
+ }
+
+ public static BinaryOperatorNode.Operator getRandom() {
+ return Randomly.fromOptions(values());
+ }
+ }
+
+}
diff --git a/src/sqlancer/arangodb/gen/ArangoDBInsertGenerator.java b/src/sqlancer/arangodb/gen/ArangoDBInsertGenerator.java
new file mode 100644
index 000000000..9a27ccd57
--- /dev/null
+++ b/src/sqlancer/arangodb/gen/ArangoDBInsertGenerator.java
@@ -0,0 +1,39 @@
+package sqlancer.arangodb.gen;
+
+import com.arangodb.entity.BaseDocument;
+
+import sqlancer.arangodb.ArangoDBProvider;
+import sqlancer.arangodb.ArangoDBQueryAdapter;
+import sqlancer.arangodb.ArangoDBSchema;
+import sqlancer.arangodb.query.ArangoDBConstantGenerator;
+import sqlancer.arangodb.query.ArangoDBInsertQuery;
+
+public final class ArangoDBInsertGenerator {
+
+ private final ArangoDBProvider.ArangoDBGlobalState globalState;
+
+ private ArangoDBInsertGenerator(ArangoDBProvider.ArangoDBGlobalState globalState) {
+ this.globalState = globalState;
+ }
+
+ public static ArangoDBQueryAdapter getQuery(ArangoDBProvider.ArangoDBGlobalState globalState) {
+ return new ArangoDBInsertGenerator(globalState).generate();
+ }
+
+ private ArangoDBQueryAdapter generate() {
+ BaseDocument result = new BaseDocument();
+ ArangoDBSchema.ArangoDBTable table = globalState.getSchema().getRandomTable();
+ ArangoDBConstantGenerator constantGenerator = new ArangoDBConstantGenerator(globalState);
+
+ for (int i = 0; i < table.getColumns().size(); i++) {
+ if (!globalState.getDmbsSpecificOptions().testRandomTypeInserts) {
+ constantGenerator.addRandomConstantWithType(result, table.getColumns().get(i).getName(),
+ table.getColumns().get(i).getType());
+ } else {
+ constantGenerator.addRandomConstant(result, table.getColumns().get(i).getName());
+ }
+ }
+
+ return new ArangoDBInsertQuery(table, result);
+ }
+}
diff --git a/src/sqlancer/arangodb/gen/ArangoDBTableGenerator.java b/src/sqlancer/arangodb/gen/ArangoDBTableGenerator.java
new file mode 100644
index 000000000..1236c3ce4
--- /dev/null
+++ b/src/sqlancer/arangodb/gen/ArangoDBTableGenerator.java
@@ -0,0 +1,44 @@
+package sqlancer.arangodb.gen;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import sqlancer.Randomly;
+import sqlancer.arangodb.ArangoDBProvider;
+import sqlancer.arangodb.ArangoDBQueryAdapter;
+import sqlancer.arangodb.ArangoDBSchema;
+import sqlancer.arangodb.query.ArangoDBCreateTableQuery;
+
+public class ArangoDBTableGenerator {
+
+ private ArangoDBSchema.ArangoDBTable table;
+ private final List columnsToBeAdded = new ArrayList<>();
+
+ public ArangoDBQueryAdapter getQuery(ArangoDBProvider.ArangoDBGlobalState globalState) {
+ String tableName = globalState.getSchema().getFreeTableName();
+ ArangoDBCreateTableQuery createTableQuery = new ArangoDBCreateTableQuery(tableName);
+ table = new ArangoDBSchema.ArangoDBTable(tableName, columnsToBeAdded, false);
+ for (int i = 0; i < Randomly.smallNumber() + 1; i++) {
+ String columnName = String.format("c%d", i);
+ createColumn(columnName);
+ }
+ globalState.addTable(table);
+ return createTableQuery;
+ }
+
+ private ArangoDBSchema.ArangoDBDataType createColumn(String columnName) {
+ ArangoDBSchema.ArangoDBDataType dataType = ArangoDBSchema.ArangoDBDataType.getRandom();
+ ArangoDBSchema.ArangoDBColumn newColumn = new ArangoDBSchema.ArangoDBColumn(columnName, dataType, false, false);
+ newColumn.setTable(table);
+ columnsToBeAdded.add(newColumn);
+ return dataType;
+ }
+
+ public String getTableName() {
+ return table.getName();
+ }
+
+ public ArangoDBSchema.ArangoDBTable getGeneratedTable() {
+ return table;
+ }
+}
diff --git a/src/sqlancer/arangodb/query/ArangoDBConstantGenerator.java b/src/sqlancer/arangodb/query/ArangoDBConstantGenerator.java
new file mode 100644
index 000000000..406e8adca
--- /dev/null
+++ b/src/sqlancer/arangodb/query/ArangoDBConstantGenerator.java
@@ -0,0 +1,46 @@
+package sqlancer.arangodb.query;
+
+import com.arangodb.entity.BaseDocument;
+
+import sqlancer.Randomly;
+import sqlancer.arangodb.ArangoDBProvider;
+import sqlancer.arangodb.ArangoDBSchema;
+import sqlancer.arangodb.ast.ArangoDBConstant;
+
+public class ArangoDBConstantGenerator {
+ private final ArangoDBProvider.ArangoDBGlobalState globalState;
+
+ public ArangoDBConstantGenerator(ArangoDBProvider.ArangoDBGlobalState globalState) {
+ this.globalState = globalState;
+ }
+
+ public void addRandomConstant(BaseDocument document, String key) {
+ ArangoDBSchema.ArangoDBDataType type = ArangoDBSchema.ArangoDBDataType.getRandom();
+ addRandomConstantWithType(document, key, type);
+ }
+
+ public void addRandomConstantWithType(BaseDocument document, String key, ArangoDBSchema.ArangoDBDataType dataType) {
+ ArangoDBConstant constant;
+ switch (dataType) {
+ case STRING:
+ constant = new ArangoDBConstant.ArangoDBStringConstant(globalState.getRandomly().getString());
+ constant.setValueInDocument(document, key);
+ return;
+ case DOUBLE:
+ constant = new ArangoDBConstant.ArangoDBDoubleConstant(globalState.getRandomly().getDouble());
+ constant.setValueInDocument(document, key);
+ return;
+ case BOOLEAN:
+ constant = new ArangoDBConstant.ArangoDBBooleanConstant(Randomly.getBoolean());
+ constant.setValueInDocument(document, key);
+ return;
+ case INTEGER:
+ constant = new ArangoDBConstant.ArangoDBIntegerConstant((int) globalState.getRandomly().getInteger());
+ constant.setValueInDocument(document, key);
+ return;
+ default:
+ throw new AssertionError(dataType);
+ }
+
+ }
+}
diff --git a/src/sqlancer/arangodb/query/ArangoDBCreateIndexQuery.java b/src/sqlancer/arangodb/query/ArangoDBCreateIndexQuery.java
new file mode 100644
index 000000000..6c2cc1b75
--- /dev/null
+++ b/src/sqlancer/arangodb/query/ArangoDBCreateIndexQuery.java
@@ -0,0 +1,54 @@
+package sqlancer.arangodb.query;
+
+import java.util.Collections;
+
+import com.arangodb.ArangoCollection;
+
+import sqlancer.GlobalState;
+import sqlancer.Main;
+import sqlancer.arangodb.ArangoDBConnection;
+import sqlancer.arangodb.ArangoDBQueryAdapter;
+import sqlancer.arangodb.ArangoDBSchema;
+import sqlancer.common.query.ExpectedErrors;
+
+public class ArangoDBCreateIndexQuery extends ArangoDBQueryAdapter {
+
+ private final ArangoDBSchema.ArangoDBColumn column;
+
+ public ArangoDBCreateIndexQuery(ArangoDBSchema.ArangoDBColumn column) {
+ this.column = column;
+ }
+
+ @Override
+ public boolean couldAffectSchema() {
+ return false;
+ }
+
+ @Override
+ public > boolean execute(G globalState, String... fills)
+ throws Exception {
+ try {
+ ArangoCollection collection = globalState.getConnection().getDatabase()
+ .collection(column.getTable().getName());
+ collection.ensureHashIndex(Collections.singletonList(column.getName()), null);
+ Main.nrSuccessfulActions.addAndGet(1);
+ return true;
+ } catch (Exception e) {
+ Main.nrUnsuccessfulActions.addAndGet(1);
+ throw e;
+ }
+ }
+
+ @Override
+ public ExpectedErrors getExpectedErrors() {
+ return new ExpectedErrors();
+ }
+
+ @Override
+ public String getLogString() {
+ StringBuilder stringBuilder = new StringBuilder();
+ stringBuilder.append("db.").append(column.getTable().getName())
+ .append(".ensureIndex({type: \"hash\", fields: [ \"").append(column.getName()).append("\" ]});");
+ return stringBuilder.toString();
+ }
+}
diff --git a/src/sqlancer/arangodb/query/ArangoDBCreateTableQuery.java b/src/sqlancer/arangodb/query/ArangoDBCreateTableQuery.java
new file mode 100644
index 000000000..00b3276d0
--- /dev/null
+++ b/src/sqlancer/arangodb/query/ArangoDBCreateTableQuery.java
@@ -0,0 +1,44 @@
+package sqlancer.arangodb.query;
+
+import sqlancer.GlobalState;
+import sqlancer.Main;
+import sqlancer.arangodb.ArangoDBConnection;
+import sqlancer.arangodb.ArangoDBQueryAdapter;
+import sqlancer.common.query.ExpectedErrors;
+
+public class ArangoDBCreateTableQuery extends ArangoDBQueryAdapter {
+
+ private final String tableName;
+
+ public ArangoDBCreateTableQuery(String tableName) {
+ this.tableName = tableName;
+ }
+
+ @Override
+ public boolean couldAffectSchema() {
+ return true;
+ }
+
+ @Override
+ public > boolean execute(G globalState, String... fills)
+ throws Exception {
+ try {
+ globalState.getConnection().getDatabase().createCollection(tableName);
+ Main.nrSuccessfulActions.addAndGet(1);
+ return true;
+ } catch (Exception e) {
+ Main.nrUnsuccessfulActions.addAndGet(1);
+ throw e;
+ }
+ }
+
+ @Override
+ public ExpectedErrors getExpectedErrors() {
+ return new ExpectedErrors();
+ }
+
+ @Override
+ public String getLogString() {
+ return "db._create(\"" + tableName + "\")";
+ }
+}
diff --git a/src/sqlancer/arangodb/query/ArangoDBInsertQuery.java b/src/sqlancer/arangodb/query/ArangoDBInsertQuery.java
new file mode 100644
index 000000000..9a3612062
--- /dev/null
+++ b/src/sqlancer/arangodb/query/ArangoDBInsertQuery.java
@@ -0,0 +1,66 @@
+package sqlancer.arangodb.query;
+
+import java.util.Map;
+
+import com.arangodb.entity.BaseDocument;
+
+import sqlancer.GlobalState;
+import sqlancer.Main;
+import sqlancer.arangodb.ArangoDBConnection;
+import sqlancer.arangodb.ArangoDBQueryAdapter;
+import sqlancer.arangodb.ArangoDBSchema;
+import sqlancer.common.query.ExpectedErrors;
+
+public class ArangoDBInsertQuery extends ArangoDBQueryAdapter {
+
+ private final ArangoDBSchema.ArangoDBTable table;
+ private final BaseDocument documentToBeInserted;
+
+ public ArangoDBInsertQuery(ArangoDBSchema.ArangoDBTable table, BaseDocument documentToBeInserted) {
+ this.table = table;
+ this.documentToBeInserted = documentToBeInserted;
+ }
+
+ @Override
+ public boolean couldAffectSchema() {
+ return true;
+ }
+
+ @Override
+ public > boolean execute(G globalState, String... fills)
+ throws Exception {
+ try {
+ globalState.getConnection().getDatabase().collection(table.getName()).insertDocument(documentToBeInserted);
+ Main.nrSuccessfulActions.addAndGet(1);
+ return true;
+ } catch (Exception e) {
+ Main.nrUnsuccessfulActions.addAndGet(1);
+ throw e;
+ }
+ }
+
+ @Override
+ public ExpectedErrors getExpectedErrors() {
+ return new ExpectedErrors();
+ }
+
+ @Override
+ public String getLogString() {
+ StringBuilder stringBuilder = new StringBuilder();
+ stringBuilder.append("db._query(\"INSERT { ");
+ String filler = "";
+ for (Map.Entry stringObjectEntry : documentToBeInserted.getProperties().entrySet()) {
+ stringBuilder.append(filler);
+ filler = ", ";
+ stringBuilder.append(stringObjectEntry.getKey()).append(": ");
+ Object value = stringObjectEntry.getValue();
+ if (value instanceof String) {
+ stringBuilder.append("'").append(value).append("'");
+ } else {
+ stringBuilder.append(value);
+ }
+ }
+ stringBuilder.append("} IN ").append(table.getName()).append("\")");
+ return stringBuilder.toString();
+ }
+}
diff --git a/src/sqlancer/arangodb/query/ArangoDBSelectQuery.java b/src/sqlancer/arangodb/query/ArangoDBSelectQuery.java
new file mode 100644
index 000000000..400585ca2
--- /dev/null
+++ b/src/sqlancer/arangodb/query/ArangoDBSelectQuery.java
@@ -0,0 +1,65 @@
+package sqlancer.arangodb.query;
+
+import java.io.IOException;
+import java.util.List;
+
+import com.arangodb.ArangoCursor;
+import com.arangodb.entity.BaseDocument;
+
+import sqlancer.GlobalState;
+import sqlancer.arangodb.ArangoDBConnection;
+import sqlancer.arangodb.ArangoDBQueryAdapter;
+import sqlancer.common.query.ExpectedErrors;
+import sqlancer.common.query.SQLancerResultSet;
+
+public class ArangoDBSelectQuery extends ArangoDBQueryAdapter {
+
+ private final String query;
+
+ private List resultSet;
+
+ public ArangoDBSelectQuery(String query) {
+ this.query = query;
+ }
+
+ @Override
+ public boolean couldAffectSchema() {
+ return false;
+ }
+
+ @Override
+ public > boolean execute(G globalState, String... fills)
+ throws Exception {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public ExpectedErrors getExpectedErrors() {
+ return new ExpectedErrors();
+ }
+
+ @Override
+ public String getLogString() {
+ return "db._query(\"" + query + "\")";
+ }
+
+ @Override
+ public > SQLancerResultSet executeAndGet(G globalState,
+ String... fills) throws Exception {
+ if (globalState.getOptions().logEachSelect()) {
+ globalState.getLogger().writeCurrent(this.getLogString());
+ try {
+ globalState.getLogger().getCurrentFileWriter().flush();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+ ArangoCursor cursor = globalState.getConnection().getDatabase().query(query, BaseDocument.class);
+ resultSet = cursor.asListRemaining();
+ return null;
+ }
+
+ public List getResultSet() {
+ return resultSet;
+ }
+}
diff --git a/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningBase.java b/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningBase.java
new file mode 100644
index 000000000..f583ed04f
--- /dev/null
+++ b/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningBase.java
@@ -0,0 +1,67 @@
+package sqlancer.arangodb.test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import sqlancer.Randomly;
+import sqlancer.arangodb.ArangoDBProvider;
+import sqlancer.arangodb.ArangoDBSchema;
+import sqlancer.arangodb.ast.ArangoDBExpression;
+import sqlancer.arangodb.ast.ArangoDBSelect;
+import sqlancer.arangodb.gen.ArangoDBComputedExpressionGenerator;
+import sqlancer.arangodb.gen.ArangoDBFilterExpressionGenerator;
+import sqlancer.common.ast.newast.Node;
+import sqlancer.common.gen.ExpressionGenerator;
+import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase;
+import sqlancer.common.oracle.TestOracle;
+
+public class ArangoDBQueryPartitioningBase
+ extends TernaryLogicPartitioningOracleBase, ArangoDBProvider.ArangoDBGlobalState>
+ implements TestOracle {
+
+ protected ArangoDBSchema schema;
+ protected List targetColumns;
+ protected ArangoDBFilterExpressionGenerator expressionGenerator;
+ protected ArangoDBSelect select;
+ protected int numberComputedColumns;
+
+ protected ArangoDBQueryPartitioningBase(ArangoDBProvider.ArangoDBGlobalState state) {
+ super(state);
+ }
+
+ @Override
+ protected ExpressionGenerator> getGen() {
+ return expressionGenerator;
+ }
+
+ @Override
+ public void check() throws Exception {
+ numberComputedColumns = state.getRandomly().getInteger(0, 4);
+ schema = state.getSchema();
+ generateTargetColumns();
+ expressionGenerator = new ArangoDBFilterExpressionGenerator(state).setColumns(targetColumns);
+ expressionGenerator.setNumberOfComputedVariables(numberComputedColumns);
+ initializeTernaryPredicateVariants();
+ select = new ArangoDBSelect<>();
+ select.setFromColumns(targetColumns);
+ select.setProjectionColumns(Randomly.nonEmptySubset(targetColumns));
+ generateComputedClause();
+ }
+
+ private void generateComputedClause() {
+ List> computedColumns = new ArrayList<>();
+ ArangoDBComputedExpressionGenerator generator = new ArangoDBComputedExpressionGenerator(state);
+ generator.setColumns(targetColumns);
+ for (int i = 0; i < numberComputedColumns; i++) {
+ computedColumns.add(generator.generateExpression());
+ }
+ select.setComputedClause(computedColumns);
+ }
+
+ private void generateTargetColumns() {
+ ArangoDBSchema.ArangoDBTables targetTables;
+ targetTables = schema.getRandomTableNonEmptyTables();
+ List allColumns = targetTables.getColumns();
+ targetColumns = Randomly.nonEmptySubset(allColumns);
+ }
+}
diff --git a/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningWhereTester.java b/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningWhereTester.java
new file mode 100644
index 000000000..80b7d46bf
--- /dev/null
+++ b/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningWhereTester.java
@@ -0,0 +1,38 @@
+package sqlancer.arangodb.test;
+
+import static sqlancer.arangodb.ArangoDBComparatorHelper.assumeResultSetsAreEqual;
+import static sqlancer.arangodb.ArangoDBComparatorHelper.getResultSetAsDocumentList;
+
+import java.util.List;
+
+import com.arangodb.entity.BaseDocument;
+
+import sqlancer.arangodb.ArangoDBProvider;
+import sqlancer.arangodb.query.ArangoDBSelectQuery;
+import sqlancer.arangodb.visitor.ArangoDBVisitor;
+
+public class ArangoDBQueryPartitioningWhereTester extends ArangoDBQueryPartitioningBase {
+ public ArangoDBQueryPartitioningWhereTester(ArangoDBProvider.ArangoDBGlobalState state) {
+ super(state);
+ }
+
+ @Override
+ public void check() throws Exception {
+ super.check();
+ select.setFilterClause(null);
+
+ ArangoDBSelectQuery query = ArangoDBVisitor.asSelectQuery(select);
+ List firstResultSet = getResultSetAsDocumentList(query, state);
+
+ select.setFilterClause(predicate);
+ query = ArangoDBVisitor.asSelectQuery(select);
+ List secondResultSet = getResultSetAsDocumentList(query, state);
+
+ select.setFilterClause(negatedPredicate);
+ query = ArangoDBVisitor.asSelectQuery(select);
+ List thirdResultSet = getResultSetAsDocumentList(query, state);
+
+ secondResultSet.addAll(thirdResultSet);
+ assumeResultSetsAreEqual(firstResultSet, secondResultSet, query);
+ }
+}
diff --git a/src/sqlancer/arangodb/visitor/ArangoDBToQueryVisitor.java b/src/sqlancer/arangodb/visitor/ArangoDBToQueryVisitor.java
new file mode 100644
index 000000000..f82995d5e
--- /dev/null
+++ b/src/sqlancer/arangodb/visitor/ArangoDBToQueryVisitor.java
@@ -0,0 +1,134 @@
+package sqlancer.arangodb.visitor;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import sqlancer.arangodb.ArangoDBSchema;
+import sqlancer.arangodb.ast.ArangoDBConstant;
+import sqlancer.arangodb.ast.ArangoDBExpression;
+import sqlancer.arangodb.ast.ArangoDBSelect;
+import sqlancer.arangodb.gen.ArangoDBComputedExpressionGenerator;
+import sqlancer.arangodb.query.ArangoDBSelectQuery;
+import sqlancer.common.ast.newast.ColumnReferenceNode;
+import sqlancer.common.ast.newast.NewBinaryOperatorNode;
+import sqlancer.common.ast.newast.NewFunctionNode;
+import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode;
+import sqlancer.common.ast.newast.Node;
+
+public class ArangoDBToQueryVisitor extends ArangoDBVisitor {
+
+ private final StringBuilder stringBuilder;
+
+ public ArangoDBToQueryVisitor() {
+ stringBuilder = new StringBuilder();
+ }
+
+ @Override
+ protected void visit(ArangoDBSelect expression) {
+ generateFrom(expression);
+ generateComputed(expression);
+ generateFilter(expression);
+ generateProject(expression);
+ }
+
+ private void generateFilter(ArangoDBSelect expression) {
+ if (expression.hasFilter()) {
+ stringBuilder.append("FILTER ");
+ visit(expression.getFilterClause());
+ stringBuilder.append(" ");
+ }
+ }
+
+ private void generateComputed(ArangoDBSelect expression) {
+ if (expression.hasComputed()) {
+ List> computedClause = expression.getComputedClause();
+ int computedNumber = 0;
+ for (Node computedExpression : computedClause) {
+ stringBuilder.append("LET c").append(computedNumber).append(" = ");
+ visit(computedExpression);
+ stringBuilder.append(" ");
+ computedNumber++;
+ }
+ }
+ }
+
+ @Override
+ protected void visit(ColumnReferenceNode expression) {
+ if (expression.getColumn().getTable().getName().equals("")) {
+ stringBuilder.append(expression.getColumn().getName());
+ } else {
+ stringBuilder.append("r").append(expression.getColumn().getTable().getName()).append(".")
+ .append(expression.getColumn().getName());
+ }
+ }
+
+ @Override
+ protected void visit(ArangoDBConstant expression) {
+ stringBuilder.append(expression.getValue());
+ }
+
+ @Override
+ protected void visit(NewBinaryOperatorNode expression) {
+ stringBuilder.append("(");
+ visit(expression.getLeft());
+ stringBuilder.append(" ").append(expression.getOperatorRepresentation()).append(" ");
+ visit(expression.getRight());
+ stringBuilder.append(")");
+ }
+
+ @Override
+ protected void visit(NewUnaryPrefixOperatorNode expression) {
+ stringBuilder.append(expression.getOperatorRepresentation()).append("(");
+ visit(expression.getExpr());
+ stringBuilder.append(")");
+ }
+
+ @Override
+ protected void visit(NewFunctionNode expression) {
+ if (!(expression.getFunc() instanceof ArangoDBComputedExpressionGenerator.ComputedFunction)) {
+ throw new UnsupportedOperationException();
+ }
+ ArangoDBComputedExpressionGenerator.ComputedFunction function = (ArangoDBComputedExpressionGenerator.ComputedFunction) expression
+ .getFunc();
+ // TODO: Support functions with a different number of arguments.
+ if (function.getNrArgs() != 2) {
+ throw new UnsupportedOperationException();
+ }
+ stringBuilder.append("(");
+ visit(expression.getArgs().get(0));
+ stringBuilder.append(" ").append(function.getOperatorName()).append(" ");
+ visit(expression.getArgs().get(1));
+ stringBuilder.append(")");
+ }
+
+ private void generateFrom(ArangoDBSelect expression) {
+ List forColumns = expression.getFromColumns();
+ Set tables = new HashSet<>();
+ for (ArangoDBSchema.ArangoDBColumn column : forColumns) {
+ tables.add(column.getTable());
+ }
+
+ for (ArangoDBSchema.ArangoDBTable table : tables) {
+ stringBuilder.append("FOR r").append(table.getName()).append(" IN ").append(table.getName()).append(" ");
+ }
+ }
+
+ private void generateProject(ArangoDBSelect expression) {
+ List projectColumns = expression.getProjectionColumns();
+ stringBuilder.append("RETURN {");
+ String filler = "";
+ for (ArangoDBSchema.ArangoDBColumn column : projectColumns) {
+ stringBuilder.append(filler);
+ filler = ", ";
+ stringBuilder.append(column.getTable().getName()).append("_").append(column.getName()).append(": r")
+ .append(column.getTable().getName()).append(".").append(column.getName());
+ }
+ stringBuilder.append("}");
+ }
+
+ public ArangoDBSelectQuery getQuery() {
+ return new ArangoDBSelectQuery(stringBuilder.toString());
+ }
+
+}
diff --git a/src/sqlancer/arangodb/visitor/ArangoDBVisitor.java b/src/sqlancer/arangodb/visitor/ArangoDBVisitor.java
new file mode 100644
index 000000000..f1db84cf5
--- /dev/null
+++ b/src/sqlancer/arangodb/visitor/ArangoDBVisitor.java
@@ -0,0 +1,51 @@
+package sqlancer.arangodb.visitor;
+
+import sqlancer.arangodb.ast.ArangoDBConstant;
+import sqlancer.arangodb.ast.ArangoDBExpression;
+import sqlancer.arangodb.ast.ArangoDBSelect;
+import sqlancer.arangodb.query.ArangoDBSelectQuery;
+import sqlancer.common.ast.newast.ColumnReferenceNode;
+import sqlancer.common.ast.newast.NewBinaryOperatorNode;
+import sqlancer.common.ast.newast.NewFunctionNode;
+import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode;
+import sqlancer.common.ast.newast.Node;
+
+public abstract class ArangoDBVisitor {
+
+ protected abstract void visit(ArangoDBSelect expression);
+
+ protected abstract void visit(ColumnReferenceNode expression);
+
+ protected abstract void visit(ArangoDBConstant expression);
+
+ protected abstract void visit(NewBinaryOperatorNode expression);
+
+ protected abstract void visit(NewUnaryPrefixOperatorNode expression);
+
+ protected abstract void visit(NewFunctionNode expression);
+
+ @SuppressWarnings("unchecked")
+ public void visit(Node expressionNode) {
+ if (expressionNode instanceof ArangoDBSelect) {
+ visit((ArangoDBSelect) expressionNode);
+ } else if (expressionNode instanceof ColumnReferenceNode, ?>) {
+ visit((ColumnReferenceNode) expressionNode);
+ } else if (expressionNode instanceof ArangoDBConstant) {
+ visit((ArangoDBConstant) expressionNode);
+ } else if (expressionNode instanceof NewBinaryOperatorNode>) {
+ visit((NewBinaryOperatorNode) expressionNode);
+ } else if (expressionNode instanceof NewUnaryPrefixOperatorNode>) {
+ visit((NewUnaryPrefixOperatorNode) expressionNode);
+ } else if (expressionNode instanceof NewFunctionNode, ?>) {
+ visit((NewFunctionNode) expressionNode);
+ } else {
+ throw new AssertionError(expressionNode);
+ }
+ }
+
+ public static ArangoDBSelectQuery asSelectQuery(Node expressionNode) {
+ ArangoDBToQueryVisitor visitor = new ArangoDBToQueryVisitor();
+ visitor.visit(expressionNode);
+ return visitor.getQuery();
+ }
+}
diff --git a/src/sqlancer/common/oracle/RemoveReduceOracleBase.java b/src/sqlancer/common/oracle/RemoveReduceOracleBase.java
new file mode 100644
index 000000000..c177f072e
--- /dev/null
+++ b/src/sqlancer/common/oracle/RemoveReduceOracleBase.java
@@ -0,0 +1,29 @@
+package sqlancer.common.oracle;
+
+import sqlancer.GlobalState;
+import sqlancer.common.gen.ExpressionGenerator;
+
+public abstract class RemoveReduceOracleBase> implements TestOracle {
+
+ protected E predicate;
+
+ protected final S state;
+
+ protected RemoveReduceOracleBase(S state) {
+ this.state = state;
+ }
+
+ protected void initializeRemoveReduceOracle() {
+ ExpressionGenerator gen = getGen();
+ if (gen == null) {
+ throw new IllegalStateException();
+ }
+ predicate = gen.generatePredicate();
+ if (predicate == null) {
+ throw new IllegalStateException();
+ }
+ }
+
+ protected abstract ExpressionGenerator getGen();
+
+}
diff --git a/src/sqlancer/common/oracle/TernaryLogicPartitioningOracleBase.java b/src/sqlancer/common/oracle/TernaryLogicPartitioningOracleBase.java
index 991824628..3b5d87814 100644
--- a/src/sqlancer/common/oracle/TernaryLogicPartitioningOracleBase.java
+++ b/src/sqlancer/common/oracle/TernaryLogicPartitioningOracleBase.java
@@ -1,6 +1,6 @@
package sqlancer.common.oracle;
-import sqlancer.SQLGlobalState;
+import sqlancer.GlobalState;
import sqlancer.common.gen.ExpressionGenerator;
import sqlancer.common.query.ExpectedErrors;
@@ -14,7 +14,7 @@
* @param
* the global state type
*/
-public abstract class TernaryLogicPartitioningOracleBase> implements TestOracle {
+public abstract class TernaryLogicPartitioningOracleBase> implements TestOracle {
protected E predicate;
protected E negatedPredicate;
diff --git a/src/sqlancer/cosmos/CosmosProvider.java b/src/sqlancer/cosmos/CosmosProvider.java
new file mode 100644
index 000000000..a8681d3af
--- /dev/null
+++ b/src/sqlancer/cosmos/CosmosProvider.java
@@ -0,0 +1,74 @@
+package sqlancer.cosmos;
+
+import com.mongodb.ConnectionString;
+import com.mongodb.MongoClientSettings;
+import com.mongodb.client.MongoClient;
+import com.mongodb.client.MongoClients;
+import com.mongodb.client.MongoDatabase;
+
+import sqlancer.IgnoreMeException;
+import sqlancer.ProviderAdapter;
+import sqlancer.Randomly;
+import sqlancer.StatementExecutor;
+import sqlancer.common.log.LoggableFactory;
+import sqlancer.mongodb.MongoDBConnection;
+import sqlancer.mongodb.MongoDBLoggableFactory;
+import sqlancer.mongodb.MongoDBOptions;
+import sqlancer.mongodb.MongoDBQueryAdapter;
+import sqlancer.mongodb.gen.MongoDBTableGenerator;
+
+public class CosmosProvider extends
+ ProviderAdapter {
+
+ public CosmosProvider() {
+ super(sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState.class, MongoDBOptions.class);
+ }
+
+ @Override
+ public void generateDatabase(sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState globalState) throws Exception {
+ for (int i = 0; i < Randomly.fromOptions(4, 5, 6); i++) {
+ boolean success;
+ do {
+ MongoDBQueryAdapter query = new MongoDBTableGenerator(globalState).getQuery(globalState);
+ success = globalState.executeStatement(query);
+ } while (!success);
+ }
+ StatementExecutor se = new StatementExecutor<>(
+ globalState, sqlancer.mongodb.MongoDBProvider.Action.values(),
+ sqlancer.mongodb.MongoDBProvider::mapActions, (q) -> {
+ if (globalState.getSchema().getDatabaseTables().isEmpty()) {
+ throw new IgnoreMeException();
+ }
+ });
+ se.executeStatements();
+ }
+
+ @Override
+ public MongoDBConnection createDatabase(sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState globalState)
+ throws Exception {
+ String connectionString = "";
+ if (connectionString.equals("")) {
+ throw new AssertionError("Please set connection string for cosmos database, located in CosmosProvider");
+ }
+ MongoClientSettings settings = MongoClientSettings.builder()
+ .applyConnectionString(new ConnectionString(connectionString)).build();
+ MongoClient mongoClient = MongoClients.create(settings);
+ MongoDatabase database = mongoClient.getDatabase(globalState.getDatabaseName());
+ database.drop();
+ return new MongoDBConnection(mongoClient, database);
+ }
+
+ @Override
+ public String getDBMSName() {
+ return "cosmos";
+ }
+
+ @Override
+ public LoggableFactory getLoggableFactory() {
+ return new MongoDBLoggableFactory();
+ }
+
+ @Override
+ protected void checkViewsAreValid(sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState globalState) {
+ }
+}
diff --git a/src/sqlancer/mongodb/MongoDBComparatorHelper.java b/src/sqlancer/mongodb/MongoDBComparatorHelper.java
new file mode 100644
index 000000000..49b692645
--- /dev/null
+++ b/src/sqlancer/mongodb/MongoDBComparatorHelper.java
@@ -0,0 +1,97 @@
+package sqlancer.mongodb;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import org.bson.Document;
+
+import sqlancer.IgnoreMeException;
+import sqlancer.Main;
+import sqlancer.common.query.ExpectedErrors;
+import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState;
+import sqlancer.mongodb.query.MongoDBSelectQuery;
+
+public final class MongoDBComparatorHelper {
+
+ private MongoDBComparatorHelper() {
+ }
+
+ public static List getResultSetAsDocumentList(MongoDBSelectQuery adapter, MongoDBGlobalState state)
+ throws Exception {
+ ExpectedErrors errors = adapter.getExpectedErrors();
+ List result;
+ try {
+ adapter.executeAndGet(state);
+ Main.nrSuccessfulActions.addAndGet(1);
+ result = adapter.getResultSet();
+ return result;
+ } catch (Exception e) {
+ if (e instanceof IgnoreMeException) {
+ throw e;
+ }
+ Main.nrUnsuccessfulActions.addAndGet(1);
+ if (e.getMessage() == null) {
+ throw new AssertionError(adapter.getLogString(), e);
+ }
+ if (errors.errorIsExpected(e.getMessage())) {
+ throw new IgnoreMeException();
+ }
+ throw new AssertionError(adapter.getLogString(), e);
+ }
+ }
+
+ public static void assumeCountIsEqual(List resultSet, List secondResultSet,
+ MongoDBSelectQuery originalQuery) {
+ int originalSize = resultSet.size();
+ if (secondResultSet.isEmpty()) {
+ if (originalSize == 0) {
+ return;
+ } else {
+ String assertMessage = String.format("The Count of the result set mismatches!\n %s",
+ originalQuery.getLogString());
+ throw new AssertionError(assertMessage);
+ }
+ }
+ if (secondResultSet.size() != 1) {
+ throw new AssertionError(
+ String.format("Count query result bigger than one \n %s", originalQuery.getLogString()));
+ }
+ int withCount = (int) secondResultSet.get(0).get("count");
+ if (originalSize != withCount) {
+ String assertMessage = String.format("The Count of the result set mismatches!\n %s",
+ originalQuery.getLogString());
+ throw new AssertionError(assertMessage);
+ }
+ }
+
+ public static void assumeResultSetsAreEqual(List resultSet, List secondResultSet,
+ MongoDBSelectQuery originalQuery) {
+ if (resultSet.size() != secondResultSet.size()) {
+ String assertionMessage = String.format("The Size of the result sets mismatch (%d and %d)!\n%s",
+ resultSet.size(), secondResultSet.size(), originalQuery.getLogString());
+ throw new AssertionError(assertionMessage);
+ }
+
+ Set firstHashSet = new HashSet<>(resultSet);
+ Set secondHashSet = new HashSet<>(secondResultSet);
+
+ if (!firstHashSet.equals(secondHashSet)) {
+ Set firstResultSetMisses = new HashSet<>(firstHashSet);
+ firstResultSetMisses.removeAll(secondHashSet);
+ Set secondResultSetMisses = new HashSet<>(secondHashSet);
+ secondResultSetMisses.removeAll(firstHashSet);
+ StringBuilder firstMisses = new StringBuilder();
+ for (Document document : firstResultSetMisses) {
+ firstMisses.append(document.toJson()).append(" ");
+ }
+ StringBuilder secondMisses = new StringBuilder();
+ for (Document document : secondResultSetMisses) {
+ secondMisses.append(document.toJson()).append(" ");
+ }
+ String assertMessage = String.format("The Content of the result sets mismatch!\n %s \n %s\n %s",
+ firstMisses.toString(), secondMisses.toString(), originalQuery.getLogString());
+ throw new AssertionError(assertMessage);
+ }
+ }
+}
diff --git a/src/sqlancer/mongodb/MongoDBConnection.java b/src/sqlancer/mongodb/MongoDBConnection.java
new file mode 100644
index 000000000..6971bd79c
--- /dev/null
+++ b/src/sqlancer/mongodb/MongoDBConnection.java
@@ -0,0 +1,35 @@
+package sqlancer.mongodb;
+
+import org.bson.BsonDocument;
+import org.bson.BsonString;
+
+import com.mongodb.client.MongoClient;
+import com.mongodb.client.MongoDatabase;
+
+import sqlancer.SQLancerDBConnection;
+
+public class MongoDBConnection implements SQLancerDBConnection {
+
+ private final MongoClient client;
+ private final MongoDatabase database;
+
+ public MongoDBConnection(MongoClient client, MongoDatabase database) {
+ this.client = client;
+ this.database = database;
+ }
+
+ @Override
+ public String getDatabaseVersion() throws Exception {
+ return client.getDatabase("dbname").runCommand(new BsonDocument("buildinfo", new BsonString(""))).get("version")
+ .toString();
+ }
+
+ @Override
+ public void close() throws Exception {
+ client.close();
+ }
+
+ public MongoDatabase getDatabase() {
+ return database;
+ }
+}
diff --git a/src/sqlancer/mongodb/MongoDBLoggableFactory.java b/src/sqlancer/mongodb/MongoDBLoggableFactory.java
new file mode 100644
index 000000000..b668301b3
--- /dev/null
+++ b/src/sqlancer/mongodb/MongoDBLoggableFactory.java
@@ -0,0 +1,40 @@
+package sqlancer.mongodb;
+
+import java.util.Arrays;
+
+import sqlancer.common.log.Loggable;
+import sqlancer.common.log.LoggableFactory;
+import sqlancer.common.log.LoggedString;
+import sqlancer.common.query.Query;
+
+public class MongoDBLoggableFactory extends LoggableFactory {
+ @Override
+ protected Loggable createLoggable(String input, String suffix) {
+ return new LoggedString(input + suffix);
+ }
+
+ @Override
+ public Query> getQueryForStateToReproduce(String queryString) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Query> commentOutQuery(Query> query) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ protected Loggable infoToLoggable(String time, String databaseName, String databaseVersion, long seedValue) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("// Time: ").append(time).append("\n");
+ sb.append("// Database: ").append(databaseName).append("\n");
+ sb.append("// Database version: ").append(databaseVersion).append("\n");
+ sb.append("// seed value: ").append(seedValue).append("\n");
+ return new LoggedString(sb.toString());
+ }
+
+ @Override
+ public Loggable convertStacktraceToLoggable(Throwable throwable) {
+ return new LoggedString(Arrays.toString(throwable.getStackTrace()) + "\n" + throwable.getMessage());
+ }
+}
diff --git a/src/sqlancer/mongodb/MongoDBOptions.java b/src/sqlancer/mongodb/MongoDBOptions.java
new file mode 100644
index 000000000..cf632085b
--- /dev/null
+++ b/src/sqlancer/mongodb/MongoDBOptions.java
@@ -0,0 +1,71 @@
+package sqlancer.mongodb;
+
+import static sqlancer.mongodb.MongoDBOptions.MongoDBOracleFactory.QUERY_PARTITIONING;
+import static sqlancer.mongodb.MongoDBOptions.MongoDBOracleFactory.REMOVE_REDUCE;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import com.beust.jcommander.Parameter;
+
+import sqlancer.DBMSSpecificOptions;
+import sqlancer.OracleFactory;
+import sqlancer.common.oracle.CompositeTestOracle;
+import sqlancer.common.oracle.TestOracle;
+import sqlancer.mongodb.test.MongoDBQueryPartitioningWhereTester;
+import sqlancer.mongodb.test.MongoDBRemoveReduceTester;
+
+public class MongoDBOptions implements DBMSSpecificOptions {
+
+ @Parameter(names = "--test-validation", description = "Enable/Disable validation of schema with Schema Validation", arity = 1)
+ public boolean testValidation = true;
+
+ @Parameter(names = "--test-null-inserts", description = "Enables to test inserting with null values, validation has to be off", arity = 1)
+ public boolean testNullInserts;
+
+ @Parameter(names = "--test-random-types", description = "Insert random types instead of schema types, validation has to be off", arity = 1)
+ public boolean testRandomTypes;
+
+ @Parameter(names = "--max-number-indexes", description = "The maximum number of indexes used.", arity = 1)
+ public int maxNumberIndexes = 15;
+
+ @Parameter(names = "--test-computed-values", description = "Enable adding computed values to query", arity = 1)
+ public boolean testComputedValues;
+
+ @Parameter(names = "--test-with-regex", description = "Enable Regex Leaf Nodes", arity = 1)
+ public boolean testWithRegex;
+
+ @Parameter(names = "--test-with-count", description = "Count the number of documents and check with count command", arity = 1)
+ public boolean testWithCount;
+
+ @Parameter(names = "--null-safety", description = "", arity = 1)
+ public boolean nullSafety;
+
+ @Parameter(names = "--oracle")
+ public List oracles = Arrays.asList(QUERY_PARTITIONING, REMOVE_REDUCE);
+
+ @Override
+ public List getTestOracleFactory() {
+ return oracles;
+ }
+
+ public enum MongoDBOracleFactory implements OracleFactory {
+ QUERY_PARTITIONING {
+ @Override
+ public TestOracle create(MongoDBProvider.MongoDBGlobalState globalState) throws Exception {
+ List oracles = new ArrayList<>();
+ oracles.add(new MongoDBQueryPartitioningWhereTester(globalState));
+ return new CompositeTestOracle(oracles, globalState);
+ }
+ },
+ REMOVE_REDUCE {
+ @Override
+ public TestOracle create(MongoDBProvider.MongoDBGlobalState globalState) throws Exception {
+ List oracles = new ArrayList<>();
+ oracles.add(new MongoDBRemoveReduceTester(globalState));
+ return new CompositeTestOracle(oracles, globalState);
+ }
+ }
+ }
+}
diff --git a/src/sqlancer/mongodb/MongoDBProvider.java b/src/sqlancer/mongodb/MongoDBProvider.java
new file mode 100644
index 000000000..5ff549d32
--- /dev/null
+++ b/src/sqlancer/mongodb/MongoDBProvider.java
@@ -0,0 +1,125 @@
+package sqlancer.mongodb;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import com.mongodb.client.MongoClient;
+import com.mongodb.client.MongoClients;
+import com.mongodb.client.MongoDatabase;
+
+import sqlancer.AbstractAction;
+import sqlancer.ExecutionTimer;
+import sqlancer.GlobalState;
+import sqlancer.IgnoreMeException;
+import sqlancer.ProviderAdapter;
+import sqlancer.Randomly;
+import sqlancer.StatementExecutor;
+import sqlancer.common.log.LoggableFactory;
+import sqlancer.common.query.Query;
+import sqlancer.mongodb.MongoDBSchema.MongoDBTable;
+import sqlancer.mongodb.gen.MongoDBIndexGenerator;
+import sqlancer.mongodb.gen.MongoDBInsertGenerator;
+import sqlancer.mongodb.gen.MongoDBTableGenerator;
+
+public class MongoDBProvider
+ extends ProviderAdapter {
+
+ public MongoDBProvider() {
+ super(MongoDBGlobalState.class, MongoDBOptions.class);
+ }
+
+ public enum Action implements AbstractAction {
+ INSERT(MongoDBInsertGenerator::getQuery), CREATE_INDEX(MongoDBIndexGenerator::getQuery);
+
+ private final MongoDBQueryProvider queryProvider;
+
+ Action(MongoDBQueryProvider queryProvider) {
+ this.queryProvider = queryProvider;
+ }
+
+ @Override
+ public Query getQuery(MongoDBGlobalState globalState) throws Exception {
+ return queryProvider.getQuery(globalState);
+ }
+ }
+
+ public static int mapActions(MongoDBGlobalState globalState, Action a) {
+ Randomly r = globalState.getRandomly();
+ switch (a) {
+ case INSERT:
+ return r.getInteger(0, globalState.getOptions().getMaxNumberInserts());
+ case CREATE_INDEX:
+ return r.getInteger(0, globalState.getDmbsSpecificOptions().maxNumberIndexes);
+ default:
+ throw new AssertionError(a);
+ }
+ }
+
+ public static class MongoDBGlobalState extends GlobalState {
+
+ private final List schemaTables = new ArrayList<>();
+
+ public void addTable(MongoDBTable table) {
+ schemaTables.add(table);
+ }
+
+ @Override
+ protected void executeEpilogue(Query> q, boolean success, ExecutionTimer timer) throws Exception {
+ boolean logExecutionTime = getOptions().logExecutionTime();
+ if (success && getOptions().printSucceedingStatements()) {
+ System.out.println(q.getLogString());
+ }
+ if (logExecutionTime) {
+ getLogger().writeCurrent("// " + timer.end().asString());
+ }
+ if (q.couldAffectSchema()) {
+ updateSchema();
+ }
+ }
+
+ @Override
+ protected MongoDBSchema readSchema() throws Exception {
+ return new MongoDBSchema(schemaTables);
+ }
+ }
+
+ @Override
+ public void generateDatabase(MongoDBGlobalState globalState) throws Exception {
+ for (int i = 0; i < Randomly.fromOptions(4, 5, 6); i++) {
+ boolean success;
+ do {
+ MongoDBQueryAdapter query = new MongoDBTableGenerator(globalState).getQuery(globalState);
+ success = globalState.executeStatement(query);
+ } while (!success);
+ }
+ StatementExecutor se = new StatementExecutor<>(globalState, Action.values(),
+ MongoDBProvider::mapActions, (q) -> {
+ if (globalState.getSchema().getDatabaseTables().isEmpty()) {
+ throw new IgnoreMeException();
+ }
+ });
+ se.executeStatements();
+ }
+
+ @Override
+ public MongoDBConnection createDatabase(MongoDBGlobalState globalState) throws Exception {
+ MongoClient mongoClient = MongoClients.create();
+ MongoDatabase database = mongoClient.getDatabase(globalState.getDatabaseName());
+ database.drop();
+ return new MongoDBConnection(mongoClient, database);
+ }
+
+ @Override
+ public String getDBMSName() {
+ return "mongodb";
+ }
+
+ @Override
+ public LoggableFactory getLoggableFactory() {
+ return new MongoDBLoggableFactory();
+ }
+
+ @Override
+ protected void checkViewsAreValid(MongoDBGlobalState globalState) {
+ }
+}
diff --git a/src/sqlancer/mongodb/MongoDBQueryAdapter.java b/src/sqlancer/mongodb/MongoDBQueryAdapter.java
new file mode 100644
index 000000000..e2add3242
--- /dev/null
+++ b/src/sqlancer/mongodb/MongoDBQueryAdapter.java
@@ -0,0 +1,15 @@
+package sqlancer.mongodb;
+
+import sqlancer.common.query.Query;
+
+public abstract class MongoDBQueryAdapter extends Query {
+ @Override
+ public String getQueryString() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public String getUnterminatedQueryString() {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/src/sqlancer/mongodb/MongoDBQueryProvider.java b/src/sqlancer/mongodb/MongoDBQueryProvider.java
new file mode 100644
index 000000000..970c90cea
--- /dev/null
+++ b/src/sqlancer/mongodb/MongoDBQueryProvider.java
@@ -0,0 +1,6 @@
+package sqlancer.mongodb;
+
+@FunctionalInterface
+public interface MongoDBQueryProvider {
+ MongoDBQueryAdapter getQuery(S globalState) throws Exception;
+}
diff --git a/src/sqlancer/mongodb/MongoDBSchema.java b/src/sqlancer/mongodb/MongoDBSchema.java
new file mode 100644
index 000000000..e9a0afb99
--- /dev/null
+++ b/src/sqlancer/mongodb/MongoDBSchema.java
@@ -0,0 +1,97 @@
+package sqlancer.mongodb;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import org.bson.BsonType;
+
+import com.mongodb.client.MongoDatabase;
+
+import sqlancer.Randomly;
+import sqlancer.common.schema.AbstractSchema;
+import sqlancer.common.schema.AbstractTable;
+import sqlancer.common.schema.AbstractTableColumn;
+import sqlancer.common.schema.AbstractTables;
+import sqlancer.common.schema.TableIndex;
+import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState;
+
+public class MongoDBSchema extends AbstractSchema {
+
+ public enum MongoDBDataType {
+ INTEGER(BsonType.INT32), STRING(BsonType.STRING), BOOLEAN(BsonType.BOOLEAN), DOUBLE(BsonType.DOUBLE),
+ DATE_TIME(BsonType.DATE_TIME), TIMESTAMP(BsonType.TIMESTAMP);
+
+ private final BsonType bsonType;
+
+ MongoDBDataType(BsonType type) {
+ this.bsonType = type;
+ }
+
+ public BsonType getBsonType() {
+ return bsonType;
+ }
+
+ public static MongoDBDataType getRandom(MongoDBGlobalState state) {
+ Set valueSet = new HashSet<>(Arrays.asList(values()));
+ if (state.getDmbsSpecificOptions().nullSafety) {
+ valueSet.remove(STRING);
+ }
+ MongoDBDataType[] configuredValues = new MongoDBDataType[valueSet.size()];
+ return Randomly.fromOptions(valueSet.toArray(configuredValues));
+ }
+ }
+
+ public static class MongoDBColumn extends AbstractTableColumn {
+
+ private final boolean isId;
+ private final boolean isNullable;
+
+ public MongoDBColumn(String name, MongoDBDataType type, boolean isId, boolean isNullable) {
+ super(name, null, type);
+ this.isId = isId;
+ this.isNullable = isNullable;
+ }
+
+ public boolean isId() {
+ return isId;
+ }
+
+ public boolean isNullable() {
+ return isNullable;
+ }
+
+ }
+
+ public static class MongoDBTables extends AbstractTables {
+
+ public MongoDBTables(List tables) {
+ super(tables);
+ }
+ }
+
+ public MongoDBSchema(List databaseTables) {
+ super(databaseTables);
+ }
+
+ public static class MongoDBTable extends AbstractTable {
+ public MongoDBTable(String name, List columns, boolean isView) {
+ super(name, columns, Collections.emptyList(), isView);
+ }
+
+ @Override
+ public long getNrRows(MongoDBGlobalState globalState) {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ public static MongoDBSchema fromConnection(MongoDatabase connection, String databaseName) {
+ throw new UnsupportedOperationException();
+ }
+
+ public MongoDBTables getRandomTableNonEmptyTables() {
+ return new MongoDBTables(Randomly.nonEmptySubset(getDatabaseTables()));
+ }
+}
diff --git a/src/sqlancer/mongodb/ast/MongoDBBinaryComparisonNode.java b/src/sqlancer/mongodb/ast/MongoDBBinaryComparisonNode.java
new file mode 100644
index 000000000..21675250a
--- /dev/null
+++ b/src/sqlancer/mongodb/ast/MongoDBBinaryComparisonNode.java
@@ -0,0 +1,16 @@
+package sqlancer.mongodb.ast;
+
+import sqlancer.common.ast.newast.NewBinaryOperatorNode;
+import sqlancer.common.ast.newast.Node;
+import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBBinaryComparisonOperator;
+
+public class MongoDBBinaryComparisonNode extends NewBinaryOperatorNode {
+ public MongoDBBinaryComparisonNode(Node left, Node right,
+ MongoDBBinaryComparisonOperator op) {
+ super(left, right, op);
+ }
+
+ public MongoDBBinaryComparisonOperator operator() {
+ return (MongoDBBinaryComparisonOperator) op;
+ }
+}
diff --git a/src/sqlancer/mongodb/ast/MongoDBBinaryLogicalNode.java b/src/sqlancer/mongodb/ast/MongoDBBinaryLogicalNode.java
new file mode 100644
index 000000000..efb8d8294
--- /dev/null
+++ b/src/sqlancer/mongodb/ast/MongoDBBinaryLogicalNode.java
@@ -0,0 +1,16 @@
+package sqlancer.mongodb.ast;
+
+import sqlancer.common.ast.newast.NewBinaryOperatorNode;
+import sqlancer.common.ast.newast.Node;
+import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBBinaryLogicalOperator;
+
+public class MongoDBBinaryLogicalNode extends NewBinaryOperatorNode {
+ public MongoDBBinaryLogicalNode(Node left, Node right,
+ MongoDBBinaryLogicalOperator op) {
+ super(left, right, op);
+ }
+
+ public MongoDBBinaryLogicalOperator operator() {
+ return (MongoDBBinaryLogicalOperator) op;
+ }
+}
diff --git a/src/sqlancer/mongodb/ast/MongoDBConstant.java b/src/sqlancer/mongodb/ast/MongoDBConstant.java
new file mode 100644
index 000000000..86f783b48
--- /dev/null
+++ b/src/sqlancer/mongodb/ast/MongoDBConstant.java
@@ -0,0 +1,252 @@
+package sqlancer.mongodb.ast;
+
+import java.io.Serializable;
+
+import org.bson.BsonDateTime;
+import org.bson.BsonTimestamp;
+import org.bson.Document;
+
+import sqlancer.common.ast.newast.Node;
+
+public abstract class MongoDBConstant implements Node {
+ private MongoDBConstant() {
+ }
+
+ public abstract void setValueInDocument(Document document, String key);
+
+ public abstract String getLogValue();
+
+ public abstract Object getValue();
+
+ public abstract Serializable getSerializedValue();
+
+ public static class MongoDBNullConstant extends MongoDBConstant {
+
+ @Override
+ public void setValueInDocument(Document document, String key) {
+ document.append(key, null);
+ }
+
+ @Override
+ public String getLogValue() {
+ return "null";
+ }
+
+ @Override
+ public Object getValue() {
+ return null;
+ }
+
+ @Override
+ public Serializable getSerializedValue() {
+ return null;
+ }
+ }
+
+ public static Node createNullConstant() {
+ return new MongoDBNullConstant();
+ }
+
+ public static class MongoDBIntegerConstant extends MongoDBConstant {
+
+ private final int value;
+
+ public MongoDBIntegerConstant(int value) {
+ this.value = value;
+ }
+
+ @Override
+ public void setValueInDocument(Document document, String key) {
+ document.append(key, value);
+ }
+
+ @Override
+ public String getLogValue() {
+ return "NumberInt(" + value + ")";
+ }
+
+ @Override
+ public Integer getValue() {
+ return value;
+ }
+
+ @Override
+ public Serializable getSerializedValue() {
+ return value;
+ }
+ }
+
+ public static Node createIntegerConstant(int value) {
+ return new MongoDBIntegerConstant(value);
+ }
+
+ public static class MongoDBStringConstant extends MongoDBConstant {
+
+ private final String value;
+
+ public MongoDBStringConstant(String value) {
+ this.value = value;
+ }
+
+ public String getStringValue() {
+ return value;
+ }
+
+ @Override
+ public void setValueInDocument(Document document, String key) {
+ document.append(key, value);
+ }
+
+ @Override
+ public String getLogValue() {
+ return "\"" + value.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n") + "\"";
+ }
+
+ @Override
+ public String getValue() {
+ return value;
+ }
+
+ @Override
+ public Serializable getSerializedValue() {
+ return value;
+ }
+ }
+
+ public static Node createStringConstant(String value) {
+ return new MongoDBStringConstant(value);
+ }
+
+ public static class MongoDBBooleanConstant extends MongoDBConstant {
+
+ private final boolean value;
+
+ public MongoDBBooleanConstant(boolean value) {
+ this.value = value;
+ }
+
+ @Override
+ public void setValueInDocument(Document document, String key) {
+ document.append(key, value);
+ }
+
+ @Override
+ public String getLogValue() {
+ return String.valueOf(value);
+ }
+
+ @Override
+ public Boolean getValue() {
+ return value;
+ }
+
+ @Override
+ public Serializable getSerializedValue() {
+ return value;
+ }
+ }
+
+ public static Node createBooleanConstant(boolean value) {
+ return new MongoDBBooleanConstant(value);
+ }
+
+ public static class MongoDBDoubleConstant extends MongoDBConstant {
+
+ private final double value;
+
+ public MongoDBDoubleConstant(double value) {
+ this.value = value;
+ }
+
+ @Override
+ public void setValueInDocument(Document document, String key) {
+ document.append(key, value);
+ }
+
+ @Override
+ public String getLogValue() {
+ return String.valueOf(value);
+ }
+
+ @Override
+ public Double getValue() {
+ return value;
+ }
+
+ @Override
+ public Serializable getSerializedValue() {
+ return value;
+ }
+ }
+
+ public static Node createDoubleConstant(double value) {
+ return new MongoDBDoubleConstant(value);
+ }
+
+ public static class MongoDBDateTimeConstant extends MongoDBConstant {
+
+ private final BsonDateTime value;
+
+ public MongoDBDateTimeConstant(long val) {
+ this.value = new BsonDateTime(val);
+ }
+
+ @Override
+ public void setValueInDocument(Document document, String key) {
+ document.append(key, value);
+ }
+
+ @Override
+ public String getLogValue() {
+ return "new Date(" + value.getValue() + ")";
+ }
+
+ @Override
+ public BsonDateTime getValue() {
+ return value;
+ }
+
+ @Override
+ public Serializable getSerializedValue() {
+ return value.getValue();
+ }
+ }
+
+ public static Node createDateTimeConstant(long value) {
+ return new MongoDBDateTimeConstant(value);
+ }
+
+ public static class MongoDBTimestampConstant extends MongoDBConstant {
+
+ private final BsonTimestamp value;
+
+ public MongoDBTimestampConstant(long value) {
+ this.value = new BsonTimestamp(value);
+ }
+
+ @Override
+ public void setValueInDocument(Document document, String key) {
+ document.append(key, value);
+ }
+
+ @Override
+ public String getLogValue() {
+ return "Timestamp(" + value.getValue() + ",1)";
+ }
+
+ @Override
+ public BsonTimestamp getValue() {
+ return value;
+ }
+
+ @Override
+ public Serializable getSerializedValue() {
+ return value.getValue();
+ }
+ }
+
+ public static Node createTimestampConstant(long value) {
+ return new MongoDBTimestampConstant(value);
+ }
+
+}
diff --git a/src/sqlancer/mongodb/ast/MongoDBExpression.java b/src/sqlancer/mongodb/ast/MongoDBExpression.java
new file mode 100644
index 000000000..1235a1fbc
--- /dev/null
+++ b/src/sqlancer/mongodb/ast/MongoDBExpression.java
@@ -0,0 +1,4 @@
+package sqlancer.mongodb.ast;
+
+public interface MongoDBExpression {
+}
diff --git a/src/sqlancer/mongodb/ast/MongoDBRegexNode.java b/src/sqlancer/mongodb/ast/MongoDBRegexNode.java
new file mode 100644
index 000000000..76c608586
--- /dev/null
+++ b/src/sqlancer/mongodb/ast/MongoDBRegexNode.java
@@ -0,0 +1,24 @@
+package sqlancer.mongodb.ast;
+
+import static sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBRegexOperator.REGEX;
+
+import sqlancer.common.ast.newast.NewBinaryOperatorNode;
+import sqlancer.common.ast.newast.Node;
+import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBRegexOperator;
+
+public class MongoDBRegexNode extends NewBinaryOperatorNode {
+ private final String options;
+
+ public MongoDBRegexNode(Node left, Node right, String options) {
+ super(left, right, REGEX);
+ this.options = options;
+ }
+
+ public String getOptions() {
+ return options;
+ }
+
+ public MongoDBRegexOperator operator() {
+ return (MongoDBRegexOperator) op;
+ }
+}
diff --git a/src/sqlancer/mongodb/ast/MongoDBSelect.java b/src/sqlancer/mongodb/ast/MongoDBSelect.java
new file mode 100644
index 000000000..0fe91ba4a
--- /dev/null
+++ b/src/sqlancer/mongodb/ast/MongoDBSelect.java
@@ -0,0 +1,104 @@
+package sqlancer.mongodb.ast;
+
+import java.util.List;
+
+import sqlancer.common.ast.newast.Node;
+import sqlancer.mongodb.test.MongoDBColumnTestReference;
+
+public class MongoDBSelect implements Node {
+
+ private final String mainTableName;
+ private final MongoDBColumnTestReference joinColumn;
+ List projectionColumns;
+ List lookupList;
+ boolean hasFilter;
+ Node filterClause;
+ boolean hasComputed;
+ List> computedClauses;
+ private boolean withCountClause;
+
+ public MongoDBSelect(String mainTableName, MongoDBColumnTestReference joinColumn) {
+ this.mainTableName = mainTableName;
+ this.joinColumn = joinColumn;
+ }
+
+ public String getMainTableName() {
+ return mainTableName;
+ }
+
+ public MongoDBColumnTestReference getJoinColumn() {
+ return joinColumn;
+ }
+
+ public void setProjectionList(List fetchColumns) {
+ if (fetchColumns == null || fetchColumns.isEmpty()) {
+ throw new IllegalArgumentException();
+ }
+ this.projectionColumns = fetchColumns;
+ }
+
+ public List getProjectionList() {
+ if (projectionColumns == null) {
+ throw new IllegalStateException();
+ }
+ return projectionColumns;
+ }
+
+ public void setLookupList(List lookupList) {
+ if (lookupList == null || lookupList.isEmpty()) {
+ throw new IllegalArgumentException();
+ }
+ this.lookupList = lookupList;
+ }
+
+ public List getLookupList() {
+ if (lookupList == null) {
+ throw new IllegalStateException();
+ }
+ return lookupList;
+ }
+
+ public void setFilterClause(Node filterClause) {
+ if (filterClause == null) {
+ hasFilter = false;
+ this.filterClause = null;
+ return;
+ }
+ hasFilter = true;
+ this.filterClause = filterClause;
+ }
+
+ public Node getFilterClause() {
+ return filterClause;
+ }
+
+ public boolean hasFilter() {
+ return hasFilter;
+ }
+
+ public void setComputedClause(List> computedClause) {
+ if (computedClause == null) {
+ hasComputed = false;
+ this.computedClauses = null;
+ return;
+ }
+ hasComputed = true;
+ this.computedClauses = computedClause;
+ }
+
+ public List> getComputedClause() {
+ return computedClauses;
+ }
+
+ public boolean hasComputed() {
+ return hasComputed;
+ }
+
+ public boolean getWithCountClause() {
+ return withCountClause;
+ }
+
+ public void setWithCountClause(boolean withCountClause) {
+ this.withCountClause = withCountClause;
+ }
+}
diff --git a/src/sqlancer/mongodb/ast/MongoDBUnaryLogicalOperatorNode.java b/src/sqlancer/mongodb/ast/MongoDBUnaryLogicalOperatorNode.java
new file mode 100644
index 000000000..a34fe27e5
--- /dev/null
+++ b/src/sqlancer/mongodb/ast/MongoDBUnaryLogicalOperatorNode.java
@@ -0,0 +1,16 @@
+package sqlancer.mongodb.ast;
+
+import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode;
+import sqlancer.common.ast.newast.Node;
+import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBUnaryLogicalOperator;
+
+public class MongoDBUnaryLogicalOperatorNode extends NewUnaryPrefixOperatorNode {
+
+ public MongoDBUnaryLogicalOperatorNode(Node expr, MongoDBUnaryLogicalOperator op) {
+ super(expr, op);
+ }
+
+ public MongoDBUnaryLogicalOperator operator() {
+ return (MongoDBUnaryLogicalOperator) op;
+ }
+}
diff --git a/src/sqlancer/mongodb/ast/MongoDBUnsupportedPredicate.java b/src/sqlancer/mongodb/ast/MongoDBUnsupportedPredicate.java
new file mode 100644
index 000000000..eae143e7d
--- /dev/null
+++ b/src/sqlancer/mongodb/ast/MongoDBUnsupportedPredicate.java
@@ -0,0 +1,7 @@
+package sqlancer.mongodb.ast;
+
+import sqlancer.common.ast.newast.Node;
+
+public class MongoDBUnsupportedPredicate implements Node {
+
+}
diff --git a/src/sqlancer/mongodb/gen/MongoDBComputedExpressionGenerator.java b/src/sqlancer/mongodb/gen/MongoDBComputedExpressionGenerator.java
new file mode 100644
index 000000000..fd5959ce6
--- /dev/null
+++ b/src/sqlancer/mongodb/gen/MongoDBComputedExpressionGenerator.java
@@ -0,0 +1,89 @@
+package sqlancer.mongodb.gen;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import sqlancer.Randomly;
+import sqlancer.common.ast.newast.NewFunctionNode;
+import sqlancer.common.ast.newast.Node;
+import sqlancer.common.gen.UntypedExpressionGenerator;
+import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState;
+import sqlancer.mongodb.MongoDBSchema;
+import sqlancer.mongodb.ast.MongoDBExpression;
+import sqlancer.mongodb.test.MongoDBColumnTestReference;
+
+public class MongoDBComputedExpressionGenerator
+ extends UntypedExpressionGenerator, MongoDBColumnTestReference> {
+
+ private final MongoDBGlobalState globalState;
+
+ @Override
+ public Node generateLeafNode() {
+ ComputedFunction function = ComputedFunction.getRandom();
+ List> expressions = new ArrayList<>();
+ for (int i = 0; i < function.getNrArgs(); i++) {
+ expressions.add(super.generateLeafNode());
+ }
+ return new NewFunctionNode<>(expressions, function);
+ }
+
+ @Override
+ protected Node generateExpression(int depth) {
+ if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) {
+ return generateLeafNode();
+ }
+ ComputedFunction func = ComputedFunction.getRandom();
+ return new NewFunctionNode<>(generateExpressions(depth + 1, func.getNrArgs()), func);
+ }
+
+ public MongoDBComputedExpressionGenerator(MongoDBGlobalState globalState) {
+ this.globalState = globalState;
+ }
+
+ public enum ComputedFunction {
+ ADD(2, "$add"), MULTIPLY(2, "$multiply"), DIVIDE(2, "$divide"), POW(2, "$pow"), SQRT(1, "$sqrt"),
+ LOG(2, "$log"), AVG(2, "$avg"), EXP(1, "$exp");
+
+ private final int nrArgs;
+ private final String operatorName;
+
+ ComputedFunction(int nrArgs, String operatorName) {
+ this.nrArgs = nrArgs;
+ this.operatorName = operatorName;
+ }
+
+ public static ComputedFunction getRandom() {
+ return Randomly.fromOptions(values());
+ }
+
+ public int getNrArgs() {
+ return nrArgs;
+ }
+
+ public String getOperator() {
+ return operatorName;
+ }
+ }
+
+ @Override
+ public Node generateConstant() {
+ MongoDBSchema.MongoDBDataType type = MongoDBSchema.MongoDBDataType.getRandom(globalState);
+ MongoDBConstantGenerator generator = new MongoDBConstantGenerator(globalState);
+ return generator.generateConstantWithType(type);
+ }
+
+ @Override
+ protected Node generateColumn() {
+ return Randomly.fromList(columns);
+ }
+
+ @Override
+ public Node negatePredicate(Node predicate) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Node isNull(Node expr) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/src/sqlancer/mongodb/gen/MongoDBConstantGenerator.java b/src/sqlancer/mongodb/gen/MongoDBConstantGenerator.java
new file mode 100644
index 000000000..e81291543
--- /dev/null
+++ b/src/sqlancer/mongodb/gen/MongoDBConstantGenerator.java
@@ -0,0 +1,86 @@
+package sqlancer.mongodb.gen;
+
+import org.bson.Document;
+
+import sqlancer.Randomly;
+import sqlancer.common.ast.newast.Node;
+import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState;
+import sqlancer.mongodb.MongoDBSchema.MongoDBDataType;
+import sqlancer.mongodb.ast.MongoDBConstant;
+import sqlancer.mongodb.ast.MongoDBConstant.MongoDBBooleanConstant;
+import sqlancer.mongodb.ast.MongoDBConstant.MongoDBDateTimeConstant;
+import sqlancer.mongodb.ast.MongoDBConstant.MongoDBDoubleConstant;
+import sqlancer.mongodb.ast.MongoDBConstant.MongoDBIntegerConstant;
+import sqlancer.mongodb.ast.MongoDBConstant.MongoDBNullConstant;
+import sqlancer.mongodb.ast.MongoDBConstant.MongoDBTimestampConstant;
+import sqlancer.mongodb.ast.MongoDBExpression;
+
+public class MongoDBConstantGenerator {
+ private final MongoDBGlobalState globalState;
+
+ public MongoDBConstantGenerator(MongoDBGlobalState globalState) {
+ this.globalState = globalState;
+ }
+
+ public Node generateConstantWithType(MongoDBDataType option) {
+ switch (option) {
+ case DATE_TIME:
+ return MongoDBConstant.createDateTimeConstant(globalState.getRandomly().getInteger());
+ case BOOLEAN:
+ return MongoDBConstant.createBooleanConstant(Randomly.getBoolean());
+ case DOUBLE:
+ return MongoDBConstant.createDoubleConstant(globalState.getRandomly().getDouble());
+ case STRING:
+ return MongoDBConstant.createStringConstant(globalState.getRandomly().getString());
+ case INTEGER:
+ return MongoDBConstant.createIntegerConstant((int) globalState.getRandomly().getInteger());
+ case TIMESTAMP:
+ return MongoDBConstant.createTimestampConstant(globalState.getRandomly().getInteger());
+ default:
+ throw new AssertionError(option);
+ }
+ }
+
+ public void addRandomConstant(Document document, String key) {
+ MongoDBDataType type = MongoDBDataType.getRandom(globalState);
+ addRandomConstantWithType(document, key, type);
+ }
+
+ public void addRandomConstantWithType(Document document, String key, MongoDBDataType option) {
+ MongoDBConstant constant;
+ if (globalState.getDmbsSpecificOptions().testNullInserts && Randomly.getBooleanWithSmallProbability()) {
+ constant = new MongoDBNullConstant();
+ constant.setValueInDocument(document, key);
+ return;
+ }
+ switch (option) {
+ case DATE_TIME:
+ constant = new MongoDBDateTimeConstant(globalState.getRandomly().getInteger());
+ constant.setValueInDocument(document, key);
+ return;
+
+ case BOOLEAN:
+ constant = new MongoDBBooleanConstant(Randomly.getBoolean());
+ constant.setValueInDocument(document, key);
+ return;
+ case DOUBLE:
+ constant = new MongoDBDoubleConstant(globalState.getRandomly().getDouble());
+ constant.setValueInDocument(document, key);
+ return;
+ case STRING:
+ constant = new MongoDBConstant.MongoDBStringConstant(globalState.getRandomly().getString());
+ constant.setValueInDocument(document, key);
+ return;
+ case INTEGER:
+ constant = new MongoDBIntegerConstant((int) globalState.getRandomly().getInteger());
+ constant.setValueInDocument(document, key);
+ return;
+ case TIMESTAMP:
+ constant = new MongoDBTimestampConstant(globalState.getRandomly().getInteger());
+ constant.setValueInDocument(document, key);
+ return;
+ default:
+ throw new AssertionError(option);
+ }
+ }
+}
diff --git a/src/sqlancer/mongodb/gen/MongoDBIndexGenerator.java b/src/sqlancer/mongodb/gen/MongoDBIndexGenerator.java
new file mode 100644
index 000000000..8687fd45c
--- /dev/null
+++ b/src/sqlancer/mongodb/gen/MongoDBIndexGenerator.java
@@ -0,0 +1,25 @@
+package sqlancer.mongodb.gen;
+
+import java.util.List;
+
+import sqlancer.Randomly;
+import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState;
+import sqlancer.mongodb.MongoDBQueryAdapter;
+import sqlancer.mongodb.MongoDBSchema.MongoDBColumn;
+import sqlancer.mongodb.MongoDBSchema.MongoDBTable;
+import sqlancer.mongodb.query.MongoDBCreateIndexQuery;
+
+public final class MongoDBIndexGenerator {
+ private MongoDBIndexGenerator() {
+ }
+
+ public static MongoDBQueryAdapter getQuery(MongoDBGlobalState globalState) {
+ MongoDBTable randomTable = globalState.getSchema().getRandomTable();
+ List columns = Randomly.nonEmptySubset(randomTable.getColumns());
+ MongoDBCreateIndexQuery createIndexQuery = new MongoDBCreateIndexQuery(randomTable);
+ for (MongoDBColumn column : columns) {
+ createIndexQuery.addIndex(column.getName(), Randomly.getBoolean());
+ }
+ return createIndexQuery;
+ }
+}
diff --git a/src/sqlancer/mongodb/gen/MongoDBInsertGenerator.java b/src/sqlancer/mongodb/gen/MongoDBInsertGenerator.java
new file mode 100644
index 000000000..4501971b4
--- /dev/null
+++ b/src/sqlancer/mongodb/gen/MongoDBInsertGenerator.java
@@ -0,0 +1,38 @@
+package sqlancer.mongodb.gen;
+
+import org.bson.Document;
+
+import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState;
+import sqlancer.mongodb.MongoDBQueryAdapter;
+import sqlancer.mongodb.MongoDBSchema.MongoDBTable;
+import sqlancer.mongodb.query.MongoDBInsertQuery;
+
+public final class MongoDBInsertGenerator {
+
+ private final MongoDBGlobalState globalState;
+
+ private MongoDBInsertGenerator(MongoDBGlobalState globalState) {
+ this.globalState = globalState;
+ }
+
+ public static MongoDBQueryAdapter getQuery(MongoDBGlobalState globalState) {
+ return new MongoDBInsertGenerator(globalState).generate();
+ }
+
+ public MongoDBQueryAdapter generate() {
+ Document result = new Document();
+ MongoDBTable table = globalState.getSchema().getRandomTable();
+ MongoDBConstantGenerator constantGenerator = new MongoDBConstantGenerator(globalState);
+
+ for (int i = 0; i < table.getColumns().size(); i++) {
+ if (!globalState.getDmbsSpecificOptions().testRandomTypes) {
+ constantGenerator.addRandomConstantWithType(result, table.getColumns().get(i).getName(),
+ table.getColumns().get(i).getType());
+ } else {
+ constantGenerator.addRandomConstant(result, table.getColumns().get(i).getName());
+ }
+ }
+
+ return new MongoDBInsertQuery(table, result);
+ }
+}
diff --git a/src/sqlancer/mongodb/gen/MongoDBMatchExpressionGenerator.java b/src/sqlancer/mongodb/gen/MongoDBMatchExpressionGenerator.java
new file mode 100644
index 000000000..0645d7f24
--- /dev/null
+++ b/src/sqlancer/mongodb/gen/MongoDBMatchExpressionGenerator.java
@@ -0,0 +1,291 @@
+package sqlancer.mongodb.gen;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.bson.conversions.Bson;
+
+import com.mongodb.client.model.Filters;
+
+import sqlancer.Randomly;
+import sqlancer.common.ast.BinaryOperatorNode.Operator;
+import sqlancer.common.ast.newast.Node;
+import sqlancer.common.gen.UntypedExpressionGenerator;
+import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState;
+import sqlancer.mongodb.MongoDBSchema.MongoDBDataType;
+import sqlancer.mongodb.ast.MongoDBBinaryComparisonNode;
+import sqlancer.mongodb.ast.MongoDBBinaryLogicalNode;
+import sqlancer.mongodb.ast.MongoDBConstant;
+import sqlancer.mongodb.ast.MongoDBExpression;
+import sqlancer.mongodb.ast.MongoDBRegexNode;
+import sqlancer.mongodb.ast.MongoDBUnaryLogicalOperatorNode;
+import sqlancer.mongodb.ast.MongoDBUnsupportedPredicate;
+import sqlancer.mongodb.test.MongoDBColumnTestReference;
+import sqlancer.mongodb.visitor.MongoDBNegateVisitor;
+
+public class MongoDBMatchExpressionGenerator
+ extends UntypedExpressionGenerator, MongoDBColumnTestReference> {
+
+ private final MongoDBGlobalState globalState;
+
+ private enum LeafExpression {
+ BINARY_COMPARISON, REGEX
+ }
+
+ private enum NonLeafExpression {
+ BINARY_LOGICAL, UNARY_LOGICAL
+ }
+
+ public MongoDBMatchExpressionGenerator(MongoDBGlobalState globalState) {
+ this.globalState = globalState;
+ }
+
+ @Override
+ public Node generateLeafNode() {
+ List possibleOptions = new ArrayList<>(Arrays.asList(LeafExpression.values()));
+ if (!globalState.getDmbsSpecificOptions().testWithRegex) {
+ possibleOptions.remove(LeafExpression.REGEX);
+ }
+ LeafExpression expr = Randomly.fromList(possibleOptions);
+ switch (expr) {
+ case BINARY_COMPARISON:
+ MongoDBBinaryComparisonOperator operator = MongoDBBinaryComparisonOperator.getRandom();
+ MongoDBColumnTestReference reference = (MongoDBColumnTestReference) generateColumn();
+
+ return new MongoDBBinaryComparisonNode(reference,
+ generateConstant(reference.getColumnReference().getType()), operator);
+ case REGEX:
+ return new MongoDBRegexNode(generateColumn(),
+ new MongoDBConstantGenerator(globalState).generateConstantWithType(MongoDBDataType.STRING),
+ getRandomizedRegexOptions());
+ default:
+ throw new AssertionError();
+ }
+ }
+
+ @Override
+ protected Node generateExpression(int depth) {
+ if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) {
+ return generateLeafNode();
+ }
+
+ List possibleOptions = new ArrayList<>(Arrays.asList(NonLeafExpression.values()));
+ NonLeafExpression expr = Randomly.fromList(possibleOptions);
+ switch (expr) {
+ case BINARY_LOGICAL:
+ MongoDBBinaryLogicalOperator binaryOperator = MongoDBBinaryLogicalOperator.getRandom();
+ return new MongoDBBinaryLogicalNode(generateExpression(depth + 1), generateExpression(depth + 1),
+ binaryOperator);
+ case UNARY_LOGICAL:
+ MongoDBUnaryLogicalOperator unaryOperator = MongoDBUnaryLogicalOperator.getRandom();
+ return new MongoDBUnaryLogicalOperatorNode(generateExpression(depth + 1), unaryOperator);
+ default:
+ throw new AssertionError();
+ }
+ }
+
+ @Override
+ public Node generateConstant() {
+ MongoDBDataType type = MongoDBDataType.getRandom(globalState);
+ MongoDBConstantGenerator generator = new MongoDBConstantGenerator(globalState);
+ if (Randomly.getBooleanWithSmallProbability()) {
+ return MongoDBConstant.createNullConstant();
+ }
+ return generator.generateConstantWithType(type);
+ }
+
+ public Node generateConstant(MongoDBDataType type) {
+ MongoDBConstantGenerator generator = new MongoDBConstantGenerator(globalState);
+ if (Randomly.getBooleanWithSmallProbability() && !globalState.getDmbsSpecificOptions().nullSafety) {
+ return MongoDBConstant.createNullConstant();
+ }
+ return generator.generateConstantWithType(type);
+ }
+
+ private String getRandomizedRegexOptions() {
+ List s = Randomly.subset("i", "m", "x", "s");
+ return String.join("", s);
+ }
+
+ @Override
+ protected Node generateColumn() {
+ return Randomly.fromList(columns);
+ }
+
+ @Override
+ public Node generatePredicate() {
+ Node result = super.generatePredicate();
+ return MongoDBNegateVisitor.cleanNegations(result);
+ }
+
+ @Override
+ public Node negatePredicate(Node predicate) {
+ Node result = new MongoDBUnaryLogicalOperatorNode(predicate,
+ MongoDBUnaryLogicalOperator.NOT);
+ return MongoDBNegateVisitor.cleanNegations(result);
+ }
+
+ @Override
+ public Node isNull(Node expr) {
+ return new MongoDBUnsupportedPredicate<>();
+ }
+
+ public enum MongoDBUnaryLogicalOperator implements Operator {
+ NOT {
+ @Override
+ public Bson applyOperator(Bson inner) {
+ return Filters.not(inner);
+ }
+
+ @Override
+ public String getTextRepresentation() {
+ return "$not";
+ }
+ };
+
+ public abstract Bson applyOperator(Bson inner);
+
+ public static MongoDBUnaryLogicalOperator getRandom() {
+ return Randomly.fromOptions(values());
+ }
+ }
+
+ public enum MongoDBBinaryLogicalOperator implements Operator {
+ AND {
+ @Override
+ public Bson applyOperator(Bson left, Bson right) {
+ return Filters.and(left, right);
+ }
+
+ @Override
+ public String getTextRepresentation() {
+ return "$and";
+ }
+ },
+ OR {
+ @Override
+ public Bson applyOperator(Bson left, Bson right) {
+ return Filters.or(left, right);
+ }
+
+ @Override
+ public String getTextRepresentation() {
+ return "$or";
+ }
+ },
+ NOR {
+ @Override
+ public Bson applyOperator(Bson left, Bson right) {
+ return Filters.nor(left, right);
+ }
+
+ @Override
+ public String getTextRepresentation() {
+ return "$nor";
+ }
+ };
+
+ public abstract Bson applyOperator(Bson left, Bson right);
+
+ public static MongoDBBinaryLogicalOperator getRandom() {
+ return Randomly.fromOptions(values());
+ }
+ }
+
+ public enum MongoDBBinaryComparisonOperator implements Operator {
+ EQUALS {
+ @Override
+ public Bson applyOperator(String columnName, MongoDBConstant constant) {
+ return Filters.eq(columnName, constant.getValue());
+ }
+
+ @Override
+ public String getTextRepresentation() {
+ return "$eq";
+ }
+ },
+ NOT_EQUALS {
+ @Override
+ public Bson applyOperator(String columnName, MongoDBConstant constant) {
+ return Filters.ne(columnName, constant.getValue());
+ }
+
+ @Override
+ public String getTextRepresentation() {
+ return "$ne";
+ }
+ },
+ GREATER {
+ @Override
+ public Bson applyOperator(String columnName, MongoDBConstant constant) {
+ return Filters.gt(columnName, constant.getValue());
+ }
+
+ @Override
+ public String getTextRepresentation() {
+ return "$gt";
+ }
+
+ },
+ LESS {
+ @Override
+ public Bson applyOperator(String columnName, MongoDBConstant constant) {
+ return Filters.lt(columnName, constant.getValue());
+ }
+
+ @Override
+ public String getTextRepresentation() {
+ return "$lt";
+ }
+
+ },
+ GREATER_EQUAL {
+ @Override
+ public Bson applyOperator(String columnName, MongoDBConstant constant) {
+ return Filters.gte(columnName, constant.getValue());
+
+ }
+
+ @Override
+ public String getTextRepresentation() {
+ return "$gte";
+ }
+
+ },
+ LESS_EQUAL {
+ @Override
+ public Bson applyOperator(String columnName, MongoDBConstant constant) {
+ return Filters.lte(columnName, constant.getValue());
+ }
+
+ @Override
+ public String getTextRepresentation() {
+ return "$lte";
+ }
+ };
+
+ public abstract Bson applyOperator(String columnName, MongoDBConstant constant);
+
+ public static MongoDBBinaryComparisonOperator getRandom() {
+ return Randomly.fromOptions(values());
+ }
+ }
+
+ public enum MongoDBRegexOperator implements Operator {
+ REGEX {
+ @Override
+ public Bson applyOperator(String columnName, MongoDBConstant.MongoDBStringConstant regex, String options) {
+ return Filters.regex(columnName, regex.getStringValue(), options);
+ }
+
+ @Override
+ public String getTextRepresentation() {
+ return "$regex";
+ }
+ };
+
+ public abstract Bson applyOperator(String columnName, MongoDBConstant.MongoDBStringConstant regex,
+ String options);
+ }
+}
diff --git a/src/sqlancer/mongodb/gen/MongoDBTableGenerator.java b/src/sqlancer/mongodb/gen/MongoDBTableGenerator.java
new file mode 100644
index 000000000..6a4f33d38
--- /dev/null
+++ b/src/sqlancer/mongodb/gen/MongoDBTableGenerator.java
@@ -0,0 +1,54 @@
+package sqlancer.mongodb.gen;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import sqlancer.Randomly;
+import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState;
+import sqlancer.mongodb.MongoDBQueryAdapter;
+import sqlancer.mongodb.MongoDBSchema.MongoDBColumn;
+import sqlancer.mongodb.MongoDBSchema.MongoDBDataType;
+import sqlancer.mongodb.MongoDBSchema.MongoDBTable;
+import sqlancer.mongodb.query.MongoDBCreateTableQuery;
+
+public class MongoDBTableGenerator {
+
+ private MongoDBTable table;
+ private final List columnsToBeAdded = new ArrayList<>();
+ private final MongoDBGlobalState state;
+
+ public MongoDBTableGenerator(MongoDBGlobalState state) {
+ this.state = state;
+ }
+
+ public MongoDBQueryAdapter getQuery(MongoDBGlobalState globalState) {
+ String tableName = globalState.getSchema().getFreeTableName();
+ MongoDBCreateTableQuery createTableQuery = new MongoDBCreateTableQuery(tableName);
+ table = new MongoDBTable(tableName, columnsToBeAdded, false);
+ for (int i = 0; i < Randomly.smallNumber() + 1; i++) {
+ String columnName = String.format("c%d", i);
+ MongoDBDataType type = createColumn(columnName);
+ if (globalState.getDmbsSpecificOptions().testValidation) {
+ createTableQuery.addValidation(columnName, type.getBsonType());
+ }
+ }
+ globalState.addTable(table);
+ return createTableQuery;
+ }
+
+ private MongoDBDataType createColumn(String columnName) {
+ MongoDBDataType columnType = MongoDBDataType.getRandom(state);
+ MongoDBColumn newColumn = new MongoDBColumn(columnName, columnType, false, false);
+ newColumn.setTable(table);
+ columnsToBeAdded.add(newColumn);
+ return columnType;
+ }
+
+ public String getTableName() {
+ return table.getName();
+ }
+
+ public MongoDBTable getGeneratedTable() {
+ return table;
+ }
+}
diff --git a/src/sqlancer/mongodb/query/MongoDBCreateIndexQuery.java b/src/sqlancer/mongodb/query/MongoDBCreateIndexQuery.java
new file mode 100644
index 000000000..c873b5924
--- /dev/null
+++ b/src/sqlancer/mongodb/query/MongoDBCreateIndexQuery.java
@@ -0,0 +1,77 @@
+package sqlancer.mongodb.query;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.bson.conversions.Bson;
+
+import com.mongodb.client.model.Indexes;
+
+import sqlancer.GlobalState;
+import sqlancer.Main;
+import sqlancer.common.query.ExpectedErrors;
+import sqlancer.mongodb.MongoDBConnection;
+import sqlancer.mongodb.MongoDBQueryAdapter;
+import sqlancer.mongodb.MongoDBSchema.MongoDBTable;
+
+public class MongoDBCreateIndexQuery extends MongoDBQueryAdapter {
+
+ private final MongoDBTable table;
+ private final List indeces;
+ private final List logIndeces;
+
+ public MongoDBCreateIndexQuery(MongoDBTable table) {
+ this.table = table;
+ this.indeces = new ArrayList<>();
+ this.logIndeces = new ArrayList<>();
+ }
+
+ public void addIndex(String column, boolean ascending) {
+ if (ascending) {
+ indeces.add(Indexes.ascending(column));
+ logIndeces.add(column + ": 1");
+ } else {
+ indeces.add(Indexes.descending(column));
+ logIndeces.add(column + ": -1");
+ }
+ }
+
+ @Override
+ public String getLogString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append("db.").append(table.getName()).append(".createIndex({");
+ String helper = "";
+ for (String index : logIndeces) {
+ sb.append(helper);
+ helper = ",";
+ sb.append(index);
+ }
+ sb.append("})\n");
+ return sb.toString();
+ }
+
+ @Override
+ public boolean couldAffectSchema() {
+ return false;
+ }
+
+ @Override
+ public > boolean execute(G globalState, String... fills)
+ throws Exception {
+ Main.nrSuccessfulActions.addAndGet(1);
+ Bson index;
+ if (indeces.size() > 1) {
+ index = Indexes.compoundIndex(indeces);
+ } else {
+ index = indeces.get(0);
+ }
+ globalState.getConnection().getDatabase().getCollection(table.getName()).createIndex(index);
+ return true;
+ }
+
+ @Override
+ public ExpectedErrors getExpectedErrors() {
+ return new ExpectedErrors();
+ }
+
+}
diff --git a/src/sqlancer/mongodb/query/MongoDBCreateTableQuery.java b/src/sqlancer/mongodb/query/MongoDBCreateTableQuery.java
new file mode 100644
index 000000000..7bc174c77
--- /dev/null
+++ b/src/sqlancer/mongodb/query/MongoDBCreateTableQuery.java
@@ -0,0 +1,115 @@
+package sqlancer.mongodb.query;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.bson.BsonType;
+import org.bson.conversions.Bson;
+
+import com.mongodb.client.model.CreateCollectionOptions;
+import com.mongodb.client.model.Filters;
+import com.mongodb.client.model.ValidationOptions;
+
+import sqlancer.GlobalState;
+import sqlancer.Main;
+import sqlancer.common.query.ExpectedErrors;
+import sqlancer.mongodb.MongoDBConnection;
+import sqlancer.mongodb.MongoDBQueryAdapter;
+
+public class MongoDBCreateTableQuery extends MongoDBQueryAdapter {
+
+ private final String tableName;
+ private Bson validationFilter;
+ private final List logRequiredList;
+ private final List logPropertiesList;
+
+ public MongoDBCreateTableQuery(String tableName) {
+ this.tableName = tableName;
+ this.validationFilter = null;
+ logRequiredList = new ArrayList<>();
+ logPropertiesList = new ArrayList<>();
+ }
+
+ @Override
+ public boolean couldAffectSchema() {
+ return true;
+ }
+
+ @Override
+ public > boolean execute(G globalState, String... fills)
+ throws Exception {
+ ValidationOptions collOptions = new ValidationOptions().validator(this.validationFilter);
+ Main.nrSuccessfulActions.addAndGet(1);
+ globalState.getConnection().getDatabase().createCollection(tableName,
+ new CreateCollectionOptions().validationOptions(collOptions));
+ return true;
+ }
+
+ @Override
+ public ExpectedErrors getExpectedErrors() {
+ return new ExpectedErrors();
+ }
+
+ @Override
+ public String getLogString() {
+ String helper = "";
+ StringBuilder sb = new StringBuilder();
+ sb.append("db.createCollection(\"").append(tableName).append("\", {\n");
+
+ if (!logPropertiesList.isEmpty()) {
+ sb.append("validator: {");
+ sb.append("$jsonSchema: {");
+ sb.append("bsonType:\"object\",");
+ sb.append("required: [\n");
+ for (String req : logRequiredList) {
+ sb.append(helper);
+ helper = ",";
+ sb.append(req);
+ }
+ sb.append("],");
+ sb.append("properties: {\n");
+ for (String prop : logPropertiesList) {
+ sb.append(prop);
+ }
+ sb.append("}}}})");
+ } else {
+ sb.append("})");
+ }
+
+ return sb.toString();
+ }
+
+ public void addValidation(String columnName, BsonType type) {
+ Bson nameFilter = Filters.exists(columnName);
+ Bson typeFilter = Filters.type(columnName, type);
+
+ if (validationFilter == null) {
+ validationFilter = Filters.and(nameFilter, typeFilter);
+ } else {
+ validationFilter = Filters.and(validationFilter, Filters.and(nameFilter, typeFilter));
+ }
+
+ logRequiredList.add("\"" + columnName + "\"");
+ logPropertiesList.add(columnName + ": { bsonType:\"" + bsonTypeToString(type) + "\"},\n");
+ }
+
+ public String bsonTypeToString(BsonType type) {
+ switch (type) {
+ case DOUBLE:
+ return "double";
+ case STRING:
+ return "string";
+ case BOOLEAN:
+ return "bool";
+ case INT32:
+ case INT64:
+ return "int";
+ case DATE_TIME:
+ return "date";
+ case TIMESTAMP:
+ return "timestamp";
+ default:
+ throw new IllegalStateException();
+ }
+ }
+}
diff --git a/src/sqlancer/mongodb/query/MongoDBInsertQuery.java b/src/sqlancer/mongodb/query/MongoDBInsertQuery.java
new file mode 100644
index 000000000..127dc82d0
--- /dev/null
+++ b/src/sqlancer/mongodb/query/MongoDBInsertQuery.java
@@ -0,0 +1,87 @@
+package sqlancer.mongodb.query;
+
+import org.bson.BsonDateTime;
+import org.bson.BsonTimestamp;
+import org.bson.Document;
+import org.bson.types.ObjectId;
+
+import com.mongodb.client.result.InsertOneResult;
+
+import sqlancer.GlobalState;
+import sqlancer.Main;
+import sqlancer.common.query.ExpectedErrors;
+import sqlancer.mongodb.MongoDBConnection;
+import sqlancer.mongodb.MongoDBQueryAdapter;
+import sqlancer.mongodb.MongoDBSchema.MongoDBTable;
+
+public class MongoDBInsertQuery extends MongoDBQueryAdapter {
+ boolean excluded;
+ private final MongoDBTable table;
+ private final Document documentToBeInserted;
+
+ public MongoDBInsertQuery(MongoDBTable table, Document documentToBeInserted) {
+ this.table = table;
+ this.documentToBeInserted = documentToBeInserted;
+ this.excluded = false;
+ }
+
+ @Override
+ public String getLogString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append("db." + table.getName() + ".insert({");
+ String helper = "";
+ for (String key : documentToBeInserted.keySet()) {
+ sb.append(helper);
+ helper = ", ";
+ if (documentToBeInserted.get(key) instanceof ObjectId) {
+ continue;
+ }
+ Object value = documentToBeInserted.get(key);
+ sb.append(key);
+ sb.append(": ");
+ sb.append(getStringRepresentation(value));
+ }
+ sb.append("})\n");
+
+ return sb.toString();
+ }
+
+ private String getStringRepresentation(Object value) {
+ if (value instanceof Double) {
+ return String.valueOf(value);
+ } else if (value instanceof Integer) {
+ return "NumberInt(" + value + ")";
+ } else if (value instanceof String) {
+ return "\"" + value + "\"";
+ } else if (value instanceof BsonDateTime) {
+ return "new Date(" + ((BsonDateTime) value).getValue() + ")";
+ } else if (value instanceof BsonTimestamp) {
+ return "Timestamp(" + ((BsonTimestamp) value).getValue() + ",1)";
+ } else if (value instanceof Boolean) {
+ return String.valueOf(value);
+ } else if (value == null) {
+ return "null";
+ } else {
+ throw new IllegalStateException();
+ }
+ }
+
+ @Override
+ public boolean couldAffectSchema() {
+ return true;
+ }
+
+ @Override
+ public > boolean execute(G globalState, String... fills)
+ throws Exception {
+ Main.nrSuccessfulActions.addAndGet(1);
+ InsertOneResult result = globalState.getConnection().getDatabase().getCollection(table.getName())
+ .insertOne(documentToBeInserted);
+ return result.wasAcknowledged();
+ }
+
+ @Override
+ public ExpectedErrors getExpectedErrors() {
+ return new ExpectedErrors();
+ }
+}
diff --git a/src/sqlancer/mongodb/query/MongoDBRemoveQuery.java b/src/sqlancer/mongodb/query/MongoDBRemoveQuery.java
new file mode 100644
index 000000000..6fe1c9e3f
--- /dev/null
+++ b/src/sqlancer/mongodb/query/MongoDBRemoveQuery.java
@@ -0,0 +1,59 @@
+package sqlancer.mongodb.query;
+
+import org.bson.Document;
+import org.bson.types.ObjectId;
+
+import com.mongodb.client.result.DeleteResult;
+
+import sqlancer.GlobalState;
+import sqlancer.Main;
+import sqlancer.common.query.ExpectedErrors;
+import sqlancer.mongodb.MongoDBConnection;
+import sqlancer.mongodb.MongoDBQueryAdapter;
+import sqlancer.mongodb.MongoDBSchema;
+
+public class MongoDBRemoveQuery extends MongoDBQueryAdapter {
+
+ private final String objectId;
+ private final MongoDBSchema.MongoDBTable table;
+
+ public MongoDBRemoveQuery(MongoDBSchema.MongoDBTable table, String objectId) {
+ this.objectId = objectId;
+ this.table = table;
+ }
+
+ @Override
+ public boolean couldAffectSchema() {
+ return true;
+ }
+
+ @Override
+ public > boolean execute(G globalState, String... fills)
+ throws Exception {
+ try {
+ DeleteResult result = globalState.getConnection().getDatabase().getCollection(table.getName())
+ .deleteOne(new Document("_id", new ObjectId(objectId)));
+ if (result.wasAcknowledged()) {
+ Main.nrSuccessfulActions.addAndGet(1);
+ } else {
+ Main.nrUnsuccessfulActions.addAndGet(1);
+ }
+ return result.wasAcknowledged();
+ } catch (Exception e) {
+ Main.nrUnsuccessfulActions.addAndGet(1);
+ return false;
+ }
+ }
+
+ @Override
+ public ExpectedErrors getExpectedErrors() {
+ return new ExpectedErrors();
+ }
+
+ @Override
+ public String getLogString() {
+ StringBuilder stringBuilder = new StringBuilder();
+ stringBuilder.append("db.").append(table.getName()).append(".remove({'_id': '").append(objectId).append("'})");
+ return stringBuilder.toString();
+ }
+}
diff --git a/src/sqlancer/mongodb/query/MongoDBSelectQuery.java b/src/sqlancer/mongodb/query/MongoDBSelectQuery.java
new file mode 100644
index 000000000..1288f114c
--- /dev/null
+++ b/src/sqlancer/mongodb/query/MongoDBSelectQuery.java
@@ -0,0 +1,147 @@
+package sqlancer.mongodb.query;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.bson.Document;
+import org.bson.conversions.Bson;
+
+import com.mongodb.client.MongoCollection;
+import com.mongodb.client.MongoCursor;
+
+import sqlancer.GlobalState;
+import sqlancer.common.query.ExpectedErrors;
+import sqlancer.common.query.SQLancerResultSet;
+import sqlancer.mongodb.MongoDBConnection;
+import sqlancer.mongodb.MongoDBQueryAdapter;
+import sqlancer.mongodb.ast.MongoDBExpression;
+import sqlancer.mongodb.ast.MongoDBSelect;
+import sqlancer.mongodb.visitor.MongoDBVisitor;
+
+public class MongoDBSelectQuery extends MongoDBQueryAdapter {
+
+ private final MongoDBSelect select;
+
+ private List resultSet;
+
+ public MongoDBSelectQuery(MongoDBSelect select) {
+ this.select = select;
+ }
+
+ @Override
+ public boolean couldAffectSchema() {
+ return false;
+ }
+
+ @Override
+ public > boolean execute(G globalState, String... fills)
+ throws Exception {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public ExpectedErrors getExpectedErrors() {
+ ExpectedErrors errors = new ExpectedErrors();
+ // ARITHMETIC
+ errors.add("Failed to optimize pipeline :: caused by :: Can't coerce out of range value");
+ errors.add("Can't coerce out of range value");
+ errors.add("date overflow in $add");
+ errors.add("Failed to optimize pipeline :: caused by :: $sqrt only supports numeric types, not");
+ errors.add("Failed to optimize pipeline :: caused by :: $sqrt's argument must be greater than or equal to 0");
+ errors.add("Failed to optimize pipeline :: caused by :: $pow's base must be numeric, not");
+ errors.add("Failed to optimize pipeline :: caused by :: $pow cannot take a base of 0 and a negative exponent");
+ errors.add("Failed to optimize pipeline :: caused by :: $add only supports numeric or date types, not");
+ errors.add("Failed to optimize pipeline :: caused by :: $exp only supports numeric types, not");
+ errors.add("Failed to optimize pipeline :: caused by :: $log's base must be numeric, not");
+ errors.add("Failed to optimize pipeline :: caused by :: $log's base must be a positive number not equal to 1");
+ errors.add("Failed to optimize pipeline :: caused by :: $multiply only supports numeric types, not");
+ errors.add("$log's argument must be numeric, not");
+ errors.add("$log's argument must be a positive number, but");
+ errors.add("$log's base must be numeric, not");
+ errors.add("$log's base must be a positive number not equal to 1");
+ errors.add("$divide only supports numeric types, not");
+ errors.add("can't $divide by zero");
+ errors.add("$pow's exponent must be numeric, not");
+ errors.add("$pow's base must be numeric, not");
+ errors.add("$pow cannot take a base of 0 and a negative exponent");
+ errors.add("$add only supports numeric or date types, not");
+ errors.add("only one date allowed in an $add expression");
+ errors.add("$multiply only supports numeric types, not");
+ errors.add("$exp only supports numeric types, not");
+ errors.add("$sqrt's argument must be greater than or equal to 0");
+ errors.add("$sqrt only supports numeric types, not");
+
+ // REGEX
+ errors.add("Regular expression is invalid: nothing to repeat");
+ errors.add("Regular expression is invalid: missing terminating ] for character class");
+ errors.add("Regular expression is invalid: unmatched parentheses");
+ errors.add("Regular expression is invalid: missing )");
+ errors.add("Regular expression is invalid: invalid UTF-8 string");
+ errors.add("Regular expression is invalid: \\k is not followed by a braced, angle-bracketed, or quoted name");
+ errors.add("Regular expression is invalid: missing opening brace after \\\\o");
+ errors.add("Regular expression is invalid: reference to non-existent subpattern");
+ errors.add("Regular expression is invalid: \\ at end of pattern");
+ errors.add("Regular expression is invalid: PCRE does not support \\L, \\l, \\N{name}, \\U, or \\u");
+ errors.add("Regular expression is invalid: (?R or (?[+-]digits must be followed by )");
+ errors.add("Regular expression is invalid: unknown property name after \\P or \\p");
+ errors.add("Regular expression is invalid: (*VERB) not recognized or malformed");
+ errors.add("Regular expression is invalid: a numbered reference must not be zero");
+ errors.add("Regular expression is invalid: unrecognized character after (? or (?-");
+ errors.add("Regular expression is invalid: \\c at end of pattern");
+ errors.add("Regular expression is invalid: malformed \\P or \\p sequence");
+ errors.add("Regular expression is invalid: range out of order in character class");
+ errors.add("Regular expression is invalid: group name must start with a non-digit");
+ errors.add("Regular expression is invalid: \\c must be followed by an ASCII character");
+ errors.add("Regular expression is invalid: subpattern name expected");
+ errors.add("Regular expression is invalid: POSIX collating elements are not supported");
+ errors.add("Regular expression is invalid: closing ) for (?C expected");
+ errors.add("Regular expression is invalid: syntax error in subpattern name (missing terminator)");
+ errors.add("Regular expression is invalid: \\\\N is not supported in a class");
+ errors.add("Regular expression is invalid: non-octal character in \\o{} (closing brace missing?)");
+ errors.add("Regular expression is invalid: non-hex character in \\x{} (closing brace missing?)");
+ errors.add(
+ "Regular expression is invalid: \\g is not followed by a braced, angle-bracketed, or quoted name/number or by a plain number");
+ errors.add("Regular expression is invalid: digits missing in \\x{} or \\o{}");
+ errors.add("Regular expression is invalid: malformed number or name after (?(");
+ errors.add("Regular expression is invalid: digit expected after (?+");
+ errors.add("Regular expression is invalid: assertion expected after (?( or (?(?C)");
+ errors.add("Regular expression is invalid: unrecognized character after (?P");
+
+ return errors;
+ }
+
+ @Override
+ public > SQLancerResultSet executeAndGet(G globalState,
+ String... fills) throws Exception {
+ if (globalState.getOptions().logEachSelect()) {
+ globalState.getLogger().writeCurrent(this.getLogString());
+ try {
+ globalState.getLogger().getCurrentFileWriter().flush();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+ List pipeline = MongoDBVisitor.asQuery(select);
+
+ MongoCollection collection = globalState.getConnection().getDatabase()
+ .getCollection(select.getMainTableName());
+ MongoCursor cursor = collection.aggregate(pipeline).cursor();
+ resultSet = new ArrayList<>();
+ while (cursor.hasNext()) {
+ Document document = cursor.next();
+ resultSet.add(document);
+ }
+ return null;
+ }
+
+ @Override
+ public String getLogString() {
+ return MongoDBVisitor.asStringLog(select);
+ }
+
+ public List getResultSet() {
+ return resultSet;
+ }
+
+}
diff --git a/src/sqlancer/mongodb/test/MongoDBColumnTestReference.java b/src/sqlancer/mongodb/test/MongoDBColumnTestReference.java
new file mode 100644
index 000000000..59a2a6724
--- /dev/null
+++ b/src/sqlancer/mongodb/test/MongoDBColumnTestReference.java
@@ -0,0 +1,40 @@
+package sqlancer.mongodb.test;
+
+import sqlancer.common.ast.newast.Node;
+import sqlancer.mongodb.MongoDBSchema.MongoDBColumn;
+import sqlancer.mongodb.ast.MongoDBExpression;
+
+public class MongoDBColumnTestReference implements Node {
+
+ private final MongoDBColumn columnReference;
+ private final boolean inMainTable;
+
+ public MongoDBColumnTestReference(MongoDBColumn columnReference, boolean inMainTable) {
+ this.columnReference = columnReference;
+ this.inMainTable = inMainTable;
+ }
+
+ public String getQueryString() {
+ if (inMainTable) {
+ return this.columnReference.getName();
+ } else {
+ return "join_" + this.columnReference.getTable().getName() + "." + this.columnReference.getName();
+ }
+ }
+
+ public boolean inMainTable() {
+ return inMainTable;
+ }
+
+ public String getTableName() {
+ return this.columnReference.getTable().getName();
+ }
+
+ public String getPlainName() {
+ return this.columnReference.getName();
+ }
+
+ public MongoDBColumn getColumnReference() {
+ return columnReference;
+ }
+}
diff --git a/src/sqlancer/mongodb/test/MongoDBQueryPartitioningBase.java b/src/sqlancer/mongodb/test/MongoDBQueryPartitioningBase.java
new file mode 100644
index 000000000..b26aa6028
--- /dev/null
+++ b/src/sqlancer/mongodb/test/MongoDBQueryPartitioningBase.java
@@ -0,0 +1,92 @@
+package sqlancer.mongodb.test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import sqlancer.Randomly;
+import sqlancer.common.ast.newast.Node;
+import sqlancer.common.gen.ExpressionGenerator;
+import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase;
+import sqlancer.common.oracle.TestOracle;
+import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState;
+import sqlancer.mongodb.MongoDBSchema;
+import sqlancer.mongodb.MongoDBSchema.MongoDBColumn;
+import sqlancer.mongodb.MongoDBSchema.MongoDBTable;
+import sqlancer.mongodb.MongoDBSchema.MongoDBTables;
+import sqlancer.mongodb.ast.MongoDBExpression;
+import sqlancer.mongodb.ast.MongoDBSelect;
+import sqlancer.mongodb.gen.MongoDBComputedExpressionGenerator;
+import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator;
+
+public class MongoDBQueryPartitioningBase
+ extends TernaryLogicPartitioningOracleBase, MongoDBGlobalState> implements TestOracle {
+
+ protected MongoDBSchema schema;
+ protected MongoDBTables targetTables;
+ protected MongoDBTable mainTable;
+ protected List targetColumns;
+ protected MongoDBMatchExpressionGenerator expressionGenerator;
+ protected MongoDBSelect select;
+
+ public MongoDBQueryPartitioningBase(MongoDBGlobalState state) {
+ super(state);
+ }
+
+ @Override
+ public void check() throws Exception {
+ schema = state.getSchema();
+ targetTables = schema.getRandomTableNonEmptyTables();
+ mainTable = targetTables.getTables().get(0);
+ generateTargetColumns();
+ expressionGenerator = new MongoDBMatchExpressionGenerator(state).setColumns(targetColumns);
+ initializeTernaryPredicateVariants();
+ select = new MongoDBSelect<>(mainTable.getName(), targetColumns.get(0));
+ select.setProjectionList(targetColumns);
+ if (Randomly.getBooleanWithRatherLowProbability()) {
+ select.setLookupList(targetColumns);
+ } else {
+ select.setLookupList(Randomly.nonEmptySubset(targetColumns));
+ }
+ if (state.getDmbsSpecificOptions().testComputedValues) {
+ generateComputedColumns();
+ }
+ }
+
+ private void generateComputedColumns() {
+ List> computedColumns = new ArrayList<>();
+ int numberComputedColumns = state.getRandomly().getInteger(1, 4);
+ MongoDBComputedExpressionGenerator generator = new MongoDBComputedExpressionGenerator(state)
+ .setColumns(targetColumns);
+ for (int i = 0; i < numberComputedColumns; i++) {
+ computedColumns.add(generator.generateExpression());
+ }
+ select.setComputedClause(computedColumns);
+ }
+
+ private void generateTargetColumns() {
+ targetColumns = new ArrayList<>();
+ for (MongoDBColumn c : mainTable.getColumns()) {
+ targetColumns.add(new MongoDBColumnTestReference(c, true));
+ }
+ List joinsOtherTables = new ArrayList<>();
+ if (!state.getDmbsSpecificOptions().nullSafety) {
+ for (int i = 1; i < targetTables.getTables().size(); i++) {
+ MongoDBTable procTable = targetTables.getTables().get(i);
+ for (MongoDBColumn c : procTable.getColumns()) {
+ joinsOtherTables.add(new MongoDBColumnTestReference(c, false));
+ }
+ }
+ }
+ if (!joinsOtherTables.isEmpty()) {
+ int randNumber = state.getRandomly().getInteger(1, Math.min(joinsOtherTables.size(), 4));
+ List subsetJoinsOtherTables = Randomly.nonEmptySubset(joinsOtherTables,
+ randNumber);
+ targetColumns.addAll(subsetJoinsOtherTables);
+ }
+ }
+
+ @Override
+ protected ExpressionGenerator> getGen() {
+ return expressionGenerator;
+ }
+}
diff --git a/src/sqlancer/mongodb/test/MongoDBQueryPartitioningWhereTester.java b/src/sqlancer/mongodb/test/MongoDBQueryPartitioningWhereTester.java
new file mode 100644
index 000000000..5a7507672
--- /dev/null
+++ b/src/sqlancer/mongodb/test/MongoDBQueryPartitioningWhereTester.java
@@ -0,0 +1,48 @@
+package sqlancer.mongodb.test;
+
+import static sqlancer.mongodb.MongoDBComparatorHelper.getResultSetAsDocumentList;
+
+import java.util.List;
+
+import org.bson.Document;
+
+import sqlancer.mongodb.MongoDBComparatorHelper;
+import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState;
+import sqlancer.mongodb.query.MongoDBSelectQuery;
+
+public class MongoDBQueryPartitioningWhereTester extends MongoDBQueryPartitioningBase {
+ public MongoDBQueryPartitioningWhereTester(MongoDBGlobalState state) {
+ super(state);
+ }
+
+ @Override
+ public void check() throws Exception {
+ super.check();
+
+ select.setWithCountClause(false);
+
+ select.setFilterClause(null);
+ MongoDBSelectQuery q = new MongoDBSelectQuery(select);
+ List firstResultSet = getResultSetAsDocumentList(q, state);
+
+ select.setFilterClause(predicate);
+ q = new MongoDBSelectQuery(select);
+ List secondResultSet = getResultSetAsDocumentList(q, state);
+
+ select.setFilterClause(negatedPredicate);
+ q = new MongoDBSelectQuery(select);
+ List thirdResultSet = getResultSetAsDocumentList(q, state);
+
+ if (state.getDmbsSpecificOptions().testWithCount) {
+ select.setWithCountClause(true);
+ select.setFilterClause(predicate);
+ q = new MongoDBSelectQuery(select);
+ List forthResultSet = getResultSetAsDocumentList(q, state);
+ MongoDBComparatorHelper.assumeCountIsEqual(secondResultSet, forthResultSet, q);
+ }
+
+ secondResultSet.addAll(thirdResultSet);
+ MongoDBComparatorHelper.assumeResultSetsAreEqual(firstResultSet, secondResultSet, q);
+
+ }
+}
diff --git a/src/sqlancer/mongodb/test/MongoDBRemoveReduceBase.java b/src/sqlancer/mongodb/test/MongoDBRemoveReduceBase.java
new file mode 100644
index 000000000..2a2a54744
--- /dev/null
+++ b/src/sqlancer/mongodb/test/MongoDBRemoveReduceBase.java
@@ -0,0 +1,89 @@
+package sqlancer.mongodb.test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import sqlancer.Randomly;
+import sqlancer.common.ast.newast.Node;
+import sqlancer.common.gen.ExpressionGenerator;
+import sqlancer.common.oracle.RemoveReduceOracleBase;
+import sqlancer.common.oracle.TestOracle;
+import sqlancer.mongodb.MongoDBProvider;
+import sqlancer.mongodb.MongoDBSchema;
+import sqlancer.mongodb.ast.MongoDBExpression;
+import sqlancer.mongodb.ast.MongoDBSelect;
+import sqlancer.mongodb.gen.MongoDBComputedExpressionGenerator;
+import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator;
+
+public class MongoDBRemoveReduceBase extends
+ RemoveReduceOracleBase, MongoDBProvider.MongoDBGlobalState> implements TestOracle {
+
+ protected MongoDBSchema schema;
+ protected MongoDBSchema.MongoDBTables targetTables;
+ protected MongoDBSchema.MongoDBTable mainTable;
+ protected List targetColumns;
+ protected MongoDBMatchExpressionGenerator expressionGenerator;
+ protected MongoDBSelect select;
+
+ protected MongoDBRemoveReduceBase(MongoDBProvider.MongoDBGlobalState state) {
+ super(state);
+ }
+
+ @Override
+ public void check() throws Exception {
+ schema = state.getSchema();
+ targetTables = schema.getRandomTableNonEmptyTables();
+ mainTable = targetTables.getTables().get(0);
+ generateTargetColumns();
+ expressionGenerator = new MongoDBMatchExpressionGenerator(state).setColumns(targetColumns);
+ initializeRemoveReduceOracle();
+ select = new MongoDBSelect<>(mainTable.getName(), targetColumns.get(0));
+ select.setProjectionList(targetColumns);
+ if (Randomly.getBooleanWithRatherLowProbability()) {
+ select.setLookupList(targetColumns);
+ } else {
+ select.setLookupList(Randomly.nonEmptySubset(targetColumns));
+ }
+ if (state.getDmbsSpecificOptions().testComputedValues) {
+ generateComputedColumns();
+ }
+ }
+
+ private void generateTargetColumns() {
+ targetColumns = new ArrayList<>();
+ for (MongoDBSchema.MongoDBColumn c : mainTable.getColumns()) {
+ targetColumns.add(new MongoDBColumnTestReference(c, true));
+ }
+ List