diff --git a/pom.xml b/pom.xml index 3dea1ef15..2b8783fb8 100644 --- a/pom.xml +++ b/pom.xml @@ -266,6 +266,16 @@ h2 1.4.200 + + org.mongodb + mongodb-driver-sync + 4.1.1 + + + com.arangodb + arangodb-java-driver + 6.9.0 + diff --git a/src/sqlancer/GlobalState.java b/src/sqlancer/GlobalState.java index 97c5a7a9d..642c9da6a 100644 --- a/src/sqlancer/GlobalState.java +++ b/src/sqlancer/GlobalState.java @@ -89,13 +89,13 @@ private ExecutionTimer executePrologue(Query q) throws Exception { timer = new ExecutionTimer().start(); } if (getOptions().printAllStatements()) { - System.out.println(q.getQueryString()); + System.out.println(q.getLogString()); } if (getOptions().logEachSelect()) { if (logExecutionTime) { - getLogger().writeCurrentNoLineBreak(q.getQueryString()); + getLogger().writeCurrentNoLineBreak(q.getLogString()); } else { - getLogger().writeCurrent(q.getQueryString()); + getLogger().writeCurrent(q.getLogString()); } } return timer; diff --git a/src/sqlancer/Main.java b/src/sqlancer/Main.java index 755ce36a8..c5da9ec98 100644 --- a/src/sqlancer/Main.java +++ b/src/sqlancer/Main.java @@ -21,15 +21,18 @@ import com.beust.jcommander.JCommander; import com.beust.jcommander.JCommander.Builder; +import sqlancer.arangodb.ArangoDBProvider; import sqlancer.citus.CitusProvider; import sqlancer.clickhouse.ClickHouseProvider; import sqlancer.cockroachdb.CockroachDBProvider; import sqlancer.common.log.Loggable; import sqlancer.common.query.Query; import sqlancer.common.query.SQLancerResultSet; +import sqlancer.cosmos.CosmosProvider; import sqlancer.duckdb.DuckDBProvider; import sqlancer.h2.H2Provider; import sqlancer.mariadb.MariaDBProvider; +import sqlancer.mongodb.MongoDBProvider; import sqlancer.mysql.MySQLProvider; import sqlancer.postgres.PostgresProvider; import sqlancer.sqlite3.SQLite3Provider; @@ -209,7 +212,7 @@ private void printState(FileWriter writer, StateToReproduce state) { .getInfo(state.getDatabaseName(), state.getDatabaseVersion(), state.getSeedValue()).getLogString()); for (Query s : state.getStatements()) { - sb.append(s.getQueryString()); + sb.append(s.getLogString()); sb.append('\n'); } try { @@ -554,6 +557,9 @@ private boolean run(MainOptions options, ExecutorService execService, providers.add(new ClickHouseProvider()); providers.add(new DuckDBProvider()); providers.add(new H2Provider()); + providers.add(new MongoDBProvider()); + providers.add(new CosmosProvider()); + providers.add(new ArangoDBProvider()); return providers; } diff --git a/src/sqlancer/arangodb/ArangoDBComparatorHelper.java b/src/sqlancer/arangodb/ArangoDBComparatorHelper.java new file mode 100644 index 000000000..2a00a312d --- /dev/null +++ b/src/sqlancer/arangodb/ArangoDBComparatorHelper.java @@ -0,0 +1,73 @@ +package sqlancer.arangodb; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import com.arangodb.entity.BaseDocument; + +import sqlancer.IgnoreMeException; +import sqlancer.Main; +import sqlancer.arangodb.query.ArangoDBSelectQuery; +import sqlancer.common.query.ExpectedErrors; + +public final class ArangoDBComparatorHelper { + + private ArangoDBComparatorHelper() { + + } + + public static List getResultSetAsDocumentList(ArangoDBSelectQuery query, + ArangoDBProvider.ArangoDBGlobalState state) throws Exception { + ExpectedErrors errors = query.getExpectedErrors(); + List result; + try { + query.executeAndGet(state); + Main.nrSuccessfulActions.addAndGet(1); + result = query.getResultSet(); + return result; + } catch (Exception e) { + if (e instanceof IgnoreMeException) { + throw e; + } + Main.nrUnsuccessfulActions.addAndGet(1); + if (e.getMessage() == null) { + throw new AssertionError(query.getLogString(), e); + } + if (errors.errorIsExpected(e.getMessage())) { + throw new IgnoreMeException(); + } + throw new AssertionError(query.getLogString(), e); + } + + } + + public static void assumeResultSetsAreEqual(List resultSet, List secondResultSet, + ArangoDBSelectQuery originalQuery) { + if (resultSet.size() != secondResultSet.size()) { + String assertionMessage = String.format("The Size of the result sets mismatch (%d and %d)!\n%s", + resultSet.size(), secondResultSet.size(), originalQuery.getLogString()); + throw new AssertionError(assertionMessage); + } + Set firstHashSet = new HashSet<>(resultSet); + Set secondHashSet = new HashSet<>(secondResultSet); + + if (!firstHashSet.equals(secondHashSet)) { + Set firstResultSetMisses = new HashSet<>(firstHashSet); + firstResultSetMisses.removeAll(secondHashSet); + Set secondResultSetMisses = new HashSet<>(secondHashSet); + secondResultSetMisses.removeAll(firstHashSet); + StringBuilder firstMisses = new StringBuilder(); + for (BaseDocument document : firstResultSetMisses) { + firstMisses.append(document).append(" "); + } + StringBuilder secondMisses = new StringBuilder(); + for (BaseDocument document : secondResultSetMisses) { + secondMisses.append(document).append(" "); + } + String assertMessage = String.format("The Content of the result sets mismatch!\n %s \n %s\n %s", + firstMisses.toString(), secondMisses.toString(), originalQuery.getLogString()); + throw new AssertionError(assertMessage); + } + } +} diff --git a/src/sqlancer/arangodb/ArangoDBConnection.java b/src/sqlancer/arangodb/ArangoDBConnection.java new file mode 100644 index 000000000..b3e5b85d3 --- /dev/null +++ b/src/sqlancer/arangodb/ArangoDBConnection.java @@ -0,0 +1,31 @@ +package sqlancer.arangodb; + +import com.arangodb.ArangoDB; +import com.arangodb.ArangoDatabase; + +import sqlancer.SQLancerDBConnection; + +public class ArangoDBConnection implements SQLancerDBConnection { + + private final ArangoDB client; + private final ArangoDatabase database; + + public ArangoDBConnection(ArangoDB client, ArangoDatabase database) { + this.client = client; + this.database = database; + } + + @Override + public String getDatabaseVersion() throws Exception { + return client.getVersion().getVersion(); + } + + @Override + public void close() throws Exception { + client.shutdown(); + } + + public ArangoDatabase getDatabase() { + return database; + } +} diff --git a/src/sqlancer/arangodb/ArangoDBLoggableFactory.java b/src/sqlancer/arangodb/ArangoDBLoggableFactory.java new file mode 100644 index 000000000..927d9f320 --- /dev/null +++ b/src/sqlancer/arangodb/ArangoDBLoggableFactory.java @@ -0,0 +1,40 @@ +package sqlancer.arangodb; + +import java.util.Arrays; + +import sqlancer.common.log.Loggable; +import sqlancer.common.log.LoggableFactory; +import sqlancer.common.log.LoggedString; +import sqlancer.common.query.Query; + +public class ArangoDBLoggableFactory extends LoggableFactory { + @Override + protected Loggable createLoggable(String input, String suffix) { + return new LoggedString(input + suffix); + } + + @Override + public Query getQueryForStateToReproduce(String queryString) { + throw new UnsupportedOperationException(); + } + + @Override + public Query commentOutQuery(Query query) { + throw new UnsupportedOperationException(); + } + + @Override + protected Loggable infoToLoggable(String time, String databaseName, String databaseVersion, long seedValue) { + StringBuilder sb = new StringBuilder(); + sb.append("// Time: ").append(time).append("\n"); + sb.append("// Database: ").append(databaseName).append("\n"); + sb.append("// Database version: ").append(databaseVersion).append("\n"); + sb.append("// seed value: ").append(seedValue).append("\n"); + return new LoggedString(sb.toString()); + } + + @Override + public Loggable convertStacktraceToLoggable(Throwable throwable) { + return new LoggedString(Arrays.toString(throwable.getStackTrace()) + "\n" + throwable.getMessage()); + } +} diff --git a/src/sqlancer/arangodb/ArangoDBOptions.java b/src/sqlancer/arangodb/ArangoDBOptions.java new file mode 100644 index 000000000..cdb7ee759 --- /dev/null +++ b/src/sqlancer/arangodb/ArangoDBOptions.java @@ -0,0 +1,44 @@ +package sqlancer.arangodb; + +import static sqlancer.arangodb.ArangoDBOptions.ArangoDBOracleFactory.QUERY_PARTITIONING; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import com.beust.jcommander.Parameter; + +import sqlancer.DBMSSpecificOptions; +import sqlancer.OracleFactory; +import sqlancer.arangodb.test.ArangoDBQueryPartitioningWhereTester; +import sqlancer.common.oracle.CompositeTestOracle; +import sqlancer.common.oracle.TestOracle; + +public class ArangoDBOptions implements DBMSSpecificOptions { + + @Parameter(names = "--oracle") + public List oracles = Arrays.asList(QUERY_PARTITIONING); + + @Parameter(names = "--test-random-type-inserts", description = "Insert random types instead of schema types.") + public boolean testRandomTypeInserts; + + @Parameter(names = "--max-number-indexes", description = "The maximum number of indexes used.", arity = 1) + public int maxNumberIndexes = 15; + + @Override + public List getTestOracleFactory() { + return oracles; + } + + public enum ArangoDBOracleFactory implements OracleFactory { + QUERY_PARTITIONING { + @Override + public TestOracle create(ArangoDBProvider.ArangoDBGlobalState globalState) throws Exception { + List oracles = new ArrayList<>(); + oracles.add(new ArangoDBQueryPartitioningWhereTester(globalState)); + return new CompositeTestOracle(oracles, globalState); + } + } + + } +} diff --git a/src/sqlancer/arangodb/ArangoDBProvider.java b/src/sqlancer/arangodb/ArangoDBProvider.java new file mode 100644 index 000000000..d63d4a1cc --- /dev/null +++ b/src/sqlancer/arangodb/ArangoDBProvider.java @@ -0,0 +1,134 @@ +package sqlancer.arangodb; + +import java.util.ArrayList; +import java.util.List; + +import com.arangodb.ArangoDB; +import com.arangodb.ArangoDatabase; + +import sqlancer.AbstractAction; +import sqlancer.ExecutionTimer; +import sqlancer.GlobalState; +import sqlancer.IgnoreMeException; +import sqlancer.ProviderAdapter; +import sqlancer.Randomly; +import sqlancer.StatementExecutor; +import sqlancer.arangodb.gen.ArangoDBCreateIndexGenerator; +import sqlancer.arangodb.gen.ArangoDBInsertGenerator; +import sqlancer.arangodb.gen.ArangoDBTableGenerator; +import sqlancer.common.log.LoggableFactory; +import sqlancer.common.query.Query; + +public class ArangoDBProvider + extends ProviderAdapter { + + public ArangoDBProvider() { + super(ArangoDBGlobalState.class, ArangoDBOptions.class); + } + + enum Action implements AbstractAction { + INSERT(ArangoDBInsertGenerator::getQuery), CREATE_INDEX(ArangoDBCreateIndexGenerator::getQuery); + + private final ArangoDBQueryProvider queryProvider; + + Action(ArangoDBQueryProvider queryProvider) { + this.queryProvider = queryProvider; + } + + @Override + public Query getQuery(ArangoDBGlobalState globalState) throws Exception { + return queryProvider.getQuery(globalState); + } + } + + private static int mapActions(ArangoDBGlobalState globalState, Action a) { + Randomly r = globalState.getRandomly(); + switch (a) { + case INSERT: + return r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); + case CREATE_INDEX: + return r.getInteger(0, globalState.getDmbsSpecificOptions().maxNumberIndexes); + default: + throw new AssertionError(a); + } + } + + public static class ArangoDBGlobalState extends GlobalState { + + private final List schemaTables = new ArrayList<>(); + + public void addTable(ArangoDBSchema.ArangoDBTable table) { + schemaTables.add(table); + } + + @Override + protected void executeEpilogue(Query q, boolean success, ExecutionTimer timer) throws Exception { + boolean logExecutionTime = getOptions().logExecutionTime(); + if (success && getOptions().printSucceedingStatements()) { + System.out.println(q.getLogString()); + } + if (logExecutionTime) { + getLogger().writeCurrent("//" + timer.end().asString()); + } + if (q.couldAffectSchema()) { + updateSchema(); + } + } + + @Override + protected ArangoDBSchema readSchema() throws Exception { + return new ArangoDBSchema(schemaTables); + } + } + + @Override + protected void checkViewsAreValid(ArangoDBGlobalState globalState) { + + } + + @Override + public void generateDatabase(ArangoDBGlobalState globalState) throws Exception { + for (int i = 0; i < Randomly.fromOptions(4, 5, 6); i++) { + boolean success; + do { + ArangoDBQueryAdapter queryAdapter = new ArangoDBTableGenerator().getQuery(globalState); + success = globalState.executeStatement(queryAdapter); + } while (!success); + } + StatementExecutor se = new StatementExecutor<>(globalState, Action.values(), + ArangoDBProvider::mapActions, (q) -> { + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + }); + se.executeStatements(); + } + + @Override + public ArangoDBConnection createDatabase(ArangoDBGlobalState globalState) throws Exception { + ArangoDB arangoDB = new ArangoDB.Builder().user(globalState.getOptions().getUserName()) + .password(globalState.getOptions().getPassword()).build(); + ArangoDatabase database = arangoDB.db(globalState.getDatabaseName()); + try { + database.drop(); + // When the database does not exist, an ArangoDB exception is thrown. Since we are not sure + // if this is the first time the database is used, the simplest is dropping it and ignoring + // the exception. + } catch (Exception ignored) { + + } + arangoDB.createDatabase(globalState.getDatabaseName()); + database = arangoDB.db(globalState.getDatabaseName()); + return new ArangoDBConnection(arangoDB, database); + } + + @Override + public String getDBMSName() { + return "arangodb"; + } + + @Override + public LoggableFactory getLoggableFactory() { + return new ArangoDBLoggableFactory(); + } +} diff --git a/src/sqlancer/arangodb/ArangoDBQueryAdapter.java b/src/sqlancer/arangodb/ArangoDBQueryAdapter.java new file mode 100644 index 000000000..34cdb3709 --- /dev/null +++ b/src/sqlancer/arangodb/ArangoDBQueryAdapter.java @@ -0,0 +1,16 @@ +package sqlancer.arangodb; + +import sqlancer.common.query.Query; + +public abstract class ArangoDBQueryAdapter extends Query { + @Override + public String getQueryString() { + // Should not be called as it is used only in SQL dependent classes + throw new UnsupportedOperationException(); + } + + @Override + public String getUnterminatedQueryString() { + throw new UnsupportedOperationException(); + } +} diff --git a/src/sqlancer/arangodb/ArangoDBQueryProvider.java b/src/sqlancer/arangodb/ArangoDBQueryProvider.java new file mode 100644 index 000000000..94a4ffda3 --- /dev/null +++ b/src/sqlancer/arangodb/ArangoDBQueryProvider.java @@ -0,0 +1,6 @@ +package sqlancer.arangodb; + +@FunctionalInterface +public interface ArangoDBQueryProvider { + ArangoDBQueryAdapter getQuery(S globalState) throws Exception; +} diff --git a/src/sqlancer/arangodb/ArangoDBSchema.java b/src/sqlancer/arangodb/ArangoDBSchema.java new file mode 100644 index 000000000..35e251b8b --- /dev/null +++ b/src/sqlancer/arangodb/ArangoDBSchema.java @@ -0,0 +1,70 @@ +package sqlancer.arangodb; + +import java.util.Collections; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTable; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; +import sqlancer.common.schema.TableIndex; + +public class ArangoDBSchema extends AbstractSchema { + + public enum ArangoDBDataType { + INTEGER, DOUBLE, STRING, BOOLEAN; + + public static ArangoDBDataType getRandom() { + return Randomly.fromOptions(values()); + } + } + + public static class ArangoDBColumn extends AbstractTableColumn { + + private final boolean isId; + private final boolean isNullable; + + public ArangoDBColumn(String name, ArangoDBDataType type, boolean isId, boolean isNullable) { + super(name, null, type); + this.isId = isId; + this.isNullable = isNullable; + } + + public boolean isId() { + return isId; + } + + public boolean isNullable() { + return isNullable; + } + } + + public ArangoDBSchema(List databaseTables) { + super(databaseTables); + } + + public static class ArangoDBTables extends AbstractTables { + + public ArangoDBTables(List tables) { + super(tables); + } + } + + public static class ArangoDBTable + extends AbstractTable { + + public ArangoDBTable(String name, List columns, boolean isView) { + super(name, columns, Collections.emptyList(), isView); + } + + @Override + public long getNrRows(ArangoDBProvider.ArangoDBGlobalState globalState) { + throw new UnsupportedOperationException(); + } + } + + public ArangoDBTables getRandomTableNonEmptyTables() { + return new ArangoDBTables(Randomly.nonEmptySubset(getDatabaseTables())); + } +} diff --git a/src/sqlancer/arangodb/ast/ArangoDBConstant.java b/src/sqlancer/arangodb/ast/ArangoDBConstant.java new file mode 100644 index 000000000..351dbd822 --- /dev/null +++ b/src/sqlancer/arangodb/ast/ArangoDBConstant.java @@ -0,0 +1,108 @@ +package sqlancer.arangodb.ast; + +import com.arangodb.entity.BaseDocument; + +import sqlancer.common.ast.newast.Node; + +public abstract class ArangoDBConstant implements Node { + private ArangoDBConstant() { + + } + + public abstract void setValueInDocument(BaseDocument document, String key); + + public abstract Object getValue(); + + public static class ArangoDBIntegerConstant extends ArangoDBConstant { + + private final int value; + + public ArangoDBIntegerConstant(int value) { + this.value = value; + } + + @Override + public void setValueInDocument(BaseDocument document, String key) { + document.addAttribute(key, value); + } + + @Override + public Object getValue() { + return value; + } + } + + public static Node createIntegerConstant(int value) { + return new ArangoDBIntegerConstant(value); + } + + public static class ArangoDBStringConstant extends ArangoDBConstant { + private final String value; + + public ArangoDBStringConstant(String value) { + this.value = value; + } + + @Override + public void setValueInDocument(BaseDocument document, String key) { + document.addAttribute(key, value); + } + + @Override + public Object getValue() { + return "'" + value.replace("\\", "\\\\").replace("'", "\\'") + "'"; + } + } + + public static Node createStringConstant(String value) { + return new ArangoDBStringConstant(value); + } + + public static class ArangoDBBooleanConstant extends ArangoDBConstant { + private final boolean value; + + public ArangoDBBooleanConstant(boolean value) { + this.value = value; + } + + @Override + public void setValueInDocument(BaseDocument document, String key) { + document.addAttribute(key, value); + } + + @Override + public Object getValue() { + return value; + } + } + + public static Node createBooleanConstant(boolean value) { + return new ArangoDBBooleanConstant(value); + } + + public static class ArangoDBDoubleConstant extends ArangoDBConstant { + private final double value; + + public ArangoDBDoubleConstant(double value) { + if (Double.isInfinite(value) || Double.isNaN(value)) { + this.value = 0.0; + } else { + this.value = value; + } + } + + @Override + public void setValueInDocument(BaseDocument document, String key) { + document.addAttribute(key, value); + } + + @Override + public Object getValue() { + return value; + } + } + + public static Node createDoubleConstant(double value) { + return new ArangoDBDoubleConstant(value); + } +} diff --git a/src/sqlancer/arangodb/ast/ArangoDBExpression.java b/src/sqlancer/arangodb/ast/ArangoDBExpression.java new file mode 100644 index 000000000..facbbfe9e --- /dev/null +++ b/src/sqlancer/arangodb/ast/ArangoDBExpression.java @@ -0,0 +1,4 @@ +package sqlancer.arangodb.ast; + +public interface ArangoDBExpression { +} diff --git a/src/sqlancer/arangodb/ast/ArangoDBSelect.java b/src/sqlancer/arangodb/ast/ArangoDBSelect.java new file mode 100644 index 000000000..9fb91d553 --- /dev/null +++ b/src/sqlancer/arangodb/ast/ArangoDBSelect.java @@ -0,0 +1,79 @@ +package sqlancer.arangodb.ast; + +import java.util.List; + +import sqlancer.arangodb.ArangoDBSchema; +import sqlancer.common.ast.newast.Node; + +public class ArangoDBSelect implements Node { + private List fromColumns; + private List projectionColumns; + private boolean hasFilter; + private Node filterClause; + private boolean hasComputed; + private List> computedClause; + + public List getFromColumns() { + if (fromColumns == null || fromColumns.isEmpty()) { + throw new IllegalStateException(); + } + return fromColumns; + } + + public void setFromColumns(List fromColumns) { + if (fromColumns == null || fromColumns.isEmpty()) { + throw new IllegalStateException(); + } + this.fromColumns = fromColumns; + } + + public List getProjectionColumns() { + if (projectionColumns == null) { + throw new IllegalStateException(); + } + return projectionColumns; + } + + public void setProjectionColumns(List projectionColumns) { + if (projectionColumns == null) { + throw new IllegalStateException(); + } + this.projectionColumns = projectionColumns; + } + + public void setFilterClause(Node filterClause) { + if (filterClause == null) { + hasFilter = false; + this.filterClause = null; + return; + } + hasFilter = true; + this.filterClause = filterClause; + } + + public Node getFilterClause() { + return filterClause; + } + + public boolean hasFilter() { + return hasFilter; + } + + public void setComputedClause(List> computedColumns) { + if (computedColumns == null || computedColumns.isEmpty()) { + hasComputed = false; + this.computedClause = null; + return; + } + hasComputed = true; + this.computedClause = computedColumns; + } + + public List> getComputedClause() { + return computedClause; + } + + public boolean hasComputed() { + return hasComputed; + } +} diff --git a/src/sqlancer/arangodb/ast/ArangoDBUnsupportedPredicate.java b/src/sqlancer/arangodb/ast/ArangoDBUnsupportedPredicate.java new file mode 100644 index 000000000..eabd25578 --- /dev/null +++ b/src/sqlancer/arangodb/ast/ArangoDBUnsupportedPredicate.java @@ -0,0 +1,6 @@ +package sqlancer.arangodb.ast; + +import sqlancer.common.ast.newast.Node; + +public class ArangoDBUnsupportedPredicate implements Node { +} diff --git a/src/sqlancer/arangodb/gen/ArangoDBComputedExpressionGenerator.java b/src/sqlancer/arangodb/gen/ArangoDBComputedExpressionGenerator.java new file mode 100644 index 000000000..8a3b98871 --- /dev/null +++ b/src/sqlancer/arangodb/gen/ArangoDBComputedExpressionGenerator.java @@ -0,0 +1,85 @@ +package sqlancer.arangodb.gen; + +import sqlancer.Randomly; +import sqlancer.arangodb.ArangoDBProvider; +import sqlancer.arangodb.ArangoDBSchema; +import sqlancer.arangodb.ast.ArangoDBConstant; +import sqlancer.arangodb.ast.ArangoDBExpression; +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.common.ast.newast.NewFunctionNode; +import sqlancer.common.ast.newast.Node; +import sqlancer.common.gen.UntypedExpressionGenerator; + +public class ArangoDBComputedExpressionGenerator + extends UntypedExpressionGenerator, ArangoDBSchema.ArangoDBColumn> { + private final ArangoDBProvider.ArangoDBGlobalState globalState; + + public ArangoDBComputedExpressionGenerator(ArangoDBProvider.ArangoDBGlobalState globalState) { + this.globalState = globalState; + } + + @Override + public Node generateConstant() { + ArangoDBSchema.ArangoDBDataType dataType = ArangoDBSchema.ArangoDBDataType.getRandom(); + switch (dataType) { + case INTEGER: + return ArangoDBConstant.createIntegerConstant((int) globalState.getRandomly().getInteger()); + case BOOLEAN: + return ArangoDBConstant.createBooleanConstant(Randomly.getBoolean()); + case DOUBLE: + return ArangoDBConstant.createDoubleConstant(globalState.getRandomly().getDouble()); + case STRING: + return ArangoDBConstant.createStringConstant(globalState.getRandomly().getString()); + default: + throw new AssertionError(dataType); + } + } + + public enum ComputedFunction { + ADD(2, "+"), MINUS(2, "-"), MULTIPLY(2, "*"), DIVISION(2, "/"), MODULUS(2, "%"); + + private final int nrArgs; + private final String operatorName; + + ComputedFunction(int nrArgs, String operatorName) { + this.nrArgs = nrArgs; + this.operatorName = operatorName; + } + + public static ComputedFunction getRandom() { + return Randomly.fromOptions(values()); + } + + public int getNrArgs() { + return nrArgs; + } + + public String getOperatorName() { + return operatorName; + } + } + + @Override + protected Node generateExpression(int depth) { + if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { + return generateLeafNode(); + } + ComputedFunction function = ComputedFunction.getRandom(); + return new NewFunctionNode<>(generateExpressions(depth + 1, function.getNrArgs()), function); + } + + @Override + protected Node generateColumn() { + return new ColumnReferenceNode<>(Randomly.fromList(columns)); + } + + @Override + public Node negatePredicate(Node predicate) { + throw new UnsupportedOperationException(); + } + + @Override + public Node isNull(Node expr) { + throw new UnsupportedOperationException(); + } +} diff --git a/src/sqlancer/arangodb/gen/ArangoDBCreateIndexGenerator.java b/src/sqlancer/arangodb/gen/ArangoDBCreateIndexGenerator.java new file mode 100644 index 000000000..6a1b872da --- /dev/null +++ b/src/sqlancer/arangodb/gen/ArangoDBCreateIndexGenerator.java @@ -0,0 +1,18 @@ +package sqlancer.arangodb.gen; + +import sqlancer.arangodb.ArangoDBProvider; +import sqlancer.arangodb.ArangoDBQueryAdapter; +import sqlancer.arangodb.ArangoDBSchema; +import sqlancer.arangodb.query.ArangoDBCreateIndexQuery; + +public final class ArangoDBCreateIndexGenerator { + private ArangoDBCreateIndexGenerator() { + + } + + public static ArangoDBQueryAdapter getQuery(ArangoDBProvider.ArangoDBGlobalState globalState) { + ArangoDBSchema.ArangoDBTable randomTable = globalState.getSchema().getRandomTable(); + ArangoDBSchema.ArangoDBColumn column = randomTable.getRandomColumn(); + return new ArangoDBCreateIndexQuery(column); + } +} diff --git a/src/sqlancer/arangodb/gen/ArangoDBFilterExpressionGenerator.java b/src/sqlancer/arangodb/gen/ArangoDBFilterExpressionGenerator.java new file mode 100644 index 000000000..1a2fc4b5e --- /dev/null +++ b/src/sqlancer/arangodb/gen/ArangoDBFilterExpressionGenerator.java @@ -0,0 +1,153 @@ +package sqlancer.arangodb.gen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.arangodb.ArangoDBProvider; +import sqlancer.arangodb.ArangoDBSchema; +import sqlancer.arangodb.ast.ArangoDBConstant; +import sqlancer.arangodb.ast.ArangoDBExpression; +import sqlancer.arangodb.ast.ArangoDBUnsupportedPredicate; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; +import sqlancer.common.ast.newast.Node; +import sqlancer.common.gen.UntypedExpressionGenerator; + +public class ArangoDBFilterExpressionGenerator + extends UntypedExpressionGenerator, ArangoDBSchema.ArangoDBColumn> { + + private final ArangoDBProvider.ArangoDBGlobalState globalState; + private int numberOfComputedVariables; + + private enum Expression { + BINARY_LOGICAL, UNARY_PREFIX, BINARY_COMPARISON + } + + public ArangoDBFilterExpressionGenerator(ArangoDBProvider.ArangoDBGlobalState globalState) { + this.globalState = globalState; + } + + public void setNumberOfComputedVariables(int numberOfComputedVariables) { + this.numberOfComputedVariables = numberOfComputedVariables; + } + + @Override + public Node generateConstant() { + ArangoDBSchema.ArangoDBDataType dataType = ArangoDBSchema.ArangoDBDataType.getRandom(); + switch (dataType) { + case INTEGER: + return ArangoDBConstant.createIntegerConstant((int) globalState.getRandomly().getInteger()); + case BOOLEAN: + return ArangoDBConstant.createBooleanConstant(Randomly.getBoolean()); + case DOUBLE: + return ArangoDBConstant.createDoubleConstant(globalState.getRandomly().getDouble()); + case STRING: + return ArangoDBConstant.createStringConstant(globalState.getRandomly().getString()); + default: + throw new AssertionError(dataType); + } + } + + @Override + protected Node generateExpression(int depth) { + if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { + return generateLeafNode(); + } + List possibleOptions = new ArrayList<>(Arrays.asList(Expression.values())); + Expression expression = Randomly.fromList(possibleOptions); + switch (expression) { + case BINARY_COMPARISON: + BinaryOperatorNode.Operator op = ArangoDBBinaryComparisonOperator.getRandom(); + return new NewBinaryOperatorNode<>(generateExpression(depth + 1), generateExpression(depth + 1), op); + case UNARY_PREFIX: + return new NewUnaryPrefixOperatorNode<>(generateExpression(depth + 1), + ArangoDBUnaryPrefixOperator.getRandom()); + case BINARY_LOGICAL: + op = ArangoDBBinaryLogicalOperator.getRandom(); + return new NewBinaryOperatorNode<>(generateExpression(depth + 1), generateExpression(depth + 1), op); + default: + throw new AssertionError(expression); + } + } + + @Override + protected Node generateColumn() { + ArangoDBSchema.ArangoDBTable dummy = new ArangoDBSchema.ArangoDBTable("", new ArrayList<>(), false); + if (Randomly.getBoolean() || numberOfComputedVariables == 0) { + ArangoDBSchema.ArangoDBColumn column = Randomly.fromList(columns); + return new ColumnReferenceNode<>(column); + } else { + int maxNumber = globalState.getRandomly().getInteger(0, numberOfComputedVariables); + ArangoDBSchema.ArangoDBColumn column = new ArangoDBSchema.ArangoDBColumn("c" + maxNumber, + ArangoDBSchema.ArangoDBDataType.INTEGER, false, false); + column.setTable(dummy); + return new ColumnReferenceNode<>(column); + } + } + + @Override + public Node negatePredicate(Node predicate) { + return new NewUnaryPrefixOperatorNode<>(predicate, ArangoDBUnaryPrefixOperator.NOT); + } + + @Override + public Node isNull(Node expr) { + return new ArangoDBUnsupportedPredicate<>(); + } + + public enum ArangoDBBinaryComparisonOperator implements BinaryOperatorNode.Operator { + EQUALS("=="), NOT_EQUALS("!="), LESS_THAN("<"), LESS_OR_EQUAL("<="), GREATER_THAN(">"), GREATER_OR_EQUAL(">="); + + private final String representation; + + ArangoDBBinaryComparisonOperator(String representation) { + this.representation = representation; + } + + @Override + public String getTextRepresentation() { + return representation; + } + + public static ArangoDBBinaryComparisonOperator getRandom() { + return Randomly.fromOptions(values()); + } + } + + public enum ArangoDBUnaryPrefixOperator implements BinaryOperatorNode.Operator { + NOT("!"); + + private final String representation; + + ArangoDBUnaryPrefixOperator(String representation) { + this.representation = representation; + } + + @Override + public String getTextRepresentation() { + return representation; + } + + public static ArangoDBUnaryPrefixOperator getRandom() { + return Randomly.fromOptions(values()); + } + } + + public enum ArangoDBBinaryLogicalOperator implements BinaryOperatorNode.Operator { + AND, OR; + + @Override + public String getTextRepresentation() { + return toString(); + } + + public static BinaryOperatorNode.Operator getRandom() { + return Randomly.fromOptions(values()); + } + } + +} diff --git a/src/sqlancer/arangodb/gen/ArangoDBInsertGenerator.java b/src/sqlancer/arangodb/gen/ArangoDBInsertGenerator.java new file mode 100644 index 000000000..9a27ccd57 --- /dev/null +++ b/src/sqlancer/arangodb/gen/ArangoDBInsertGenerator.java @@ -0,0 +1,39 @@ +package sqlancer.arangodb.gen; + +import com.arangodb.entity.BaseDocument; + +import sqlancer.arangodb.ArangoDBProvider; +import sqlancer.arangodb.ArangoDBQueryAdapter; +import sqlancer.arangodb.ArangoDBSchema; +import sqlancer.arangodb.query.ArangoDBConstantGenerator; +import sqlancer.arangodb.query.ArangoDBInsertQuery; + +public final class ArangoDBInsertGenerator { + + private final ArangoDBProvider.ArangoDBGlobalState globalState; + + private ArangoDBInsertGenerator(ArangoDBProvider.ArangoDBGlobalState globalState) { + this.globalState = globalState; + } + + public static ArangoDBQueryAdapter getQuery(ArangoDBProvider.ArangoDBGlobalState globalState) { + return new ArangoDBInsertGenerator(globalState).generate(); + } + + private ArangoDBQueryAdapter generate() { + BaseDocument result = new BaseDocument(); + ArangoDBSchema.ArangoDBTable table = globalState.getSchema().getRandomTable(); + ArangoDBConstantGenerator constantGenerator = new ArangoDBConstantGenerator(globalState); + + for (int i = 0; i < table.getColumns().size(); i++) { + if (!globalState.getDmbsSpecificOptions().testRandomTypeInserts) { + constantGenerator.addRandomConstantWithType(result, table.getColumns().get(i).getName(), + table.getColumns().get(i).getType()); + } else { + constantGenerator.addRandomConstant(result, table.getColumns().get(i).getName()); + } + } + + return new ArangoDBInsertQuery(table, result); + } +} diff --git a/src/sqlancer/arangodb/gen/ArangoDBTableGenerator.java b/src/sqlancer/arangodb/gen/ArangoDBTableGenerator.java new file mode 100644 index 000000000..1236c3ce4 --- /dev/null +++ b/src/sqlancer/arangodb/gen/ArangoDBTableGenerator.java @@ -0,0 +1,44 @@ +package sqlancer.arangodb.gen; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.arangodb.ArangoDBProvider; +import sqlancer.arangodb.ArangoDBQueryAdapter; +import sqlancer.arangodb.ArangoDBSchema; +import sqlancer.arangodb.query.ArangoDBCreateTableQuery; + +public class ArangoDBTableGenerator { + + private ArangoDBSchema.ArangoDBTable table; + private final List columnsToBeAdded = new ArrayList<>(); + + public ArangoDBQueryAdapter getQuery(ArangoDBProvider.ArangoDBGlobalState globalState) { + String tableName = globalState.getSchema().getFreeTableName(); + ArangoDBCreateTableQuery createTableQuery = new ArangoDBCreateTableQuery(tableName); + table = new ArangoDBSchema.ArangoDBTable(tableName, columnsToBeAdded, false); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + String columnName = String.format("c%d", i); + createColumn(columnName); + } + globalState.addTable(table); + return createTableQuery; + } + + private ArangoDBSchema.ArangoDBDataType createColumn(String columnName) { + ArangoDBSchema.ArangoDBDataType dataType = ArangoDBSchema.ArangoDBDataType.getRandom(); + ArangoDBSchema.ArangoDBColumn newColumn = new ArangoDBSchema.ArangoDBColumn(columnName, dataType, false, false); + newColumn.setTable(table); + columnsToBeAdded.add(newColumn); + return dataType; + } + + public String getTableName() { + return table.getName(); + } + + public ArangoDBSchema.ArangoDBTable getGeneratedTable() { + return table; + } +} diff --git a/src/sqlancer/arangodb/query/ArangoDBConstantGenerator.java b/src/sqlancer/arangodb/query/ArangoDBConstantGenerator.java new file mode 100644 index 000000000..406e8adca --- /dev/null +++ b/src/sqlancer/arangodb/query/ArangoDBConstantGenerator.java @@ -0,0 +1,46 @@ +package sqlancer.arangodb.query; + +import com.arangodb.entity.BaseDocument; + +import sqlancer.Randomly; +import sqlancer.arangodb.ArangoDBProvider; +import sqlancer.arangodb.ArangoDBSchema; +import sqlancer.arangodb.ast.ArangoDBConstant; + +public class ArangoDBConstantGenerator { + private final ArangoDBProvider.ArangoDBGlobalState globalState; + + public ArangoDBConstantGenerator(ArangoDBProvider.ArangoDBGlobalState globalState) { + this.globalState = globalState; + } + + public void addRandomConstant(BaseDocument document, String key) { + ArangoDBSchema.ArangoDBDataType type = ArangoDBSchema.ArangoDBDataType.getRandom(); + addRandomConstantWithType(document, key, type); + } + + public void addRandomConstantWithType(BaseDocument document, String key, ArangoDBSchema.ArangoDBDataType dataType) { + ArangoDBConstant constant; + switch (dataType) { + case STRING: + constant = new ArangoDBConstant.ArangoDBStringConstant(globalState.getRandomly().getString()); + constant.setValueInDocument(document, key); + return; + case DOUBLE: + constant = new ArangoDBConstant.ArangoDBDoubleConstant(globalState.getRandomly().getDouble()); + constant.setValueInDocument(document, key); + return; + case BOOLEAN: + constant = new ArangoDBConstant.ArangoDBBooleanConstant(Randomly.getBoolean()); + constant.setValueInDocument(document, key); + return; + case INTEGER: + constant = new ArangoDBConstant.ArangoDBIntegerConstant((int) globalState.getRandomly().getInteger()); + constant.setValueInDocument(document, key); + return; + default: + throw new AssertionError(dataType); + } + + } +} diff --git a/src/sqlancer/arangodb/query/ArangoDBCreateIndexQuery.java b/src/sqlancer/arangodb/query/ArangoDBCreateIndexQuery.java new file mode 100644 index 000000000..6c2cc1b75 --- /dev/null +++ b/src/sqlancer/arangodb/query/ArangoDBCreateIndexQuery.java @@ -0,0 +1,54 @@ +package sqlancer.arangodb.query; + +import java.util.Collections; + +import com.arangodb.ArangoCollection; + +import sqlancer.GlobalState; +import sqlancer.Main; +import sqlancer.arangodb.ArangoDBConnection; +import sqlancer.arangodb.ArangoDBQueryAdapter; +import sqlancer.arangodb.ArangoDBSchema; +import sqlancer.common.query.ExpectedErrors; + +public class ArangoDBCreateIndexQuery extends ArangoDBQueryAdapter { + + private final ArangoDBSchema.ArangoDBColumn column; + + public ArangoDBCreateIndexQuery(ArangoDBSchema.ArangoDBColumn column) { + this.column = column; + } + + @Override + public boolean couldAffectSchema() { + return false; + } + + @Override + public > boolean execute(G globalState, String... fills) + throws Exception { + try { + ArangoCollection collection = globalState.getConnection().getDatabase() + .collection(column.getTable().getName()); + collection.ensureHashIndex(Collections.singletonList(column.getName()), null); + Main.nrSuccessfulActions.addAndGet(1); + return true; + } catch (Exception e) { + Main.nrUnsuccessfulActions.addAndGet(1); + throw e; + } + } + + @Override + public ExpectedErrors getExpectedErrors() { + return new ExpectedErrors(); + } + + @Override + public String getLogString() { + StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append("db.").append(column.getTable().getName()) + .append(".ensureIndex({type: \"hash\", fields: [ \"").append(column.getName()).append("\" ]});"); + return stringBuilder.toString(); + } +} diff --git a/src/sqlancer/arangodb/query/ArangoDBCreateTableQuery.java b/src/sqlancer/arangodb/query/ArangoDBCreateTableQuery.java new file mode 100644 index 000000000..00b3276d0 --- /dev/null +++ b/src/sqlancer/arangodb/query/ArangoDBCreateTableQuery.java @@ -0,0 +1,44 @@ +package sqlancer.arangodb.query; + +import sqlancer.GlobalState; +import sqlancer.Main; +import sqlancer.arangodb.ArangoDBConnection; +import sqlancer.arangodb.ArangoDBQueryAdapter; +import sqlancer.common.query.ExpectedErrors; + +public class ArangoDBCreateTableQuery extends ArangoDBQueryAdapter { + + private final String tableName; + + public ArangoDBCreateTableQuery(String tableName) { + this.tableName = tableName; + } + + @Override + public boolean couldAffectSchema() { + return true; + } + + @Override + public > boolean execute(G globalState, String... fills) + throws Exception { + try { + globalState.getConnection().getDatabase().createCollection(tableName); + Main.nrSuccessfulActions.addAndGet(1); + return true; + } catch (Exception e) { + Main.nrUnsuccessfulActions.addAndGet(1); + throw e; + } + } + + @Override + public ExpectedErrors getExpectedErrors() { + return new ExpectedErrors(); + } + + @Override + public String getLogString() { + return "db._create(\"" + tableName + "\")"; + } +} diff --git a/src/sqlancer/arangodb/query/ArangoDBInsertQuery.java b/src/sqlancer/arangodb/query/ArangoDBInsertQuery.java new file mode 100644 index 000000000..9a3612062 --- /dev/null +++ b/src/sqlancer/arangodb/query/ArangoDBInsertQuery.java @@ -0,0 +1,66 @@ +package sqlancer.arangodb.query; + +import java.util.Map; + +import com.arangodb.entity.BaseDocument; + +import sqlancer.GlobalState; +import sqlancer.Main; +import sqlancer.arangodb.ArangoDBConnection; +import sqlancer.arangodb.ArangoDBQueryAdapter; +import sqlancer.arangodb.ArangoDBSchema; +import sqlancer.common.query.ExpectedErrors; + +public class ArangoDBInsertQuery extends ArangoDBQueryAdapter { + + private final ArangoDBSchema.ArangoDBTable table; + private final BaseDocument documentToBeInserted; + + public ArangoDBInsertQuery(ArangoDBSchema.ArangoDBTable table, BaseDocument documentToBeInserted) { + this.table = table; + this.documentToBeInserted = documentToBeInserted; + } + + @Override + public boolean couldAffectSchema() { + return true; + } + + @Override + public > boolean execute(G globalState, String... fills) + throws Exception { + try { + globalState.getConnection().getDatabase().collection(table.getName()).insertDocument(documentToBeInserted); + Main.nrSuccessfulActions.addAndGet(1); + return true; + } catch (Exception e) { + Main.nrUnsuccessfulActions.addAndGet(1); + throw e; + } + } + + @Override + public ExpectedErrors getExpectedErrors() { + return new ExpectedErrors(); + } + + @Override + public String getLogString() { + StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append("db._query(\"INSERT { "); + String filler = ""; + for (Map.Entry stringObjectEntry : documentToBeInserted.getProperties().entrySet()) { + stringBuilder.append(filler); + filler = ", "; + stringBuilder.append(stringObjectEntry.getKey()).append(": "); + Object value = stringObjectEntry.getValue(); + if (value instanceof String) { + stringBuilder.append("'").append(value).append("'"); + } else { + stringBuilder.append(value); + } + } + stringBuilder.append("} IN ").append(table.getName()).append("\")"); + return stringBuilder.toString(); + } +} diff --git a/src/sqlancer/arangodb/query/ArangoDBSelectQuery.java b/src/sqlancer/arangodb/query/ArangoDBSelectQuery.java new file mode 100644 index 000000000..400585ca2 --- /dev/null +++ b/src/sqlancer/arangodb/query/ArangoDBSelectQuery.java @@ -0,0 +1,65 @@ +package sqlancer.arangodb.query; + +import java.io.IOException; +import java.util.List; + +import com.arangodb.ArangoCursor; +import com.arangodb.entity.BaseDocument; + +import sqlancer.GlobalState; +import sqlancer.arangodb.ArangoDBConnection; +import sqlancer.arangodb.ArangoDBQueryAdapter; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLancerResultSet; + +public class ArangoDBSelectQuery extends ArangoDBQueryAdapter { + + private final String query; + + private List resultSet; + + public ArangoDBSelectQuery(String query) { + this.query = query; + } + + @Override + public boolean couldAffectSchema() { + return false; + } + + @Override + public > boolean execute(G globalState, String... fills) + throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + public ExpectedErrors getExpectedErrors() { + return new ExpectedErrors(); + } + + @Override + public String getLogString() { + return "db._query(\"" + query + "\")"; + } + + @Override + public > SQLancerResultSet executeAndGet(G globalState, + String... fills) throws Exception { + if (globalState.getOptions().logEachSelect()) { + globalState.getLogger().writeCurrent(this.getLogString()); + try { + globalState.getLogger().getCurrentFileWriter().flush(); + } catch (IOException e) { + e.printStackTrace(); + } + } + ArangoCursor cursor = globalState.getConnection().getDatabase().query(query, BaseDocument.class); + resultSet = cursor.asListRemaining(); + return null; + } + + public List getResultSet() { + return resultSet; + } +} diff --git a/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningBase.java b/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningBase.java new file mode 100644 index 000000000..f583ed04f --- /dev/null +++ b/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningBase.java @@ -0,0 +1,67 @@ +package sqlancer.arangodb.test; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.arangodb.ArangoDBProvider; +import sqlancer.arangodb.ArangoDBSchema; +import sqlancer.arangodb.ast.ArangoDBExpression; +import sqlancer.arangodb.ast.ArangoDBSelect; +import sqlancer.arangodb.gen.ArangoDBComputedExpressionGenerator; +import sqlancer.arangodb.gen.ArangoDBFilterExpressionGenerator; +import sqlancer.common.ast.newast.Node; +import sqlancer.common.gen.ExpressionGenerator; +import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; +import sqlancer.common.oracle.TestOracle; + +public class ArangoDBQueryPartitioningBase + extends TernaryLogicPartitioningOracleBase, ArangoDBProvider.ArangoDBGlobalState> + implements TestOracle { + + protected ArangoDBSchema schema; + protected List targetColumns; + protected ArangoDBFilterExpressionGenerator expressionGenerator; + protected ArangoDBSelect select; + protected int numberComputedColumns; + + protected ArangoDBQueryPartitioningBase(ArangoDBProvider.ArangoDBGlobalState state) { + super(state); + } + + @Override + protected ExpressionGenerator> getGen() { + return expressionGenerator; + } + + @Override + public void check() throws Exception { + numberComputedColumns = state.getRandomly().getInteger(0, 4); + schema = state.getSchema(); + generateTargetColumns(); + expressionGenerator = new ArangoDBFilterExpressionGenerator(state).setColumns(targetColumns); + expressionGenerator.setNumberOfComputedVariables(numberComputedColumns); + initializeTernaryPredicateVariants(); + select = new ArangoDBSelect<>(); + select.setFromColumns(targetColumns); + select.setProjectionColumns(Randomly.nonEmptySubset(targetColumns)); + generateComputedClause(); + } + + private void generateComputedClause() { + List> computedColumns = new ArrayList<>(); + ArangoDBComputedExpressionGenerator generator = new ArangoDBComputedExpressionGenerator(state); + generator.setColumns(targetColumns); + for (int i = 0; i < numberComputedColumns; i++) { + computedColumns.add(generator.generateExpression()); + } + select.setComputedClause(computedColumns); + } + + private void generateTargetColumns() { + ArangoDBSchema.ArangoDBTables targetTables; + targetTables = schema.getRandomTableNonEmptyTables(); + List allColumns = targetTables.getColumns(); + targetColumns = Randomly.nonEmptySubset(allColumns); + } +} diff --git a/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningWhereTester.java b/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningWhereTester.java new file mode 100644 index 000000000..80b7d46bf --- /dev/null +++ b/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningWhereTester.java @@ -0,0 +1,38 @@ +package sqlancer.arangodb.test; + +import static sqlancer.arangodb.ArangoDBComparatorHelper.assumeResultSetsAreEqual; +import static sqlancer.arangodb.ArangoDBComparatorHelper.getResultSetAsDocumentList; + +import java.util.List; + +import com.arangodb.entity.BaseDocument; + +import sqlancer.arangodb.ArangoDBProvider; +import sqlancer.arangodb.query.ArangoDBSelectQuery; +import sqlancer.arangodb.visitor.ArangoDBVisitor; + +public class ArangoDBQueryPartitioningWhereTester extends ArangoDBQueryPartitioningBase { + public ArangoDBQueryPartitioningWhereTester(ArangoDBProvider.ArangoDBGlobalState state) { + super(state); + } + + @Override + public void check() throws Exception { + super.check(); + select.setFilterClause(null); + + ArangoDBSelectQuery query = ArangoDBVisitor.asSelectQuery(select); + List firstResultSet = getResultSetAsDocumentList(query, state); + + select.setFilterClause(predicate); + query = ArangoDBVisitor.asSelectQuery(select); + List secondResultSet = getResultSetAsDocumentList(query, state); + + select.setFilterClause(negatedPredicate); + query = ArangoDBVisitor.asSelectQuery(select); + List thirdResultSet = getResultSetAsDocumentList(query, state); + + secondResultSet.addAll(thirdResultSet); + assumeResultSetsAreEqual(firstResultSet, secondResultSet, query); + } +} diff --git a/src/sqlancer/arangodb/visitor/ArangoDBToQueryVisitor.java b/src/sqlancer/arangodb/visitor/ArangoDBToQueryVisitor.java new file mode 100644 index 000000000..f82995d5e --- /dev/null +++ b/src/sqlancer/arangodb/visitor/ArangoDBToQueryVisitor.java @@ -0,0 +1,134 @@ +package sqlancer.arangodb.visitor; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import sqlancer.arangodb.ArangoDBSchema; +import sqlancer.arangodb.ast.ArangoDBConstant; +import sqlancer.arangodb.ast.ArangoDBExpression; +import sqlancer.arangodb.ast.ArangoDBSelect; +import sqlancer.arangodb.gen.ArangoDBComputedExpressionGenerator; +import sqlancer.arangodb.query.ArangoDBSelectQuery; +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; +import sqlancer.common.ast.newast.NewFunctionNode; +import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; +import sqlancer.common.ast.newast.Node; + +public class ArangoDBToQueryVisitor extends ArangoDBVisitor { + + private final StringBuilder stringBuilder; + + public ArangoDBToQueryVisitor() { + stringBuilder = new StringBuilder(); + } + + @Override + protected void visit(ArangoDBSelect expression) { + generateFrom(expression); + generateComputed(expression); + generateFilter(expression); + generateProject(expression); + } + + private void generateFilter(ArangoDBSelect expression) { + if (expression.hasFilter()) { + stringBuilder.append("FILTER "); + visit(expression.getFilterClause()); + stringBuilder.append(" "); + } + } + + private void generateComputed(ArangoDBSelect expression) { + if (expression.hasComputed()) { + List> computedClause = expression.getComputedClause(); + int computedNumber = 0; + for (Node computedExpression : computedClause) { + stringBuilder.append("LET c").append(computedNumber).append(" = "); + visit(computedExpression); + stringBuilder.append(" "); + computedNumber++; + } + } + } + + @Override + protected void visit(ColumnReferenceNode expression) { + if (expression.getColumn().getTable().getName().equals("")) { + stringBuilder.append(expression.getColumn().getName()); + } else { + stringBuilder.append("r").append(expression.getColumn().getTable().getName()).append(".") + .append(expression.getColumn().getName()); + } + } + + @Override + protected void visit(ArangoDBConstant expression) { + stringBuilder.append(expression.getValue()); + } + + @Override + protected void visit(NewBinaryOperatorNode expression) { + stringBuilder.append("("); + visit(expression.getLeft()); + stringBuilder.append(" ").append(expression.getOperatorRepresentation()).append(" "); + visit(expression.getRight()); + stringBuilder.append(")"); + } + + @Override + protected void visit(NewUnaryPrefixOperatorNode expression) { + stringBuilder.append(expression.getOperatorRepresentation()).append("("); + visit(expression.getExpr()); + stringBuilder.append(")"); + } + + @Override + protected void visit(NewFunctionNode expression) { + if (!(expression.getFunc() instanceof ArangoDBComputedExpressionGenerator.ComputedFunction)) { + throw new UnsupportedOperationException(); + } + ArangoDBComputedExpressionGenerator.ComputedFunction function = (ArangoDBComputedExpressionGenerator.ComputedFunction) expression + .getFunc(); + // TODO: Support functions with a different number of arguments. + if (function.getNrArgs() != 2) { + throw new UnsupportedOperationException(); + } + stringBuilder.append("("); + visit(expression.getArgs().get(0)); + stringBuilder.append(" ").append(function.getOperatorName()).append(" "); + visit(expression.getArgs().get(1)); + stringBuilder.append(")"); + } + + private void generateFrom(ArangoDBSelect expression) { + List forColumns = expression.getFromColumns(); + Set tables = new HashSet<>(); + for (ArangoDBSchema.ArangoDBColumn column : forColumns) { + tables.add(column.getTable()); + } + + for (ArangoDBSchema.ArangoDBTable table : tables) { + stringBuilder.append("FOR r").append(table.getName()).append(" IN ").append(table.getName()).append(" "); + } + } + + private void generateProject(ArangoDBSelect expression) { + List projectColumns = expression.getProjectionColumns(); + stringBuilder.append("RETURN {"); + String filler = ""; + for (ArangoDBSchema.ArangoDBColumn column : projectColumns) { + stringBuilder.append(filler); + filler = ", "; + stringBuilder.append(column.getTable().getName()).append("_").append(column.getName()).append(": r") + .append(column.getTable().getName()).append(".").append(column.getName()); + } + stringBuilder.append("}"); + } + + public ArangoDBSelectQuery getQuery() { + return new ArangoDBSelectQuery(stringBuilder.toString()); + } + +} diff --git a/src/sqlancer/arangodb/visitor/ArangoDBVisitor.java b/src/sqlancer/arangodb/visitor/ArangoDBVisitor.java new file mode 100644 index 000000000..f1db84cf5 --- /dev/null +++ b/src/sqlancer/arangodb/visitor/ArangoDBVisitor.java @@ -0,0 +1,51 @@ +package sqlancer.arangodb.visitor; + +import sqlancer.arangodb.ast.ArangoDBConstant; +import sqlancer.arangodb.ast.ArangoDBExpression; +import sqlancer.arangodb.ast.ArangoDBSelect; +import sqlancer.arangodb.query.ArangoDBSelectQuery; +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; +import sqlancer.common.ast.newast.NewFunctionNode; +import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; +import sqlancer.common.ast.newast.Node; + +public abstract class ArangoDBVisitor { + + protected abstract void visit(ArangoDBSelect expression); + + protected abstract void visit(ColumnReferenceNode expression); + + protected abstract void visit(ArangoDBConstant expression); + + protected abstract void visit(NewBinaryOperatorNode expression); + + protected abstract void visit(NewUnaryPrefixOperatorNode expression); + + protected abstract void visit(NewFunctionNode expression); + + @SuppressWarnings("unchecked") + public void visit(Node expressionNode) { + if (expressionNode instanceof ArangoDBSelect) { + visit((ArangoDBSelect) expressionNode); + } else if (expressionNode instanceof ColumnReferenceNode) { + visit((ColumnReferenceNode) expressionNode); + } else if (expressionNode instanceof ArangoDBConstant) { + visit((ArangoDBConstant) expressionNode); + } else if (expressionNode instanceof NewBinaryOperatorNode) { + visit((NewBinaryOperatorNode) expressionNode); + } else if (expressionNode instanceof NewUnaryPrefixOperatorNode) { + visit((NewUnaryPrefixOperatorNode) expressionNode); + } else if (expressionNode instanceof NewFunctionNode) { + visit((NewFunctionNode) expressionNode); + } else { + throw new AssertionError(expressionNode); + } + } + + public static ArangoDBSelectQuery asSelectQuery(Node expressionNode) { + ArangoDBToQueryVisitor visitor = new ArangoDBToQueryVisitor(); + visitor.visit(expressionNode); + return visitor.getQuery(); + } +} diff --git a/src/sqlancer/common/oracle/RemoveReduceOracleBase.java b/src/sqlancer/common/oracle/RemoveReduceOracleBase.java new file mode 100644 index 000000000..c177f072e --- /dev/null +++ b/src/sqlancer/common/oracle/RemoveReduceOracleBase.java @@ -0,0 +1,29 @@ +package sqlancer.common.oracle; + +import sqlancer.GlobalState; +import sqlancer.common.gen.ExpressionGenerator; + +public abstract class RemoveReduceOracleBase> implements TestOracle { + + protected E predicate; + + protected final S state; + + protected RemoveReduceOracleBase(S state) { + this.state = state; + } + + protected void initializeRemoveReduceOracle() { + ExpressionGenerator gen = getGen(); + if (gen == null) { + throw new IllegalStateException(); + } + predicate = gen.generatePredicate(); + if (predicate == null) { + throw new IllegalStateException(); + } + } + + protected abstract ExpressionGenerator getGen(); + +} diff --git a/src/sqlancer/common/oracle/TernaryLogicPartitioningOracleBase.java b/src/sqlancer/common/oracle/TernaryLogicPartitioningOracleBase.java index 991824628..3b5d87814 100644 --- a/src/sqlancer/common/oracle/TernaryLogicPartitioningOracleBase.java +++ b/src/sqlancer/common/oracle/TernaryLogicPartitioningOracleBase.java @@ -1,6 +1,6 @@ package sqlancer.common.oracle; -import sqlancer.SQLGlobalState; +import sqlancer.GlobalState; import sqlancer.common.gen.ExpressionGenerator; import sqlancer.common.query.ExpectedErrors; @@ -14,7 +14,7 @@ * @param * the global state type */ -public abstract class TernaryLogicPartitioningOracleBase> implements TestOracle { +public abstract class TernaryLogicPartitioningOracleBase> implements TestOracle { protected E predicate; protected E negatedPredicate; diff --git a/src/sqlancer/cosmos/CosmosProvider.java b/src/sqlancer/cosmos/CosmosProvider.java new file mode 100644 index 000000000..a8681d3af --- /dev/null +++ b/src/sqlancer/cosmos/CosmosProvider.java @@ -0,0 +1,74 @@ +package sqlancer.cosmos; + +import com.mongodb.ConnectionString; +import com.mongodb.MongoClientSettings; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.MongoDatabase; + +import sqlancer.IgnoreMeException; +import sqlancer.ProviderAdapter; +import sqlancer.Randomly; +import sqlancer.StatementExecutor; +import sqlancer.common.log.LoggableFactory; +import sqlancer.mongodb.MongoDBConnection; +import sqlancer.mongodb.MongoDBLoggableFactory; +import sqlancer.mongodb.MongoDBOptions; +import sqlancer.mongodb.MongoDBQueryAdapter; +import sqlancer.mongodb.gen.MongoDBTableGenerator; + +public class CosmosProvider extends + ProviderAdapter { + + public CosmosProvider() { + super(sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState.class, MongoDBOptions.class); + } + + @Override + public void generateDatabase(sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState globalState) throws Exception { + for (int i = 0; i < Randomly.fromOptions(4, 5, 6); i++) { + boolean success; + do { + MongoDBQueryAdapter query = new MongoDBTableGenerator(globalState).getQuery(globalState); + success = globalState.executeStatement(query); + } while (!success); + } + StatementExecutor se = new StatementExecutor<>( + globalState, sqlancer.mongodb.MongoDBProvider.Action.values(), + sqlancer.mongodb.MongoDBProvider::mapActions, (q) -> { + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + }); + se.executeStatements(); + } + + @Override + public MongoDBConnection createDatabase(sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState globalState) + throws Exception { + String connectionString = ""; + if (connectionString.equals("")) { + throw new AssertionError("Please set connection string for cosmos database, located in CosmosProvider"); + } + MongoClientSettings settings = MongoClientSettings.builder() + .applyConnectionString(new ConnectionString(connectionString)).build(); + MongoClient mongoClient = MongoClients.create(settings); + MongoDatabase database = mongoClient.getDatabase(globalState.getDatabaseName()); + database.drop(); + return new MongoDBConnection(mongoClient, database); + } + + @Override + public String getDBMSName() { + return "cosmos"; + } + + @Override + public LoggableFactory getLoggableFactory() { + return new MongoDBLoggableFactory(); + } + + @Override + protected void checkViewsAreValid(sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState globalState) { + } +} diff --git a/src/sqlancer/mongodb/MongoDBComparatorHelper.java b/src/sqlancer/mongodb/MongoDBComparatorHelper.java new file mode 100644 index 000000000..49b692645 --- /dev/null +++ b/src/sqlancer/mongodb/MongoDBComparatorHelper.java @@ -0,0 +1,97 @@ +package sqlancer.mongodb; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.bson.Document; + +import sqlancer.IgnoreMeException; +import sqlancer.Main; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; +import sqlancer.mongodb.query.MongoDBSelectQuery; + +public final class MongoDBComparatorHelper { + + private MongoDBComparatorHelper() { + } + + public static List getResultSetAsDocumentList(MongoDBSelectQuery adapter, MongoDBGlobalState state) + throws Exception { + ExpectedErrors errors = adapter.getExpectedErrors(); + List result; + try { + adapter.executeAndGet(state); + Main.nrSuccessfulActions.addAndGet(1); + result = adapter.getResultSet(); + return result; + } catch (Exception e) { + if (e instanceof IgnoreMeException) { + throw e; + } + Main.nrUnsuccessfulActions.addAndGet(1); + if (e.getMessage() == null) { + throw new AssertionError(adapter.getLogString(), e); + } + if (errors.errorIsExpected(e.getMessage())) { + throw new IgnoreMeException(); + } + throw new AssertionError(adapter.getLogString(), e); + } + } + + public static void assumeCountIsEqual(List resultSet, List secondResultSet, + MongoDBSelectQuery originalQuery) { + int originalSize = resultSet.size(); + if (secondResultSet.isEmpty()) { + if (originalSize == 0) { + return; + } else { + String assertMessage = String.format("The Count of the result set mismatches!\n %s", + originalQuery.getLogString()); + throw new AssertionError(assertMessage); + } + } + if (secondResultSet.size() != 1) { + throw new AssertionError( + String.format("Count query result bigger than one \n %s", originalQuery.getLogString())); + } + int withCount = (int) secondResultSet.get(0).get("count"); + if (originalSize != withCount) { + String assertMessage = String.format("The Count of the result set mismatches!\n %s", + originalQuery.getLogString()); + throw new AssertionError(assertMessage); + } + } + + public static void assumeResultSetsAreEqual(List resultSet, List secondResultSet, + MongoDBSelectQuery originalQuery) { + if (resultSet.size() != secondResultSet.size()) { + String assertionMessage = String.format("The Size of the result sets mismatch (%d and %d)!\n%s", + resultSet.size(), secondResultSet.size(), originalQuery.getLogString()); + throw new AssertionError(assertionMessage); + } + + Set firstHashSet = new HashSet<>(resultSet); + Set secondHashSet = new HashSet<>(secondResultSet); + + if (!firstHashSet.equals(secondHashSet)) { + Set firstResultSetMisses = new HashSet<>(firstHashSet); + firstResultSetMisses.removeAll(secondHashSet); + Set secondResultSetMisses = new HashSet<>(secondHashSet); + secondResultSetMisses.removeAll(firstHashSet); + StringBuilder firstMisses = new StringBuilder(); + for (Document document : firstResultSetMisses) { + firstMisses.append(document.toJson()).append(" "); + } + StringBuilder secondMisses = new StringBuilder(); + for (Document document : secondResultSetMisses) { + secondMisses.append(document.toJson()).append(" "); + } + String assertMessage = String.format("The Content of the result sets mismatch!\n %s \n %s\n %s", + firstMisses.toString(), secondMisses.toString(), originalQuery.getLogString()); + throw new AssertionError(assertMessage); + } + } +} diff --git a/src/sqlancer/mongodb/MongoDBConnection.java b/src/sqlancer/mongodb/MongoDBConnection.java new file mode 100644 index 000000000..6971bd79c --- /dev/null +++ b/src/sqlancer/mongodb/MongoDBConnection.java @@ -0,0 +1,35 @@ +package sqlancer.mongodb; + +import org.bson.BsonDocument; +import org.bson.BsonString; + +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoDatabase; + +import sqlancer.SQLancerDBConnection; + +public class MongoDBConnection implements SQLancerDBConnection { + + private final MongoClient client; + private final MongoDatabase database; + + public MongoDBConnection(MongoClient client, MongoDatabase database) { + this.client = client; + this.database = database; + } + + @Override + public String getDatabaseVersion() throws Exception { + return client.getDatabase("dbname").runCommand(new BsonDocument("buildinfo", new BsonString(""))).get("version") + .toString(); + } + + @Override + public void close() throws Exception { + client.close(); + } + + public MongoDatabase getDatabase() { + return database; + } +} diff --git a/src/sqlancer/mongodb/MongoDBLoggableFactory.java b/src/sqlancer/mongodb/MongoDBLoggableFactory.java new file mode 100644 index 000000000..b668301b3 --- /dev/null +++ b/src/sqlancer/mongodb/MongoDBLoggableFactory.java @@ -0,0 +1,40 @@ +package sqlancer.mongodb; + +import java.util.Arrays; + +import sqlancer.common.log.Loggable; +import sqlancer.common.log.LoggableFactory; +import sqlancer.common.log.LoggedString; +import sqlancer.common.query.Query; + +public class MongoDBLoggableFactory extends LoggableFactory { + @Override + protected Loggable createLoggable(String input, String suffix) { + return new LoggedString(input + suffix); + } + + @Override + public Query getQueryForStateToReproduce(String queryString) { + throw new UnsupportedOperationException(); + } + + @Override + public Query commentOutQuery(Query query) { + throw new UnsupportedOperationException(); + } + + @Override + protected Loggable infoToLoggable(String time, String databaseName, String databaseVersion, long seedValue) { + StringBuilder sb = new StringBuilder(); + sb.append("// Time: ").append(time).append("\n"); + sb.append("// Database: ").append(databaseName).append("\n"); + sb.append("// Database version: ").append(databaseVersion).append("\n"); + sb.append("// seed value: ").append(seedValue).append("\n"); + return new LoggedString(sb.toString()); + } + + @Override + public Loggable convertStacktraceToLoggable(Throwable throwable) { + return new LoggedString(Arrays.toString(throwable.getStackTrace()) + "\n" + throwable.getMessage()); + } +} diff --git a/src/sqlancer/mongodb/MongoDBOptions.java b/src/sqlancer/mongodb/MongoDBOptions.java new file mode 100644 index 000000000..cf632085b --- /dev/null +++ b/src/sqlancer/mongodb/MongoDBOptions.java @@ -0,0 +1,71 @@ +package sqlancer.mongodb; + +import static sqlancer.mongodb.MongoDBOptions.MongoDBOracleFactory.QUERY_PARTITIONING; +import static sqlancer.mongodb.MongoDBOptions.MongoDBOracleFactory.REMOVE_REDUCE; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import com.beust.jcommander.Parameter; + +import sqlancer.DBMSSpecificOptions; +import sqlancer.OracleFactory; +import sqlancer.common.oracle.CompositeTestOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.mongodb.test.MongoDBQueryPartitioningWhereTester; +import sqlancer.mongodb.test.MongoDBRemoveReduceTester; + +public class MongoDBOptions implements DBMSSpecificOptions { + + @Parameter(names = "--test-validation", description = "Enable/Disable validation of schema with Schema Validation", arity = 1) + public boolean testValidation = true; + + @Parameter(names = "--test-null-inserts", description = "Enables to test inserting with null values, validation has to be off", arity = 1) + public boolean testNullInserts; + + @Parameter(names = "--test-random-types", description = "Insert random types instead of schema types, validation has to be off", arity = 1) + public boolean testRandomTypes; + + @Parameter(names = "--max-number-indexes", description = "The maximum number of indexes used.", arity = 1) + public int maxNumberIndexes = 15; + + @Parameter(names = "--test-computed-values", description = "Enable adding computed values to query", arity = 1) + public boolean testComputedValues; + + @Parameter(names = "--test-with-regex", description = "Enable Regex Leaf Nodes", arity = 1) + public boolean testWithRegex; + + @Parameter(names = "--test-with-count", description = "Count the number of documents and check with count command", arity = 1) + public boolean testWithCount; + + @Parameter(names = "--null-safety", description = "", arity = 1) + public boolean nullSafety; + + @Parameter(names = "--oracle") + public List oracles = Arrays.asList(QUERY_PARTITIONING, REMOVE_REDUCE); + + @Override + public List getTestOracleFactory() { + return oracles; + } + + public enum MongoDBOracleFactory implements OracleFactory { + QUERY_PARTITIONING { + @Override + public TestOracle create(MongoDBProvider.MongoDBGlobalState globalState) throws Exception { + List oracles = new ArrayList<>(); + oracles.add(new MongoDBQueryPartitioningWhereTester(globalState)); + return new CompositeTestOracle(oracles, globalState); + } + }, + REMOVE_REDUCE { + @Override + public TestOracle create(MongoDBProvider.MongoDBGlobalState globalState) throws Exception { + List oracles = new ArrayList<>(); + oracles.add(new MongoDBRemoveReduceTester(globalState)); + return new CompositeTestOracle(oracles, globalState); + } + } + } +} diff --git a/src/sqlancer/mongodb/MongoDBProvider.java b/src/sqlancer/mongodb/MongoDBProvider.java new file mode 100644 index 000000000..5ff549d32 --- /dev/null +++ b/src/sqlancer/mongodb/MongoDBProvider.java @@ -0,0 +1,125 @@ +package sqlancer.mongodb; + +import java.util.ArrayList; +import java.util.List; + +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.MongoDatabase; + +import sqlancer.AbstractAction; +import sqlancer.ExecutionTimer; +import sqlancer.GlobalState; +import sqlancer.IgnoreMeException; +import sqlancer.ProviderAdapter; +import sqlancer.Randomly; +import sqlancer.StatementExecutor; +import sqlancer.common.log.LoggableFactory; +import sqlancer.common.query.Query; +import sqlancer.mongodb.MongoDBSchema.MongoDBTable; +import sqlancer.mongodb.gen.MongoDBIndexGenerator; +import sqlancer.mongodb.gen.MongoDBInsertGenerator; +import sqlancer.mongodb.gen.MongoDBTableGenerator; + +public class MongoDBProvider + extends ProviderAdapter { + + public MongoDBProvider() { + super(MongoDBGlobalState.class, MongoDBOptions.class); + } + + public enum Action implements AbstractAction { + INSERT(MongoDBInsertGenerator::getQuery), CREATE_INDEX(MongoDBIndexGenerator::getQuery); + + private final MongoDBQueryProvider queryProvider; + + Action(MongoDBQueryProvider queryProvider) { + this.queryProvider = queryProvider; + } + + @Override + public Query getQuery(MongoDBGlobalState globalState) throws Exception { + return queryProvider.getQuery(globalState); + } + } + + public static int mapActions(MongoDBGlobalState globalState, Action a) { + Randomly r = globalState.getRandomly(); + switch (a) { + case INSERT: + return r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); + case CREATE_INDEX: + return r.getInteger(0, globalState.getDmbsSpecificOptions().maxNumberIndexes); + default: + throw new AssertionError(a); + } + } + + public static class MongoDBGlobalState extends GlobalState { + + private final List schemaTables = new ArrayList<>(); + + public void addTable(MongoDBTable table) { + schemaTables.add(table); + } + + @Override + protected void executeEpilogue(Query q, boolean success, ExecutionTimer timer) throws Exception { + boolean logExecutionTime = getOptions().logExecutionTime(); + if (success && getOptions().printSucceedingStatements()) { + System.out.println(q.getLogString()); + } + if (logExecutionTime) { + getLogger().writeCurrent("// " + timer.end().asString()); + } + if (q.couldAffectSchema()) { + updateSchema(); + } + } + + @Override + protected MongoDBSchema readSchema() throws Exception { + return new MongoDBSchema(schemaTables); + } + } + + @Override + public void generateDatabase(MongoDBGlobalState globalState) throws Exception { + for (int i = 0; i < Randomly.fromOptions(4, 5, 6); i++) { + boolean success; + do { + MongoDBQueryAdapter query = new MongoDBTableGenerator(globalState).getQuery(globalState); + success = globalState.executeStatement(query); + } while (!success); + } + StatementExecutor se = new StatementExecutor<>(globalState, Action.values(), + MongoDBProvider::mapActions, (q) -> { + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + }); + se.executeStatements(); + } + + @Override + public MongoDBConnection createDatabase(MongoDBGlobalState globalState) throws Exception { + MongoClient mongoClient = MongoClients.create(); + MongoDatabase database = mongoClient.getDatabase(globalState.getDatabaseName()); + database.drop(); + return new MongoDBConnection(mongoClient, database); + } + + @Override + public String getDBMSName() { + return "mongodb"; + } + + @Override + public LoggableFactory getLoggableFactory() { + return new MongoDBLoggableFactory(); + } + + @Override + protected void checkViewsAreValid(MongoDBGlobalState globalState) { + } +} diff --git a/src/sqlancer/mongodb/MongoDBQueryAdapter.java b/src/sqlancer/mongodb/MongoDBQueryAdapter.java new file mode 100644 index 000000000..e2add3242 --- /dev/null +++ b/src/sqlancer/mongodb/MongoDBQueryAdapter.java @@ -0,0 +1,15 @@ +package sqlancer.mongodb; + +import sqlancer.common.query.Query; + +public abstract class MongoDBQueryAdapter extends Query { + @Override + public String getQueryString() { + throw new UnsupportedOperationException(); + } + + @Override + public String getUnterminatedQueryString() { + throw new UnsupportedOperationException(); + } +} diff --git a/src/sqlancer/mongodb/MongoDBQueryProvider.java b/src/sqlancer/mongodb/MongoDBQueryProvider.java new file mode 100644 index 000000000..970c90cea --- /dev/null +++ b/src/sqlancer/mongodb/MongoDBQueryProvider.java @@ -0,0 +1,6 @@ +package sqlancer.mongodb; + +@FunctionalInterface +public interface MongoDBQueryProvider { + MongoDBQueryAdapter getQuery(S globalState) throws Exception; +} diff --git a/src/sqlancer/mongodb/MongoDBSchema.java b/src/sqlancer/mongodb/MongoDBSchema.java new file mode 100644 index 000000000..e9a0afb99 --- /dev/null +++ b/src/sqlancer/mongodb/MongoDBSchema.java @@ -0,0 +1,97 @@ +package sqlancer.mongodb; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.bson.BsonType; + +import com.mongodb.client.MongoDatabase; + +import sqlancer.Randomly; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTable; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; +import sqlancer.common.schema.TableIndex; +import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; + +public class MongoDBSchema extends AbstractSchema { + + public enum MongoDBDataType { + INTEGER(BsonType.INT32), STRING(BsonType.STRING), BOOLEAN(BsonType.BOOLEAN), DOUBLE(BsonType.DOUBLE), + DATE_TIME(BsonType.DATE_TIME), TIMESTAMP(BsonType.TIMESTAMP); + + private final BsonType bsonType; + + MongoDBDataType(BsonType type) { + this.bsonType = type; + } + + public BsonType getBsonType() { + return bsonType; + } + + public static MongoDBDataType getRandom(MongoDBGlobalState state) { + Set valueSet = new HashSet<>(Arrays.asList(values())); + if (state.getDmbsSpecificOptions().nullSafety) { + valueSet.remove(STRING); + } + MongoDBDataType[] configuredValues = new MongoDBDataType[valueSet.size()]; + return Randomly.fromOptions(valueSet.toArray(configuredValues)); + } + } + + public static class MongoDBColumn extends AbstractTableColumn { + + private final boolean isId; + private final boolean isNullable; + + public MongoDBColumn(String name, MongoDBDataType type, boolean isId, boolean isNullable) { + super(name, null, type); + this.isId = isId; + this.isNullable = isNullable; + } + + public boolean isId() { + return isId; + } + + public boolean isNullable() { + return isNullable; + } + + } + + public static class MongoDBTables extends AbstractTables { + + public MongoDBTables(List tables) { + super(tables); + } + } + + public MongoDBSchema(List databaseTables) { + super(databaseTables); + } + + public static class MongoDBTable extends AbstractTable { + public MongoDBTable(String name, List columns, boolean isView) { + super(name, columns, Collections.emptyList(), isView); + } + + @Override + public long getNrRows(MongoDBGlobalState globalState) { + throw new UnsupportedOperationException(); + } + } + + public static MongoDBSchema fromConnection(MongoDatabase connection, String databaseName) { + throw new UnsupportedOperationException(); + } + + public MongoDBTables getRandomTableNonEmptyTables() { + return new MongoDBTables(Randomly.nonEmptySubset(getDatabaseTables())); + } +} diff --git a/src/sqlancer/mongodb/ast/MongoDBBinaryComparisonNode.java b/src/sqlancer/mongodb/ast/MongoDBBinaryComparisonNode.java new file mode 100644 index 000000000..21675250a --- /dev/null +++ b/src/sqlancer/mongodb/ast/MongoDBBinaryComparisonNode.java @@ -0,0 +1,16 @@ +package sqlancer.mongodb.ast; + +import sqlancer.common.ast.newast.NewBinaryOperatorNode; +import sqlancer.common.ast.newast.Node; +import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBBinaryComparisonOperator; + +public class MongoDBBinaryComparisonNode extends NewBinaryOperatorNode { + public MongoDBBinaryComparisonNode(Node left, Node right, + MongoDBBinaryComparisonOperator op) { + super(left, right, op); + } + + public MongoDBBinaryComparisonOperator operator() { + return (MongoDBBinaryComparisonOperator) op; + } +} diff --git a/src/sqlancer/mongodb/ast/MongoDBBinaryLogicalNode.java b/src/sqlancer/mongodb/ast/MongoDBBinaryLogicalNode.java new file mode 100644 index 000000000..efb8d8294 --- /dev/null +++ b/src/sqlancer/mongodb/ast/MongoDBBinaryLogicalNode.java @@ -0,0 +1,16 @@ +package sqlancer.mongodb.ast; + +import sqlancer.common.ast.newast.NewBinaryOperatorNode; +import sqlancer.common.ast.newast.Node; +import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBBinaryLogicalOperator; + +public class MongoDBBinaryLogicalNode extends NewBinaryOperatorNode { + public MongoDBBinaryLogicalNode(Node left, Node right, + MongoDBBinaryLogicalOperator op) { + super(left, right, op); + } + + public MongoDBBinaryLogicalOperator operator() { + return (MongoDBBinaryLogicalOperator) op; + } +} diff --git a/src/sqlancer/mongodb/ast/MongoDBConstant.java b/src/sqlancer/mongodb/ast/MongoDBConstant.java new file mode 100644 index 000000000..86f783b48 --- /dev/null +++ b/src/sqlancer/mongodb/ast/MongoDBConstant.java @@ -0,0 +1,252 @@ +package sqlancer.mongodb.ast; + +import java.io.Serializable; + +import org.bson.BsonDateTime; +import org.bson.BsonTimestamp; +import org.bson.Document; + +import sqlancer.common.ast.newast.Node; + +public abstract class MongoDBConstant implements Node { + private MongoDBConstant() { + } + + public abstract void setValueInDocument(Document document, String key); + + public abstract String getLogValue(); + + public abstract Object getValue(); + + public abstract Serializable getSerializedValue(); + + public static class MongoDBNullConstant extends MongoDBConstant { + + @Override + public void setValueInDocument(Document document, String key) { + document.append(key, null); + } + + @Override + public String getLogValue() { + return "null"; + } + + @Override + public Object getValue() { + return null; + } + + @Override + public Serializable getSerializedValue() { + return null; + } + } + + public static Node createNullConstant() { + return new MongoDBNullConstant(); + } + + public static class MongoDBIntegerConstant extends MongoDBConstant { + + private final int value; + + public MongoDBIntegerConstant(int value) { + this.value = value; + } + + @Override + public void setValueInDocument(Document document, String key) { + document.append(key, value); + } + + @Override + public String getLogValue() { + return "NumberInt(" + value + ")"; + } + + @Override + public Integer getValue() { + return value; + } + + @Override + public Serializable getSerializedValue() { + return value; + } + } + + public static Node createIntegerConstant(int value) { + return new MongoDBIntegerConstant(value); + } + + public static class MongoDBStringConstant extends MongoDBConstant { + + private final String value; + + public MongoDBStringConstant(String value) { + this.value = value; + } + + public String getStringValue() { + return value; + } + + @Override + public void setValueInDocument(Document document, String key) { + document.append(key, value); + } + + @Override + public String getLogValue() { + return "\"" + value.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n") + "\""; + } + + @Override + public String getValue() { + return value; + } + + @Override + public Serializable getSerializedValue() { + return value; + } + } + + public static Node createStringConstant(String value) { + return new MongoDBStringConstant(value); + } + + public static class MongoDBBooleanConstant extends MongoDBConstant { + + private final boolean value; + + public MongoDBBooleanConstant(boolean value) { + this.value = value; + } + + @Override + public void setValueInDocument(Document document, String key) { + document.append(key, value); + } + + @Override + public String getLogValue() { + return String.valueOf(value); + } + + @Override + public Boolean getValue() { + return value; + } + + @Override + public Serializable getSerializedValue() { + return value; + } + } + + public static Node createBooleanConstant(boolean value) { + return new MongoDBBooleanConstant(value); + } + + public static class MongoDBDoubleConstant extends MongoDBConstant { + + private final double value; + + public MongoDBDoubleConstant(double value) { + this.value = value; + } + + @Override + public void setValueInDocument(Document document, String key) { + document.append(key, value); + } + + @Override + public String getLogValue() { + return String.valueOf(value); + } + + @Override + public Double getValue() { + return value; + } + + @Override + public Serializable getSerializedValue() { + return value; + } + } + + public static Node createDoubleConstant(double value) { + return new MongoDBDoubleConstant(value); + } + + public static class MongoDBDateTimeConstant extends MongoDBConstant { + + private final BsonDateTime value; + + public MongoDBDateTimeConstant(long val) { + this.value = new BsonDateTime(val); + } + + @Override + public void setValueInDocument(Document document, String key) { + document.append(key, value); + } + + @Override + public String getLogValue() { + return "new Date(" + value.getValue() + ")"; + } + + @Override + public BsonDateTime getValue() { + return value; + } + + @Override + public Serializable getSerializedValue() { + return value.getValue(); + } + } + + public static Node createDateTimeConstant(long value) { + return new MongoDBDateTimeConstant(value); + } + + public static class MongoDBTimestampConstant extends MongoDBConstant { + + private final BsonTimestamp value; + + public MongoDBTimestampConstant(long value) { + this.value = new BsonTimestamp(value); + } + + @Override + public void setValueInDocument(Document document, String key) { + document.append(key, value); + } + + @Override + public String getLogValue() { + return "Timestamp(" + value.getValue() + ",1)"; + } + + @Override + public BsonTimestamp getValue() { + return value; + } + + @Override + public Serializable getSerializedValue() { + return value.getValue(); + } + } + + public static Node createTimestampConstant(long value) { + return new MongoDBTimestampConstant(value); + } + +} diff --git a/src/sqlancer/mongodb/ast/MongoDBExpression.java b/src/sqlancer/mongodb/ast/MongoDBExpression.java new file mode 100644 index 000000000..1235a1fbc --- /dev/null +++ b/src/sqlancer/mongodb/ast/MongoDBExpression.java @@ -0,0 +1,4 @@ +package sqlancer.mongodb.ast; + +public interface MongoDBExpression { +} diff --git a/src/sqlancer/mongodb/ast/MongoDBRegexNode.java b/src/sqlancer/mongodb/ast/MongoDBRegexNode.java new file mode 100644 index 000000000..76c608586 --- /dev/null +++ b/src/sqlancer/mongodb/ast/MongoDBRegexNode.java @@ -0,0 +1,24 @@ +package sqlancer.mongodb.ast; + +import static sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBRegexOperator.REGEX; + +import sqlancer.common.ast.newast.NewBinaryOperatorNode; +import sqlancer.common.ast.newast.Node; +import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBRegexOperator; + +public class MongoDBRegexNode extends NewBinaryOperatorNode { + private final String options; + + public MongoDBRegexNode(Node left, Node right, String options) { + super(left, right, REGEX); + this.options = options; + } + + public String getOptions() { + return options; + } + + public MongoDBRegexOperator operator() { + return (MongoDBRegexOperator) op; + } +} diff --git a/src/sqlancer/mongodb/ast/MongoDBSelect.java b/src/sqlancer/mongodb/ast/MongoDBSelect.java new file mode 100644 index 000000000..0fe91ba4a --- /dev/null +++ b/src/sqlancer/mongodb/ast/MongoDBSelect.java @@ -0,0 +1,104 @@ +package sqlancer.mongodb.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.Node; +import sqlancer.mongodb.test.MongoDBColumnTestReference; + +public class MongoDBSelect implements Node { + + private final String mainTableName; + private final MongoDBColumnTestReference joinColumn; + List projectionColumns; + List lookupList; + boolean hasFilter; + Node filterClause; + boolean hasComputed; + List> computedClauses; + private boolean withCountClause; + + public MongoDBSelect(String mainTableName, MongoDBColumnTestReference joinColumn) { + this.mainTableName = mainTableName; + this.joinColumn = joinColumn; + } + + public String getMainTableName() { + return mainTableName; + } + + public MongoDBColumnTestReference getJoinColumn() { + return joinColumn; + } + + public void setProjectionList(List fetchColumns) { + if (fetchColumns == null || fetchColumns.isEmpty()) { + throw new IllegalArgumentException(); + } + this.projectionColumns = fetchColumns; + } + + public List getProjectionList() { + if (projectionColumns == null) { + throw new IllegalStateException(); + } + return projectionColumns; + } + + public void setLookupList(List lookupList) { + if (lookupList == null || lookupList.isEmpty()) { + throw new IllegalArgumentException(); + } + this.lookupList = lookupList; + } + + public List getLookupList() { + if (lookupList == null) { + throw new IllegalStateException(); + } + return lookupList; + } + + public void setFilterClause(Node filterClause) { + if (filterClause == null) { + hasFilter = false; + this.filterClause = null; + return; + } + hasFilter = true; + this.filterClause = filterClause; + } + + public Node getFilterClause() { + return filterClause; + } + + public boolean hasFilter() { + return hasFilter; + } + + public void setComputedClause(List> computedClause) { + if (computedClause == null) { + hasComputed = false; + this.computedClauses = null; + return; + } + hasComputed = true; + this.computedClauses = computedClause; + } + + public List> getComputedClause() { + return computedClauses; + } + + public boolean hasComputed() { + return hasComputed; + } + + public boolean getWithCountClause() { + return withCountClause; + } + + public void setWithCountClause(boolean withCountClause) { + this.withCountClause = withCountClause; + } +} diff --git a/src/sqlancer/mongodb/ast/MongoDBUnaryLogicalOperatorNode.java b/src/sqlancer/mongodb/ast/MongoDBUnaryLogicalOperatorNode.java new file mode 100644 index 000000000..a34fe27e5 --- /dev/null +++ b/src/sqlancer/mongodb/ast/MongoDBUnaryLogicalOperatorNode.java @@ -0,0 +1,16 @@ +package sqlancer.mongodb.ast; + +import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; +import sqlancer.common.ast.newast.Node; +import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBUnaryLogicalOperator; + +public class MongoDBUnaryLogicalOperatorNode extends NewUnaryPrefixOperatorNode { + + public MongoDBUnaryLogicalOperatorNode(Node expr, MongoDBUnaryLogicalOperator op) { + super(expr, op); + } + + public MongoDBUnaryLogicalOperator operator() { + return (MongoDBUnaryLogicalOperator) op; + } +} diff --git a/src/sqlancer/mongodb/ast/MongoDBUnsupportedPredicate.java b/src/sqlancer/mongodb/ast/MongoDBUnsupportedPredicate.java new file mode 100644 index 000000000..eae143e7d --- /dev/null +++ b/src/sqlancer/mongodb/ast/MongoDBUnsupportedPredicate.java @@ -0,0 +1,7 @@ +package sqlancer.mongodb.ast; + +import sqlancer.common.ast.newast.Node; + +public class MongoDBUnsupportedPredicate implements Node { + +} diff --git a/src/sqlancer/mongodb/gen/MongoDBComputedExpressionGenerator.java b/src/sqlancer/mongodb/gen/MongoDBComputedExpressionGenerator.java new file mode 100644 index 000000000..fd5959ce6 --- /dev/null +++ b/src/sqlancer/mongodb/gen/MongoDBComputedExpressionGenerator.java @@ -0,0 +1,89 @@ +package sqlancer.mongodb.gen; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.ast.newast.NewFunctionNode; +import sqlancer.common.ast.newast.Node; +import sqlancer.common.gen.UntypedExpressionGenerator; +import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; +import sqlancer.mongodb.MongoDBSchema; +import sqlancer.mongodb.ast.MongoDBExpression; +import sqlancer.mongodb.test.MongoDBColumnTestReference; + +public class MongoDBComputedExpressionGenerator + extends UntypedExpressionGenerator, MongoDBColumnTestReference> { + + private final MongoDBGlobalState globalState; + + @Override + public Node generateLeafNode() { + ComputedFunction function = ComputedFunction.getRandom(); + List> expressions = new ArrayList<>(); + for (int i = 0; i < function.getNrArgs(); i++) { + expressions.add(super.generateLeafNode()); + } + return new NewFunctionNode<>(expressions, function); + } + + @Override + protected Node generateExpression(int depth) { + if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { + return generateLeafNode(); + } + ComputedFunction func = ComputedFunction.getRandom(); + return new NewFunctionNode<>(generateExpressions(depth + 1, func.getNrArgs()), func); + } + + public MongoDBComputedExpressionGenerator(MongoDBGlobalState globalState) { + this.globalState = globalState; + } + + public enum ComputedFunction { + ADD(2, "$add"), MULTIPLY(2, "$multiply"), DIVIDE(2, "$divide"), POW(2, "$pow"), SQRT(1, "$sqrt"), + LOG(2, "$log"), AVG(2, "$avg"), EXP(1, "$exp"); + + private final int nrArgs; + private final String operatorName; + + ComputedFunction(int nrArgs, String operatorName) { + this.nrArgs = nrArgs; + this.operatorName = operatorName; + } + + public static ComputedFunction getRandom() { + return Randomly.fromOptions(values()); + } + + public int getNrArgs() { + return nrArgs; + } + + public String getOperator() { + return operatorName; + } + } + + @Override + public Node generateConstant() { + MongoDBSchema.MongoDBDataType type = MongoDBSchema.MongoDBDataType.getRandom(globalState); + MongoDBConstantGenerator generator = new MongoDBConstantGenerator(globalState); + return generator.generateConstantWithType(type); + } + + @Override + protected Node generateColumn() { + return Randomly.fromList(columns); + } + + @Override + public Node negatePredicate(Node predicate) { + throw new UnsupportedOperationException(); + } + + @Override + public Node isNull(Node expr) { + throw new UnsupportedOperationException(); + } +} diff --git a/src/sqlancer/mongodb/gen/MongoDBConstantGenerator.java b/src/sqlancer/mongodb/gen/MongoDBConstantGenerator.java new file mode 100644 index 000000000..e81291543 --- /dev/null +++ b/src/sqlancer/mongodb/gen/MongoDBConstantGenerator.java @@ -0,0 +1,86 @@ +package sqlancer.mongodb.gen; + +import org.bson.Document; + +import sqlancer.Randomly; +import sqlancer.common.ast.newast.Node; +import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; +import sqlancer.mongodb.MongoDBSchema.MongoDBDataType; +import sqlancer.mongodb.ast.MongoDBConstant; +import sqlancer.mongodb.ast.MongoDBConstant.MongoDBBooleanConstant; +import sqlancer.mongodb.ast.MongoDBConstant.MongoDBDateTimeConstant; +import sqlancer.mongodb.ast.MongoDBConstant.MongoDBDoubleConstant; +import sqlancer.mongodb.ast.MongoDBConstant.MongoDBIntegerConstant; +import sqlancer.mongodb.ast.MongoDBConstant.MongoDBNullConstant; +import sqlancer.mongodb.ast.MongoDBConstant.MongoDBTimestampConstant; +import sqlancer.mongodb.ast.MongoDBExpression; + +public class MongoDBConstantGenerator { + private final MongoDBGlobalState globalState; + + public MongoDBConstantGenerator(MongoDBGlobalState globalState) { + this.globalState = globalState; + } + + public Node generateConstantWithType(MongoDBDataType option) { + switch (option) { + case DATE_TIME: + return MongoDBConstant.createDateTimeConstant(globalState.getRandomly().getInteger()); + case BOOLEAN: + return MongoDBConstant.createBooleanConstant(Randomly.getBoolean()); + case DOUBLE: + return MongoDBConstant.createDoubleConstant(globalState.getRandomly().getDouble()); + case STRING: + return MongoDBConstant.createStringConstant(globalState.getRandomly().getString()); + case INTEGER: + return MongoDBConstant.createIntegerConstant((int) globalState.getRandomly().getInteger()); + case TIMESTAMP: + return MongoDBConstant.createTimestampConstant(globalState.getRandomly().getInteger()); + default: + throw new AssertionError(option); + } + } + + public void addRandomConstant(Document document, String key) { + MongoDBDataType type = MongoDBDataType.getRandom(globalState); + addRandomConstantWithType(document, key, type); + } + + public void addRandomConstantWithType(Document document, String key, MongoDBDataType option) { + MongoDBConstant constant; + if (globalState.getDmbsSpecificOptions().testNullInserts && Randomly.getBooleanWithSmallProbability()) { + constant = new MongoDBNullConstant(); + constant.setValueInDocument(document, key); + return; + } + switch (option) { + case DATE_TIME: + constant = new MongoDBDateTimeConstant(globalState.getRandomly().getInteger()); + constant.setValueInDocument(document, key); + return; + + case BOOLEAN: + constant = new MongoDBBooleanConstant(Randomly.getBoolean()); + constant.setValueInDocument(document, key); + return; + case DOUBLE: + constant = new MongoDBDoubleConstant(globalState.getRandomly().getDouble()); + constant.setValueInDocument(document, key); + return; + case STRING: + constant = new MongoDBConstant.MongoDBStringConstant(globalState.getRandomly().getString()); + constant.setValueInDocument(document, key); + return; + case INTEGER: + constant = new MongoDBIntegerConstant((int) globalState.getRandomly().getInteger()); + constant.setValueInDocument(document, key); + return; + case TIMESTAMP: + constant = new MongoDBTimestampConstant(globalState.getRandomly().getInteger()); + constant.setValueInDocument(document, key); + return; + default: + throw new AssertionError(option); + } + } +} diff --git a/src/sqlancer/mongodb/gen/MongoDBIndexGenerator.java b/src/sqlancer/mongodb/gen/MongoDBIndexGenerator.java new file mode 100644 index 000000000..8687fd45c --- /dev/null +++ b/src/sqlancer/mongodb/gen/MongoDBIndexGenerator.java @@ -0,0 +1,25 @@ +package sqlancer.mongodb.gen; + +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; +import sqlancer.mongodb.MongoDBQueryAdapter; +import sqlancer.mongodb.MongoDBSchema.MongoDBColumn; +import sqlancer.mongodb.MongoDBSchema.MongoDBTable; +import sqlancer.mongodb.query.MongoDBCreateIndexQuery; + +public final class MongoDBIndexGenerator { + private MongoDBIndexGenerator() { + } + + public static MongoDBQueryAdapter getQuery(MongoDBGlobalState globalState) { + MongoDBTable randomTable = globalState.getSchema().getRandomTable(); + List columns = Randomly.nonEmptySubset(randomTable.getColumns()); + MongoDBCreateIndexQuery createIndexQuery = new MongoDBCreateIndexQuery(randomTable); + for (MongoDBColumn column : columns) { + createIndexQuery.addIndex(column.getName(), Randomly.getBoolean()); + } + return createIndexQuery; + } +} diff --git a/src/sqlancer/mongodb/gen/MongoDBInsertGenerator.java b/src/sqlancer/mongodb/gen/MongoDBInsertGenerator.java new file mode 100644 index 000000000..4501971b4 --- /dev/null +++ b/src/sqlancer/mongodb/gen/MongoDBInsertGenerator.java @@ -0,0 +1,38 @@ +package sqlancer.mongodb.gen; + +import org.bson.Document; + +import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; +import sqlancer.mongodb.MongoDBQueryAdapter; +import sqlancer.mongodb.MongoDBSchema.MongoDBTable; +import sqlancer.mongodb.query.MongoDBInsertQuery; + +public final class MongoDBInsertGenerator { + + private final MongoDBGlobalState globalState; + + private MongoDBInsertGenerator(MongoDBGlobalState globalState) { + this.globalState = globalState; + } + + public static MongoDBQueryAdapter getQuery(MongoDBGlobalState globalState) { + return new MongoDBInsertGenerator(globalState).generate(); + } + + public MongoDBQueryAdapter generate() { + Document result = new Document(); + MongoDBTable table = globalState.getSchema().getRandomTable(); + MongoDBConstantGenerator constantGenerator = new MongoDBConstantGenerator(globalState); + + for (int i = 0; i < table.getColumns().size(); i++) { + if (!globalState.getDmbsSpecificOptions().testRandomTypes) { + constantGenerator.addRandomConstantWithType(result, table.getColumns().get(i).getName(), + table.getColumns().get(i).getType()); + } else { + constantGenerator.addRandomConstant(result, table.getColumns().get(i).getName()); + } + } + + return new MongoDBInsertQuery(table, result); + } +} diff --git a/src/sqlancer/mongodb/gen/MongoDBMatchExpressionGenerator.java b/src/sqlancer/mongodb/gen/MongoDBMatchExpressionGenerator.java new file mode 100644 index 000000000..0645d7f24 --- /dev/null +++ b/src/sqlancer/mongodb/gen/MongoDBMatchExpressionGenerator.java @@ -0,0 +1,291 @@ +package sqlancer.mongodb.gen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.bson.conversions.Bson; + +import com.mongodb.client.model.Filters; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.Node; +import sqlancer.common.gen.UntypedExpressionGenerator; +import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; +import sqlancer.mongodb.MongoDBSchema.MongoDBDataType; +import sqlancer.mongodb.ast.MongoDBBinaryComparisonNode; +import sqlancer.mongodb.ast.MongoDBBinaryLogicalNode; +import sqlancer.mongodb.ast.MongoDBConstant; +import sqlancer.mongodb.ast.MongoDBExpression; +import sqlancer.mongodb.ast.MongoDBRegexNode; +import sqlancer.mongodb.ast.MongoDBUnaryLogicalOperatorNode; +import sqlancer.mongodb.ast.MongoDBUnsupportedPredicate; +import sqlancer.mongodb.test.MongoDBColumnTestReference; +import sqlancer.mongodb.visitor.MongoDBNegateVisitor; + +public class MongoDBMatchExpressionGenerator + extends UntypedExpressionGenerator, MongoDBColumnTestReference> { + + private final MongoDBGlobalState globalState; + + private enum LeafExpression { + BINARY_COMPARISON, REGEX + } + + private enum NonLeafExpression { + BINARY_LOGICAL, UNARY_LOGICAL + } + + public MongoDBMatchExpressionGenerator(MongoDBGlobalState globalState) { + this.globalState = globalState; + } + + @Override + public Node generateLeafNode() { + List possibleOptions = new ArrayList<>(Arrays.asList(LeafExpression.values())); + if (!globalState.getDmbsSpecificOptions().testWithRegex) { + possibleOptions.remove(LeafExpression.REGEX); + } + LeafExpression expr = Randomly.fromList(possibleOptions); + switch (expr) { + case BINARY_COMPARISON: + MongoDBBinaryComparisonOperator operator = MongoDBBinaryComparisonOperator.getRandom(); + MongoDBColumnTestReference reference = (MongoDBColumnTestReference) generateColumn(); + + return new MongoDBBinaryComparisonNode(reference, + generateConstant(reference.getColumnReference().getType()), operator); + case REGEX: + return new MongoDBRegexNode(generateColumn(), + new MongoDBConstantGenerator(globalState).generateConstantWithType(MongoDBDataType.STRING), + getRandomizedRegexOptions()); + default: + throw new AssertionError(); + } + } + + @Override + protected Node generateExpression(int depth) { + if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { + return generateLeafNode(); + } + + List possibleOptions = new ArrayList<>(Arrays.asList(NonLeafExpression.values())); + NonLeafExpression expr = Randomly.fromList(possibleOptions); + switch (expr) { + case BINARY_LOGICAL: + MongoDBBinaryLogicalOperator binaryOperator = MongoDBBinaryLogicalOperator.getRandom(); + return new MongoDBBinaryLogicalNode(generateExpression(depth + 1), generateExpression(depth + 1), + binaryOperator); + case UNARY_LOGICAL: + MongoDBUnaryLogicalOperator unaryOperator = MongoDBUnaryLogicalOperator.getRandom(); + return new MongoDBUnaryLogicalOperatorNode(generateExpression(depth + 1), unaryOperator); + default: + throw new AssertionError(); + } + } + + @Override + public Node generateConstant() { + MongoDBDataType type = MongoDBDataType.getRandom(globalState); + MongoDBConstantGenerator generator = new MongoDBConstantGenerator(globalState); + if (Randomly.getBooleanWithSmallProbability()) { + return MongoDBConstant.createNullConstant(); + } + return generator.generateConstantWithType(type); + } + + public Node generateConstant(MongoDBDataType type) { + MongoDBConstantGenerator generator = new MongoDBConstantGenerator(globalState); + if (Randomly.getBooleanWithSmallProbability() && !globalState.getDmbsSpecificOptions().nullSafety) { + return MongoDBConstant.createNullConstant(); + } + return generator.generateConstantWithType(type); + } + + private String getRandomizedRegexOptions() { + List s = Randomly.subset("i", "m", "x", "s"); + return String.join("", s); + } + + @Override + protected Node generateColumn() { + return Randomly.fromList(columns); + } + + @Override + public Node generatePredicate() { + Node result = super.generatePredicate(); + return MongoDBNegateVisitor.cleanNegations(result); + } + + @Override + public Node negatePredicate(Node predicate) { + Node result = new MongoDBUnaryLogicalOperatorNode(predicate, + MongoDBUnaryLogicalOperator.NOT); + return MongoDBNegateVisitor.cleanNegations(result); + } + + @Override + public Node isNull(Node expr) { + return new MongoDBUnsupportedPredicate<>(); + } + + public enum MongoDBUnaryLogicalOperator implements Operator { + NOT { + @Override + public Bson applyOperator(Bson inner) { + return Filters.not(inner); + } + + @Override + public String getTextRepresentation() { + return "$not"; + } + }; + + public abstract Bson applyOperator(Bson inner); + + public static MongoDBUnaryLogicalOperator getRandom() { + return Randomly.fromOptions(values()); + } + } + + public enum MongoDBBinaryLogicalOperator implements Operator { + AND { + @Override + public Bson applyOperator(Bson left, Bson right) { + return Filters.and(left, right); + } + + @Override + public String getTextRepresentation() { + return "$and"; + } + }, + OR { + @Override + public Bson applyOperator(Bson left, Bson right) { + return Filters.or(left, right); + } + + @Override + public String getTextRepresentation() { + return "$or"; + } + }, + NOR { + @Override + public Bson applyOperator(Bson left, Bson right) { + return Filters.nor(left, right); + } + + @Override + public String getTextRepresentation() { + return "$nor"; + } + }; + + public abstract Bson applyOperator(Bson left, Bson right); + + public static MongoDBBinaryLogicalOperator getRandom() { + return Randomly.fromOptions(values()); + } + } + + public enum MongoDBBinaryComparisonOperator implements Operator { + EQUALS { + @Override + public Bson applyOperator(String columnName, MongoDBConstant constant) { + return Filters.eq(columnName, constant.getValue()); + } + + @Override + public String getTextRepresentation() { + return "$eq"; + } + }, + NOT_EQUALS { + @Override + public Bson applyOperator(String columnName, MongoDBConstant constant) { + return Filters.ne(columnName, constant.getValue()); + } + + @Override + public String getTextRepresentation() { + return "$ne"; + } + }, + GREATER { + @Override + public Bson applyOperator(String columnName, MongoDBConstant constant) { + return Filters.gt(columnName, constant.getValue()); + } + + @Override + public String getTextRepresentation() { + return "$gt"; + } + + }, + LESS { + @Override + public Bson applyOperator(String columnName, MongoDBConstant constant) { + return Filters.lt(columnName, constant.getValue()); + } + + @Override + public String getTextRepresentation() { + return "$lt"; + } + + }, + GREATER_EQUAL { + @Override + public Bson applyOperator(String columnName, MongoDBConstant constant) { + return Filters.gte(columnName, constant.getValue()); + + } + + @Override + public String getTextRepresentation() { + return "$gte"; + } + + }, + LESS_EQUAL { + @Override + public Bson applyOperator(String columnName, MongoDBConstant constant) { + return Filters.lte(columnName, constant.getValue()); + } + + @Override + public String getTextRepresentation() { + return "$lte"; + } + }; + + public abstract Bson applyOperator(String columnName, MongoDBConstant constant); + + public static MongoDBBinaryComparisonOperator getRandom() { + return Randomly.fromOptions(values()); + } + } + + public enum MongoDBRegexOperator implements Operator { + REGEX { + @Override + public Bson applyOperator(String columnName, MongoDBConstant.MongoDBStringConstant regex, String options) { + return Filters.regex(columnName, regex.getStringValue(), options); + } + + @Override + public String getTextRepresentation() { + return "$regex"; + } + }; + + public abstract Bson applyOperator(String columnName, MongoDBConstant.MongoDBStringConstant regex, + String options); + } +} diff --git a/src/sqlancer/mongodb/gen/MongoDBTableGenerator.java b/src/sqlancer/mongodb/gen/MongoDBTableGenerator.java new file mode 100644 index 000000000..6a4f33d38 --- /dev/null +++ b/src/sqlancer/mongodb/gen/MongoDBTableGenerator.java @@ -0,0 +1,54 @@ +package sqlancer.mongodb.gen; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; +import sqlancer.mongodb.MongoDBQueryAdapter; +import sqlancer.mongodb.MongoDBSchema.MongoDBColumn; +import sqlancer.mongodb.MongoDBSchema.MongoDBDataType; +import sqlancer.mongodb.MongoDBSchema.MongoDBTable; +import sqlancer.mongodb.query.MongoDBCreateTableQuery; + +public class MongoDBTableGenerator { + + private MongoDBTable table; + private final List columnsToBeAdded = new ArrayList<>(); + private final MongoDBGlobalState state; + + public MongoDBTableGenerator(MongoDBGlobalState state) { + this.state = state; + } + + public MongoDBQueryAdapter getQuery(MongoDBGlobalState globalState) { + String tableName = globalState.getSchema().getFreeTableName(); + MongoDBCreateTableQuery createTableQuery = new MongoDBCreateTableQuery(tableName); + table = new MongoDBTable(tableName, columnsToBeAdded, false); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + String columnName = String.format("c%d", i); + MongoDBDataType type = createColumn(columnName); + if (globalState.getDmbsSpecificOptions().testValidation) { + createTableQuery.addValidation(columnName, type.getBsonType()); + } + } + globalState.addTable(table); + return createTableQuery; + } + + private MongoDBDataType createColumn(String columnName) { + MongoDBDataType columnType = MongoDBDataType.getRandom(state); + MongoDBColumn newColumn = new MongoDBColumn(columnName, columnType, false, false); + newColumn.setTable(table); + columnsToBeAdded.add(newColumn); + return columnType; + } + + public String getTableName() { + return table.getName(); + } + + public MongoDBTable getGeneratedTable() { + return table; + } +} diff --git a/src/sqlancer/mongodb/query/MongoDBCreateIndexQuery.java b/src/sqlancer/mongodb/query/MongoDBCreateIndexQuery.java new file mode 100644 index 000000000..c873b5924 --- /dev/null +++ b/src/sqlancer/mongodb/query/MongoDBCreateIndexQuery.java @@ -0,0 +1,77 @@ +package sqlancer.mongodb.query; + +import java.util.ArrayList; +import java.util.List; + +import org.bson.conversions.Bson; + +import com.mongodb.client.model.Indexes; + +import sqlancer.GlobalState; +import sqlancer.Main; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.mongodb.MongoDBConnection; +import sqlancer.mongodb.MongoDBQueryAdapter; +import sqlancer.mongodb.MongoDBSchema.MongoDBTable; + +public class MongoDBCreateIndexQuery extends MongoDBQueryAdapter { + + private final MongoDBTable table; + private final List indeces; + private final List logIndeces; + + public MongoDBCreateIndexQuery(MongoDBTable table) { + this.table = table; + this.indeces = new ArrayList<>(); + this.logIndeces = new ArrayList<>(); + } + + public void addIndex(String column, boolean ascending) { + if (ascending) { + indeces.add(Indexes.ascending(column)); + logIndeces.add(column + ": 1"); + } else { + indeces.add(Indexes.descending(column)); + logIndeces.add(column + ": -1"); + } + } + + @Override + public String getLogString() { + StringBuilder sb = new StringBuilder(); + sb.append("db.").append(table.getName()).append(".createIndex({"); + String helper = ""; + for (String index : logIndeces) { + sb.append(helper); + helper = ","; + sb.append(index); + } + sb.append("})\n"); + return sb.toString(); + } + + @Override + public boolean couldAffectSchema() { + return false; + } + + @Override + public > boolean execute(G globalState, String... fills) + throws Exception { + Main.nrSuccessfulActions.addAndGet(1); + Bson index; + if (indeces.size() > 1) { + index = Indexes.compoundIndex(indeces); + } else { + index = indeces.get(0); + } + globalState.getConnection().getDatabase().getCollection(table.getName()).createIndex(index); + return true; + } + + @Override + public ExpectedErrors getExpectedErrors() { + return new ExpectedErrors(); + } + +} diff --git a/src/sqlancer/mongodb/query/MongoDBCreateTableQuery.java b/src/sqlancer/mongodb/query/MongoDBCreateTableQuery.java new file mode 100644 index 000000000..7bc174c77 --- /dev/null +++ b/src/sqlancer/mongodb/query/MongoDBCreateTableQuery.java @@ -0,0 +1,115 @@ +package sqlancer.mongodb.query; + +import java.util.ArrayList; +import java.util.List; + +import org.bson.BsonType; +import org.bson.conversions.Bson; + +import com.mongodb.client.model.CreateCollectionOptions; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.ValidationOptions; + +import sqlancer.GlobalState; +import sqlancer.Main; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.mongodb.MongoDBConnection; +import sqlancer.mongodb.MongoDBQueryAdapter; + +public class MongoDBCreateTableQuery extends MongoDBQueryAdapter { + + private final String tableName; + private Bson validationFilter; + private final List logRequiredList; + private final List logPropertiesList; + + public MongoDBCreateTableQuery(String tableName) { + this.tableName = tableName; + this.validationFilter = null; + logRequiredList = new ArrayList<>(); + logPropertiesList = new ArrayList<>(); + } + + @Override + public boolean couldAffectSchema() { + return true; + } + + @Override + public > boolean execute(G globalState, String... fills) + throws Exception { + ValidationOptions collOptions = new ValidationOptions().validator(this.validationFilter); + Main.nrSuccessfulActions.addAndGet(1); + globalState.getConnection().getDatabase().createCollection(tableName, + new CreateCollectionOptions().validationOptions(collOptions)); + return true; + } + + @Override + public ExpectedErrors getExpectedErrors() { + return new ExpectedErrors(); + } + + @Override + public String getLogString() { + String helper = ""; + StringBuilder sb = new StringBuilder(); + sb.append("db.createCollection(\"").append(tableName).append("\", {\n"); + + if (!logPropertiesList.isEmpty()) { + sb.append("validator: {"); + sb.append("$jsonSchema: {"); + sb.append("bsonType:\"object\","); + sb.append("required: [\n"); + for (String req : logRequiredList) { + sb.append(helper); + helper = ","; + sb.append(req); + } + sb.append("],"); + sb.append("properties: {\n"); + for (String prop : logPropertiesList) { + sb.append(prop); + } + sb.append("}}}})"); + } else { + sb.append("})"); + } + + return sb.toString(); + } + + public void addValidation(String columnName, BsonType type) { + Bson nameFilter = Filters.exists(columnName); + Bson typeFilter = Filters.type(columnName, type); + + if (validationFilter == null) { + validationFilter = Filters.and(nameFilter, typeFilter); + } else { + validationFilter = Filters.and(validationFilter, Filters.and(nameFilter, typeFilter)); + } + + logRequiredList.add("\"" + columnName + "\""); + logPropertiesList.add(columnName + ": { bsonType:\"" + bsonTypeToString(type) + "\"},\n"); + } + + public String bsonTypeToString(BsonType type) { + switch (type) { + case DOUBLE: + return "double"; + case STRING: + return "string"; + case BOOLEAN: + return "bool"; + case INT32: + case INT64: + return "int"; + case DATE_TIME: + return "date"; + case TIMESTAMP: + return "timestamp"; + default: + throw new IllegalStateException(); + } + } +} diff --git a/src/sqlancer/mongodb/query/MongoDBInsertQuery.java b/src/sqlancer/mongodb/query/MongoDBInsertQuery.java new file mode 100644 index 000000000..127dc82d0 --- /dev/null +++ b/src/sqlancer/mongodb/query/MongoDBInsertQuery.java @@ -0,0 +1,87 @@ +package sqlancer.mongodb.query; + +import org.bson.BsonDateTime; +import org.bson.BsonTimestamp; +import org.bson.Document; +import org.bson.types.ObjectId; + +import com.mongodb.client.result.InsertOneResult; + +import sqlancer.GlobalState; +import sqlancer.Main; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.mongodb.MongoDBConnection; +import sqlancer.mongodb.MongoDBQueryAdapter; +import sqlancer.mongodb.MongoDBSchema.MongoDBTable; + +public class MongoDBInsertQuery extends MongoDBQueryAdapter { + boolean excluded; + private final MongoDBTable table; + private final Document documentToBeInserted; + + public MongoDBInsertQuery(MongoDBTable table, Document documentToBeInserted) { + this.table = table; + this.documentToBeInserted = documentToBeInserted; + this.excluded = false; + } + + @Override + public String getLogString() { + StringBuilder sb = new StringBuilder(); + sb.append("db." + table.getName() + ".insert({"); + String helper = ""; + for (String key : documentToBeInserted.keySet()) { + sb.append(helper); + helper = ", "; + if (documentToBeInserted.get(key) instanceof ObjectId) { + continue; + } + Object value = documentToBeInserted.get(key); + sb.append(key); + sb.append(": "); + sb.append(getStringRepresentation(value)); + } + sb.append("})\n"); + + return sb.toString(); + } + + private String getStringRepresentation(Object value) { + if (value instanceof Double) { + return String.valueOf(value); + } else if (value instanceof Integer) { + return "NumberInt(" + value + ")"; + } else if (value instanceof String) { + return "\"" + value + "\""; + } else if (value instanceof BsonDateTime) { + return "new Date(" + ((BsonDateTime) value).getValue() + ")"; + } else if (value instanceof BsonTimestamp) { + return "Timestamp(" + ((BsonTimestamp) value).getValue() + ",1)"; + } else if (value instanceof Boolean) { + return String.valueOf(value); + } else if (value == null) { + return "null"; + } else { + throw new IllegalStateException(); + } + } + + @Override + public boolean couldAffectSchema() { + return true; + } + + @Override + public > boolean execute(G globalState, String... fills) + throws Exception { + Main.nrSuccessfulActions.addAndGet(1); + InsertOneResult result = globalState.getConnection().getDatabase().getCollection(table.getName()) + .insertOne(documentToBeInserted); + return result.wasAcknowledged(); + } + + @Override + public ExpectedErrors getExpectedErrors() { + return new ExpectedErrors(); + } +} diff --git a/src/sqlancer/mongodb/query/MongoDBRemoveQuery.java b/src/sqlancer/mongodb/query/MongoDBRemoveQuery.java new file mode 100644 index 000000000..6fe1c9e3f --- /dev/null +++ b/src/sqlancer/mongodb/query/MongoDBRemoveQuery.java @@ -0,0 +1,59 @@ +package sqlancer.mongodb.query; + +import org.bson.Document; +import org.bson.types.ObjectId; + +import com.mongodb.client.result.DeleteResult; + +import sqlancer.GlobalState; +import sqlancer.Main; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.mongodb.MongoDBConnection; +import sqlancer.mongodb.MongoDBQueryAdapter; +import sqlancer.mongodb.MongoDBSchema; + +public class MongoDBRemoveQuery extends MongoDBQueryAdapter { + + private final String objectId; + private final MongoDBSchema.MongoDBTable table; + + public MongoDBRemoveQuery(MongoDBSchema.MongoDBTable table, String objectId) { + this.objectId = objectId; + this.table = table; + } + + @Override + public boolean couldAffectSchema() { + return true; + } + + @Override + public > boolean execute(G globalState, String... fills) + throws Exception { + try { + DeleteResult result = globalState.getConnection().getDatabase().getCollection(table.getName()) + .deleteOne(new Document("_id", new ObjectId(objectId))); + if (result.wasAcknowledged()) { + Main.nrSuccessfulActions.addAndGet(1); + } else { + Main.nrUnsuccessfulActions.addAndGet(1); + } + return result.wasAcknowledged(); + } catch (Exception e) { + Main.nrUnsuccessfulActions.addAndGet(1); + return false; + } + } + + @Override + public ExpectedErrors getExpectedErrors() { + return new ExpectedErrors(); + } + + @Override + public String getLogString() { + StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append("db.").append(table.getName()).append(".remove({'_id': '").append(objectId).append("'})"); + return stringBuilder.toString(); + } +} diff --git a/src/sqlancer/mongodb/query/MongoDBSelectQuery.java b/src/sqlancer/mongodb/query/MongoDBSelectQuery.java new file mode 100644 index 000000000..1288f114c --- /dev/null +++ b/src/sqlancer/mongodb/query/MongoDBSelectQuery.java @@ -0,0 +1,147 @@ +package sqlancer.mongodb.query; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.bson.Document; +import org.bson.conversions.Bson; + +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoCursor; + +import sqlancer.GlobalState; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLancerResultSet; +import sqlancer.mongodb.MongoDBConnection; +import sqlancer.mongodb.MongoDBQueryAdapter; +import sqlancer.mongodb.ast.MongoDBExpression; +import sqlancer.mongodb.ast.MongoDBSelect; +import sqlancer.mongodb.visitor.MongoDBVisitor; + +public class MongoDBSelectQuery extends MongoDBQueryAdapter { + + private final MongoDBSelect select; + + private List resultSet; + + public MongoDBSelectQuery(MongoDBSelect select) { + this.select = select; + } + + @Override + public boolean couldAffectSchema() { + return false; + } + + @Override + public > boolean execute(G globalState, String... fills) + throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + public ExpectedErrors getExpectedErrors() { + ExpectedErrors errors = new ExpectedErrors(); + // ARITHMETIC + errors.add("Failed to optimize pipeline :: caused by :: Can't coerce out of range value"); + errors.add("Can't coerce out of range value"); + errors.add("date overflow in $add"); + errors.add("Failed to optimize pipeline :: caused by :: $sqrt only supports numeric types, not"); + errors.add("Failed to optimize pipeline :: caused by :: $sqrt's argument must be greater than or equal to 0"); + errors.add("Failed to optimize pipeline :: caused by :: $pow's base must be numeric, not"); + errors.add("Failed to optimize pipeline :: caused by :: $pow cannot take a base of 0 and a negative exponent"); + errors.add("Failed to optimize pipeline :: caused by :: $add only supports numeric or date types, not"); + errors.add("Failed to optimize pipeline :: caused by :: $exp only supports numeric types, not"); + errors.add("Failed to optimize pipeline :: caused by :: $log's base must be numeric, not"); + errors.add("Failed to optimize pipeline :: caused by :: $log's base must be a positive number not equal to 1"); + errors.add("Failed to optimize pipeline :: caused by :: $multiply only supports numeric types, not"); + errors.add("$log's argument must be numeric, not"); + errors.add("$log's argument must be a positive number, but"); + errors.add("$log's base must be numeric, not"); + errors.add("$log's base must be a positive number not equal to 1"); + errors.add("$divide only supports numeric types, not"); + errors.add("can't $divide by zero"); + errors.add("$pow's exponent must be numeric, not"); + errors.add("$pow's base must be numeric, not"); + errors.add("$pow cannot take a base of 0 and a negative exponent"); + errors.add("$add only supports numeric or date types, not"); + errors.add("only one date allowed in an $add expression"); + errors.add("$multiply only supports numeric types, not"); + errors.add("$exp only supports numeric types, not"); + errors.add("$sqrt's argument must be greater than or equal to 0"); + errors.add("$sqrt only supports numeric types, not"); + + // REGEX + errors.add("Regular expression is invalid: nothing to repeat"); + errors.add("Regular expression is invalid: missing terminating ] for character class"); + errors.add("Regular expression is invalid: unmatched parentheses"); + errors.add("Regular expression is invalid: missing )"); + errors.add("Regular expression is invalid: invalid UTF-8 string"); + errors.add("Regular expression is invalid: \\k is not followed by a braced, angle-bracketed, or quoted name"); + errors.add("Regular expression is invalid: missing opening brace after \\\\o"); + errors.add("Regular expression is invalid: reference to non-existent subpattern"); + errors.add("Regular expression is invalid: \\ at end of pattern"); + errors.add("Regular expression is invalid: PCRE does not support \\L, \\l, \\N{name}, \\U, or \\u"); + errors.add("Regular expression is invalid: (?R or (?[+-]digits must be followed by )"); + errors.add("Regular expression is invalid: unknown property name after \\P or \\p"); + errors.add("Regular expression is invalid: (*VERB) not recognized or malformed"); + errors.add("Regular expression is invalid: a numbered reference must not be zero"); + errors.add("Regular expression is invalid: unrecognized character after (? or (?-"); + errors.add("Regular expression is invalid: \\c at end of pattern"); + errors.add("Regular expression is invalid: malformed \\P or \\p sequence"); + errors.add("Regular expression is invalid: range out of order in character class"); + errors.add("Regular expression is invalid: group name must start with a non-digit"); + errors.add("Regular expression is invalid: \\c must be followed by an ASCII character"); + errors.add("Regular expression is invalid: subpattern name expected"); + errors.add("Regular expression is invalid: POSIX collating elements are not supported"); + errors.add("Regular expression is invalid: closing ) for (?C expected"); + errors.add("Regular expression is invalid: syntax error in subpattern name (missing terminator)"); + errors.add("Regular expression is invalid: \\\\N is not supported in a class"); + errors.add("Regular expression is invalid: non-octal character in \\o{} (closing brace missing?)"); + errors.add("Regular expression is invalid: non-hex character in \\x{} (closing brace missing?)"); + errors.add( + "Regular expression is invalid: \\g is not followed by a braced, angle-bracketed, or quoted name/number or by a plain number"); + errors.add("Regular expression is invalid: digits missing in \\x{} or \\o{}"); + errors.add("Regular expression is invalid: malformed number or name after (?("); + errors.add("Regular expression is invalid: digit expected after (?+"); + errors.add("Regular expression is invalid: assertion expected after (?( or (?(?C)"); + errors.add("Regular expression is invalid: unrecognized character after (?P"); + + return errors; + } + + @Override + public > SQLancerResultSet executeAndGet(G globalState, + String... fills) throws Exception { + if (globalState.getOptions().logEachSelect()) { + globalState.getLogger().writeCurrent(this.getLogString()); + try { + globalState.getLogger().getCurrentFileWriter().flush(); + } catch (IOException e) { + e.printStackTrace(); + } + } + List pipeline = MongoDBVisitor.asQuery(select); + + MongoCollection collection = globalState.getConnection().getDatabase() + .getCollection(select.getMainTableName()); + MongoCursor cursor = collection.aggregate(pipeline).cursor(); + resultSet = new ArrayList<>(); + while (cursor.hasNext()) { + Document document = cursor.next(); + resultSet.add(document); + } + return null; + } + + @Override + public String getLogString() { + return MongoDBVisitor.asStringLog(select); + } + + public List getResultSet() { + return resultSet; + } + +} diff --git a/src/sqlancer/mongodb/test/MongoDBColumnTestReference.java b/src/sqlancer/mongodb/test/MongoDBColumnTestReference.java new file mode 100644 index 000000000..59a2a6724 --- /dev/null +++ b/src/sqlancer/mongodb/test/MongoDBColumnTestReference.java @@ -0,0 +1,40 @@ +package sqlancer.mongodb.test; + +import sqlancer.common.ast.newast.Node; +import sqlancer.mongodb.MongoDBSchema.MongoDBColumn; +import sqlancer.mongodb.ast.MongoDBExpression; + +public class MongoDBColumnTestReference implements Node { + + private final MongoDBColumn columnReference; + private final boolean inMainTable; + + public MongoDBColumnTestReference(MongoDBColumn columnReference, boolean inMainTable) { + this.columnReference = columnReference; + this.inMainTable = inMainTable; + } + + public String getQueryString() { + if (inMainTable) { + return this.columnReference.getName(); + } else { + return "join_" + this.columnReference.getTable().getName() + "." + this.columnReference.getName(); + } + } + + public boolean inMainTable() { + return inMainTable; + } + + public String getTableName() { + return this.columnReference.getTable().getName(); + } + + public String getPlainName() { + return this.columnReference.getName(); + } + + public MongoDBColumn getColumnReference() { + return columnReference; + } +} diff --git a/src/sqlancer/mongodb/test/MongoDBQueryPartitioningBase.java b/src/sqlancer/mongodb/test/MongoDBQueryPartitioningBase.java new file mode 100644 index 000000000..b26aa6028 --- /dev/null +++ b/src/sqlancer/mongodb/test/MongoDBQueryPartitioningBase.java @@ -0,0 +1,92 @@ +package sqlancer.mongodb.test; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.ast.newast.Node; +import sqlancer.common.gen.ExpressionGenerator; +import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; +import sqlancer.common.oracle.TestOracle; +import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; +import sqlancer.mongodb.MongoDBSchema; +import sqlancer.mongodb.MongoDBSchema.MongoDBColumn; +import sqlancer.mongodb.MongoDBSchema.MongoDBTable; +import sqlancer.mongodb.MongoDBSchema.MongoDBTables; +import sqlancer.mongodb.ast.MongoDBExpression; +import sqlancer.mongodb.ast.MongoDBSelect; +import sqlancer.mongodb.gen.MongoDBComputedExpressionGenerator; +import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator; + +public class MongoDBQueryPartitioningBase + extends TernaryLogicPartitioningOracleBase, MongoDBGlobalState> implements TestOracle { + + protected MongoDBSchema schema; + protected MongoDBTables targetTables; + protected MongoDBTable mainTable; + protected List targetColumns; + protected MongoDBMatchExpressionGenerator expressionGenerator; + protected MongoDBSelect select; + + public MongoDBQueryPartitioningBase(MongoDBGlobalState state) { + super(state); + } + + @Override + public void check() throws Exception { + schema = state.getSchema(); + targetTables = schema.getRandomTableNonEmptyTables(); + mainTable = targetTables.getTables().get(0); + generateTargetColumns(); + expressionGenerator = new MongoDBMatchExpressionGenerator(state).setColumns(targetColumns); + initializeTernaryPredicateVariants(); + select = new MongoDBSelect<>(mainTable.getName(), targetColumns.get(0)); + select.setProjectionList(targetColumns); + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setLookupList(targetColumns); + } else { + select.setLookupList(Randomly.nonEmptySubset(targetColumns)); + } + if (state.getDmbsSpecificOptions().testComputedValues) { + generateComputedColumns(); + } + } + + private void generateComputedColumns() { + List> computedColumns = new ArrayList<>(); + int numberComputedColumns = state.getRandomly().getInteger(1, 4); + MongoDBComputedExpressionGenerator generator = new MongoDBComputedExpressionGenerator(state) + .setColumns(targetColumns); + for (int i = 0; i < numberComputedColumns; i++) { + computedColumns.add(generator.generateExpression()); + } + select.setComputedClause(computedColumns); + } + + private void generateTargetColumns() { + targetColumns = new ArrayList<>(); + for (MongoDBColumn c : mainTable.getColumns()) { + targetColumns.add(new MongoDBColumnTestReference(c, true)); + } + List joinsOtherTables = new ArrayList<>(); + if (!state.getDmbsSpecificOptions().nullSafety) { + for (int i = 1; i < targetTables.getTables().size(); i++) { + MongoDBTable procTable = targetTables.getTables().get(i); + for (MongoDBColumn c : procTable.getColumns()) { + joinsOtherTables.add(new MongoDBColumnTestReference(c, false)); + } + } + } + if (!joinsOtherTables.isEmpty()) { + int randNumber = state.getRandomly().getInteger(1, Math.min(joinsOtherTables.size(), 4)); + List subsetJoinsOtherTables = Randomly.nonEmptySubset(joinsOtherTables, + randNumber); + targetColumns.addAll(subsetJoinsOtherTables); + } + } + + @Override + protected ExpressionGenerator> getGen() { + return expressionGenerator; + } +} diff --git a/src/sqlancer/mongodb/test/MongoDBQueryPartitioningWhereTester.java b/src/sqlancer/mongodb/test/MongoDBQueryPartitioningWhereTester.java new file mode 100644 index 000000000..5a7507672 --- /dev/null +++ b/src/sqlancer/mongodb/test/MongoDBQueryPartitioningWhereTester.java @@ -0,0 +1,48 @@ +package sqlancer.mongodb.test; + +import static sqlancer.mongodb.MongoDBComparatorHelper.getResultSetAsDocumentList; + +import java.util.List; + +import org.bson.Document; + +import sqlancer.mongodb.MongoDBComparatorHelper; +import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; +import sqlancer.mongodb.query.MongoDBSelectQuery; + +public class MongoDBQueryPartitioningWhereTester extends MongoDBQueryPartitioningBase { + public MongoDBQueryPartitioningWhereTester(MongoDBGlobalState state) { + super(state); + } + + @Override + public void check() throws Exception { + super.check(); + + select.setWithCountClause(false); + + select.setFilterClause(null); + MongoDBSelectQuery q = new MongoDBSelectQuery(select); + List firstResultSet = getResultSetAsDocumentList(q, state); + + select.setFilterClause(predicate); + q = new MongoDBSelectQuery(select); + List secondResultSet = getResultSetAsDocumentList(q, state); + + select.setFilterClause(negatedPredicate); + q = new MongoDBSelectQuery(select); + List thirdResultSet = getResultSetAsDocumentList(q, state); + + if (state.getDmbsSpecificOptions().testWithCount) { + select.setWithCountClause(true); + select.setFilterClause(predicate); + q = new MongoDBSelectQuery(select); + List forthResultSet = getResultSetAsDocumentList(q, state); + MongoDBComparatorHelper.assumeCountIsEqual(secondResultSet, forthResultSet, q); + } + + secondResultSet.addAll(thirdResultSet); + MongoDBComparatorHelper.assumeResultSetsAreEqual(firstResultSet, secondResultSet, q); + + } +} diff --git a/src/sqlancer/mongodb/test/MongoDBRemoveReduceBase.java b/src/sqlancer/mongodb/test/MongoDBRemoveReduceBase.java new file mode 100644 index 000000000..2a2a54744 --- /dev/null +++ b/src/sqlancer/mongodb/test/MongoDBRemoveReduceBase.java @@ -0,0 +1,89 @@ +package sqlancer.mongodb.test; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.ast.newast.Node; +import sqlancer.common.gen.ExpressionGenerator; +import sqlancer.common.oracle.RemoveReduceOracleBase; +import sqlancer.common.oracle.TestOracle; +import sqlancer.mongodb.MongoDBProvider; +import sqlancer.mongodb.MongoDBSchema; +import sqlancer.mongodb.ast.MongoDBExpression; +import sqlancer.mongodb.ast.MongoDBSelect; +import sqlancer.mongodb.gen.MongoDBComputedExpressionGenerator; +import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator; + +public class MongoDBRemoveReduceBase extends + RemoveReduceOracleBase, MongoDBProvider.MongoDBGlobalState> implements TestOracle { + + protected MongoDBSchema schema; + protected MongoDBSchema.MongoDBTables targetTables; + protected MongoDBSchema.MongoDBTable mainTable; + protected List targetColumns; + protected MongoDBMatchExpressionGenerator expressionGenerator; + protected MongoDBSelect select; + + protected MongoDBRemoveReduceBase(MongoDBProvider.MongoDBGlobalState state) { + super(state); + } + + @Override + public void check() throws Exception { + schema = state.getSchema(); + targetTables = schema.getRandomTableNonEmptyTables(); + mainTable = targetTables.getTables().get(0); + generateTargetColumns(); + expressionGenerator = new MongoDBMatchExpressionGenerator(state).setColumns(targetColumns); + initializeRemoveReduceOracle(); + select = new MongoDBSelect<>(mainTable.getName(), targetColumns.get(0)); + select.setProjectionList(targetColumns); + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setLookupList(targetColumns); + } else { + select.setLookupList(Randomly.nonEmptySubset(targetColumns)); + } + if (state.getDmbsSpecificOptions().testComputedValues) { + generateComputedColumns(); + } + } + + private void generateTargetColumns() { + targetColumns = new ArrayList<>(); + for (MongoDBSchema.MongoDBColumn c : mainTable.getColumns()) { + targetColumns.add(new MongoDBColumnTestReference(c, true)); + } + List joinsOtherTables = new ArrayList<>(); + if (!state.getDmbsSpecificOptions().nullSafety) { + for (int i = 1; i < targetTables.getTables().size(); i++) { + MongoDBSchema.MongoDBTable procTable = targetTables.getTables().get(i); + for (MongoDBSchema.MongoDBColumn c : procTable.getColumns()) { + joinsOtherTables.add(new MongoDBColumnTestReference(c, false)); + } + } + } + if (!joinsOtherTables.isEmpty()) { + int randNumber = state.getRandomly().getInteger(1, Math.min(joinsOtherTables.size(), 4)); + List subsetJoinsOtherTables = Randomly.nonEmptySubset(joinsOtherTables, + randNumber); + targetColumns.addAll(subsetJoinsOtherTables); + } + } + + private void generateComputedColumns() { + List> computedColumns = new ArrayList<>(); + int numberComputedColumns = state.getRandomly().getInteger(1, 4); + MongoDBComputedExpressionGenerator generator = new MongoDBComputedExpressionGenerator(state) + .setColumns(targetColumns); + for (int i = 0; i < numberComputedColumns; i++) { + computedColumns.add(generator.generateExpression()); + } + select.setComputedClause(computedColumns); + } + + @Override + protected ExpressionGenerator> getGen() { + return expressionGenerator; + } +} diff --git a/src/sqlancer/mongodb/test/MongoDBRemoveReduceTester.java b/src/sqlancer/mongodb/test/MongoDBRemoveReduceTester.java new file mode 100644 index 000000000..829f63f23 --- /dev/null +++ b/src/sqlancer/mongodb/test/MongoDBRemoveReduceTester.java @@ -0,0 +1,49 @@ +package sqlancer.mongodb.test; + +import static sqlancer.mongodb.MongoDBComparatorHelper.getResultSetAsDocumentList; + +import java.util.List; + +import org.bson.Document; + +import sqlancer.Randomly; +import sqlancer.mongodb.MongoDBProvider; +import sqlancer.mongodb.MongoDBQueryAdapter; +import sqlancer.mongodb.gen.MongoDBInsertGenerator; +import sqlancer.mongodb.query.MongoDBRemoveQuery; +import sqlancer.mongodb.query.MongoDBSelectQuery; + +public class MongoDBRemoveReduceTester extends MongoDBRemoveReduceBase { + public MongoDBRemoveReduceTester(MongoDBProvider.MongoDBGlobalState state) { + super(state); + } + + @Override + public void check() throws Exception { + super.check(); + + select.setWithCountClause(false); + + select.setFilterClause(predicate); + MongoDBSelectQuery selectQuery = new MongoDBSelectQuery(select); + List firstResultSet = getResultSetAsDocumentList(selectQuery, state); + if (firstResultSet == null || firstResultSet.isEmpty()) { + return; + } + + Document documentToRemove = Randomly.fromList(firstResultSet); + MongoDBRemoveQuery removeQuery = new MongoDBRemoveQuery(mainTable, documentToRemove.get("_id").toString()); + state.executeStatement(removeQuery); + + selectQuery = new MongoDBSelectQuery(select); + List secondResultSet = getResultSetAsDocumentList(selectQuery, state); + + MongoDBQueryAdapter insertQuery = MongoDBInsertGenerator.getQuery(state); + state.executeStatement(insertQuery); + + if (secondResultSet.size() + 1 != firstResultSet.size()) { + String assertMessage = "The Result Sizes mismatches!"; + throw new AssertionError(assertMessage); + } + } +} diff --git a/src/sqlancer/mongodb/visitor/MongoDBNegateVisitor.java b/src/sqlancer/mongodb/visitor/MongoDBNegateVisitor.java new file mode 100644 index 000000000..39b607f67 --- /dev/null +++ b/src/sqlancer/mongodb/visitor/MongoDBNegateVisitor.java @@ -0,0 +1,161 @@ +package sqlancer.mongodb.visitor; + +import static sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBBinaryLogicalOperator.AND; +import static sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBBinaryLogicalOperator.NOR; +import static sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBBinaryLogicalOperator.OR; +import static sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBUnaryLogicalOperator.NOT; + +import sqlancer.common.ast.newast.Node; +import sqlancer.mongodb.ast.MongoDBBinaryComparisonNode; +import sqlancer.mongodb.ast.MongoDBBinaryLogicalNode; +import sqlancer.mongodb.ast.MongoDBConstant; +import sqlancer.mongodb.ast.MongoDBExpression; +import sqlancer.mongodb.ast.MongoDBRegexNode; +import sqlancer.mongodb.ast.MongoDBSelect; +import sqlancer.mongodb.ast.MongoDBUnaryLogicalOperatorNode; +import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator; + +public class MongoDBNegateVisitor extends MongoDBVisitor { + + private boolean negate; + Node negatedExpression; + + public MongoDBNegateVisitor(boolean negate) { + this.negate = negate; + } + + @Override + public void visit(Node expr) { + if (expr instanceof MongoDBConstant) { + visit((MongoDBConstant) expr); + } else if (expr instanceof MongoDBSelect) { + visit((MongoDBSelect) expr); + } else if (expr instanceof MongoDBBinaryComparisonNode) { + visit((MongoDBBinaryComparisonNode) expr); + } else if (expr instanceof MongoDBUnaryLogicalOperatorNode) { + visit((MongoDBUnaryLogicalOperatorNode) expr); + } else if (expr instanceof MongoDBRegexNode) { + visit((MongoDBRegexNode) expr); + } else if (expr instanceof MongoDBBinaryLogicalNode) { + visit((MongoDBBinaryLogicalNode) expr); + } else { + throw new AssertionError(expr.getClass()); + } + } + + public void visit(MongoDBBinaryComparisonNode expr) { + + if (negate) { + negatedExpression = new MongoDBUnaryLogicalOperatorNode(expr, NOT); + switch (expr.operator()) { + case EQUALS: + negatedExpression = new MongoDBBinaryComparisonNode(expr.getLeft(), expr.getRight(), + MongoDBMatchExpressionGenerator.MongoDBBinaryComparisonOperator.NOT_EQUALS); + break; + case NOT_EQUALS: + negatedExpression = new MongoDBBinaryComparisonNode(expr.getLeft(), expr.getRight(), + MongoDBMatchExpressionGenerator.MongoDBBinaryComparisonOperator.EQUALS); + break; + case LESS: + negatedExpression = new MongoDBBinaryComparisonNode(expr.getLeft(), expr.getRight(), + MongoDBMatchExpressionGenerator.MongoDBBinaryComparisonOperator.GREATER_EQUAL); + break; + case LESS_EQUAL: + negatedExpression = new MongoDBBinaryComparisonNode(expr.getLeft(), expr.getRight(), + MongoDBMatchExpressionGenerator.MongoDBBinaryComparisonOperator.GREATER); + break; + case GREATER: + negatedExpression = new MongoDBBinaryComparisonNode(expr.getLeft(), expr.getRight(), + MongoDBMatchExpressionGenerator.MongoDBBinaryComparisonOperator.LESS_EQUAL); + break; + case GREATER_EQUAL: + negatedExpression = new MongoDBBinaryComparisonNode(expr.getLeft(), expr.getRight(), + MongoDBMatchExpressionGenerator.MongoDBBinaryComparisonOperator.LESS); + break; + default: + throw new UnsupportedOperationException(); + } + } else { + negatedExpression = expr; + } + } + + public void visit(MongoDBRegexNode expr) { + if (negate) { + negatedExpression = new MongoDBUnaryLogicalOperatorNode(expr, NOT); + } else { + negatedExpression = expr; + } + } + + public void visit(MongoDBUnaryLogicalOperatorNode expr) { + if (!(expr.operator().equals(NOT))) { + throw new UnsupportedOperationException(); + } + negate = !negate; + visit(expr.getExpr()); + } + + public void visit(MongoDBBinaryLogicalNode expr) { + boolean saveNegate = negate; + Node left; + Node right; + switch (expr.operator()) { + case OR: + negate = false; + visit(expr.getLeft()); + left = negatedExpression; + negate = false; + visit(expr.getRight()); + right = negatedExpression; + if (saveNegate) { + negatedExpression = new MongoDBBinaryLogicalNode(left, right, NOR); + } else { + negatedExpression = new MongoDBBinaryLogicalNode(left, right, OR); + } + break; + case AND: + negate = saveNegate; + visit(expr.getLeft()); + left = negatedExpression; + negate = saveNegate; + visit(expr.getRight()); + right = negatedExpression; + if (saveNegate) { + negatedExpression = new MongoDBBinaryLogicalNode(left, right, OR); + } else { + negatedExpression = new MongoDBBinaryLogicalNode(left, right, AND); + } + break; + case NOR: + negate = false; + visit(expr.getLeft()); + left = negatedExpression; + negate = false; + visit(expr.getRight()); + right = negatedExpression; + if (saveNegate) { + negatedExpression = new MongoDBBinaryLogicalNode(left, right, OR); + } else { + negatedExpression = new MongoDBBinaryLogicalNode(left, right, NOR); + } + break; + default: + throw new UnsupportedOperationException(expr.getOperatorRepresentation()); + } + } + + @Override + public void visit(MongoDBConstant c) { + negatedExpression = c; + } + + @Override + public void visit(MongoDBSelect s) { + throw new UnsupportedOperationException(); + } + + public Node getNegatedExpression() { + return negatedExpression; + } +} diff --git a/src/sqlancer/mongodb/visitor/MongoDBToLogVisitor.java b/src/sqlancer/mongodb/visitor/MongoDBToLogVisitor.java new file mode 100644 index 000000000..4c55e17a6 --- /dev/null +++ b/src/sqlancer/mongodb/visitor/MongoDBToLogVisitor.java @@ -0,0 +1,195 @@ +package sqlancer.mongodb.visitor; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.common.ast.newast.NewFunctionNode; +import sqlancer.common.ast.newast.Node; +import sqlancer.mongodb.ast.MongoDBBinaryComparisonNode; +import sqlancer.mongodb.ast.MongoDBBinaryLogicalNode; +import sqlancer.mongodb.ast.MongoDBConstant; +import sqlancer.mongodb.ast.MongoDBExpression; +import sqlancer.mongodb.ast.MongoDBRegexNode; +import sqlancer.mongodb.ast.MongoDBSelect; +import sqlancer.mongodb.ast.MongoDBUnaryLogicalOperatorNode; +import sqlancer.mongodb.gen.MongoDBComputedExpressionGenerator.ComputedFunction; +import sqlancer.mongodb.test.MongoDBColumnTestReference; + +public class MongoDBToLogVisitor extends MongoDBVisitor { + + private String mainTableName; + private List lookups; + private String filter; + private String projects; + private boolean hasFilter; + private boolean withCount; + + public String visitLog(Node expr) { + if (expr instanceof MongoDBUnaryLogicalOperatorNode) { + return visit((MongoDBUnaryLogicalOperatorNode) expr); + } else if (expr instanceof MongoDBBinaryLogicalNode) { + return visit((MongoDBBinaryLogicalNode) expr); + } else if (expr instanceof MongoDBBinaryComparisonNode) { + return visit((MongoDBBinaryComparisonNode) expr); + } else if (expr instanceof MongoDBRegexNode) { + return visit((MongoDBRegexNode) expr); + } else { + throw new AssertionError(expr.getClass()); + } + } + + public String visitComputed(Node expr) { + if (expr instanceof NewFunctionNode) { + return visitComputed((NewFunctionNode) expr); + } else { + throw new AssertionError(expr.getClass()); + } + } + + public String visitComputed(NewFunctionNode expr) { + List arguments = new ArrayList<>(); + for (int i = 0; i < expr.getArgs().size(); i++) { + if (expr.getArgs().get(i) instanceof MongoDBConstant) { + arguments.add(((MongoDBConstant) expr.getArgs().get(i)).getLogValue()); + continue; + } + if (expr.getArgs().get(i) instanceof MongoDBColumnTestReference) { + arguments.add("\"$" + ((MongoDBColumnTestReference) expr.getArgs().get(i)).getQueryString() + "\""); + continue; + } + if (expr.getArgs().get(i) instanceof NewFunctionNode) { + arguments.add(visitComputed((NewFunctionNode) expr.getArgs().get(i))); + } else { + throw new AssertionError(); + } + } + if (!(expr.getFunc() instanceof ComputedFunction)) { + throw new AssertionError(expr.getClass()); + } + + StringBuilder sb = new StringBuilder(); + sb.append("{"); + sb.append(((ComputedFunction) expr.getFunc()).getOperator()); + sb.append(": ["); + String helper = ""; + for (String arg : arguments) { + sb.append(helper); + helper = ", "; + sb.append(arg); + } + sb.append("]}"); + return sb.toString(); + } + + public String visit(MongoDBUnaryLogicalOperatorNode expr) { + String inner = visitLog(expr.getExpr()); + return "{ " + expr.operator().getTextRepresentation() + ": [" + inner + "]}"; + } + + public String visit(MongoDBBinaryLogicalNode expr) { + String left = visitLog(expr.getLeft()); + String right = visitLog(expr.getRight()); + + return "{" + expr.operator().getTextRepresentation() + ":[" + left + "," + right + "]}"; + } + + public String visit(MongoDBBinaryComparisonNode expr) { + Node left = expr.getLeft(); + Node right = expr.getRight(); + assert left instanceof MongoDBColumnTestReference; + assert right instanceof MongoDBConstant; + + return "{\"" + ((MongoDBColumnTestReference) left).getQueryString() + "\": {" + + expr.operator().getTextRepresentation() + ": " + ((MongoDBConstant) right).getLogValue() + "}}"; + } + + public String visit(MongoDBRegexNode expr) { + Node left = expr.getLeft(); + Node right = expr.getRight(); + + return "{\"" + ((MongoDBColumnTestReference) left).getQueryString() + "\": {" + + expr.operator().getTextRepresentation() + ": \'" + + ((MongoDBConstant.MongoDBStringConstant) right).getStringValue() + "\', $options: \'" + + expr.getOptions() + "\'}}"; + } + + @Override + public void visit(MongoDBConstant c) { + throw new UnsupportedOperationException(); + } + + @Override + public void visit(MongoDBSelect select) { + hasFilter = select.hasFilter(); + mainTableName = select.getMainTableName(); + setLookups(select); + if (hasFilter) { + setFilter(select); + } + setProjects(select); + withCount = select.getWithCountClause(); + } + + private void setFilter(MongoDBSelect select) { + filter = visitLog(select.getFilterClause()); + } + + private void setLookups(MongoDBSelect select) { + lookups = new ArrayList<>(); + for (MongoDBColumnTestReference testReference : select.getLookupList()) { + if (testReference.inMainTable()) { + continue; + } + String newLookup = "{ $lookup: { from: \"" + testReference.getTableName() + "\", localField: \"" + + select.getJoinColumn().getPlainName() + "\", foreignField: \"" + testReference.getPlainName() + + "\", as: \"" + testReference.getQueryString() + "\"}},\n"; + lookups.add(newLookup); + } + } + + private void setProjects(MongoDBSelect select) { + StringBuilder sb = new StringBuilder(); + sb.append("{"); + String helper = ""; + for (MongoDBColumnTestReference reference : select.getProjectionList()) { + sb.append(helper); + helper = ","; + sb.append("\"").append(reference.getQueryString()).append("\"").append(": 1"); + } + sb.append("\n"); + if (select.hasComputed()) { + String name = "computed"; + int number = 0; + for (Node expressionNode : select.getComputedClause()) { + sb.append(helper); + helper = ",\n"; + sb.append("\"" + name + number + "\": " + visitComputed(expressionNode)); + number++; + } + } + sb.append("}"); + projects = sb.toString(); + } + + public String getStringLog() { + StringBuilder sb = new StringBuilder(); + sb.append("db.").append(mainTableName).append(".aggregate([\n"); + for (String lookup : lookups) { + sb.append(lookup); + } + if (hasFilter) { + sb.append("{ $match: "); + sb.append(filter); + sb.append("},\n"); + } + sb.append("{ $project : "); + sb.append(projects); + sb.append("}"); + if (withCount) { + sb.append(",\n"); + sb.append(" {$count: \"count\"}\n"); + } + sb.append("])\n"); + return sb.toString(); + } +} diff --git a/src/sqlancer/mongodb/visitor/MongoDBToQueryVisitor.java b/src/sqlancer/mongodb/visitor/MongoDBToQueryVisitor.java new file mode 100644 index 000000000..8efbf97e4 --- /dev/null +++ b/src/sqlancer/mongodb/visitor/MongoDBToQueryVisitor.java @@ -0,0 +1,184 @@ +package sqlancer.mongodb.visitor; + +import static com.mongodb.client.model.Aggregates.match; +import static com.mongodb.client.model.Aggregates.project; +import static com.mongodb.client.model.Projections.fields; +import static com.mongodb.client.model.Projections.include; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +import org.bson.Document; +import org.bson.conversions.Bson; + +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; + +import sqlancer.common.ast.newast.NewFunctionNode; +import sqlancer.common.ast.newast.Node; +import sqlancer.mongodb.ast.MongoDBBinaryComparisonNode; +import sqlancer.mongodb.ast.MongoDBBinaryLogicalNode; +import sqlancer.mongodb.ast.MongoDBConstant; +import sqlancer.mongodb.ast.MongoDBConstant.MongoDBStringConstant; +import sqlancer.mongodb.ast.MongoDBExpression; +import sqlancer.mongodb.ast.MongoDBRegexNode; +import sqlancer.mongodb.ast.MongoDBSelect; +import sqlancer.mongodb.ast.MongoDBUnaryLogicalOperatorNode; +import sqlancer.mongodb.gen.MongoDBComputedExpressionGenerator.ComputedFunction; +import sqlancer.mongodb.test.MongoDBColumnTestReference; + +public class MongoDBToQueryVisitor extends MongoDBVisitor { + + private List lookup; + private Bson filter; + private Bson projection; + private Bson count; + private boolean hasFilter; + private boolean hasCountClause; + + public Bson visitBson(Node expr) { + if (expr instanceof MongoDBUnaryLogicalOperatorNode) { + return visit((MongoDBUnaryLogicalOperatorNode) expr); + } else if (expr instanceof MongoDBBinaryLogicalNode) { + return visit((MongoDBBinaryLogicalNode) expr); + } else if (expr instanceof MongoDBBinaryComparisonNode) { + return visit((MongoDBBinaryComparisonNode) expr); + } else if (expr instanceof MongoDBRegexNode) { + return visit((MongoDBRegexNode) expr); + } else { + throw new AssertionError(expr.getClass()); + } + } + + public Document visitComputed(Node expr) { + if (expr instanceof NewFunctionNode) { + return visitComputed((NewFunctionNode) expr); + } else { + throw new AssertionError(expr.getClass()); + } + } + + public Document visitComputed(NewFunctionNode expr) { + List visitedArgs = new ArrayList<>(); + for (int i = 0; i < expr.getArgs().size(); i++) { + if (expr.getArgs().get(i) instanceof MongoDBConstant) { + visitedArgs.add(((MongoDBConstant) expr.getArgs().get(i)).getSerializedValue()); + continue; + } + if (expr.getArgs().get(i) instanceof MongoDBColumnTestReference) { + visitedArgs.add("$" + ((MongoDBColumnTestReference) expr.getArgs().get(i)).getQueryString()); + continue; + } + if (expr.getArgs().get(i) instanceof NewFunctionNode) { + visitedArgs.add(visitComputed((NewFunctionNode) expr.getArgs().get(i))); + } else { + throw new AssertionError(); + } + } + if (expr.getFunc() instanceof ComputedFunction) { + return new Document(((ComputedFunction) expr.getFunc()).getOperator(), visitedArgs); + } else { + throw new AssertionError(expr.getClass()); + } + + } + + public Bson visit(MongoDBUnaryLogicalOperatorNode expr) { + Bson inner = visitBson(expr.getExpr()); + return expr.operator().applyOperator(inner); + } + + public Bson visit(MongoDBBinaryLogicalNode expr) { + Bson left = visitBson(expr.getLeft()); + Bson right = visitBson(expr.getRight()); + return expr.operator().applyOperator(left, right); + } + + public Bson visit(MongoDBRegexNode expr) { + Node left = expr.getLeft(); + Node right = expr.getRight(); + + String columnName = ((MongoDBColumnTestReference) left).getQueryString(); + + return expr.operator().applyOperator(columnName, (MongoDBStringConstant) right, expr.getOptions()); + } + + public Bson visit(MongoDBBinaryComparisonNode expr) { + Node left = expr.getLeft(); + Node right = expr.getRight(); + assert left instanceof MongoDBColumnTestReference; + assert right instanceof MongoDBConstant; + + String columnName = ((MongoDBColumnTestReference) left).getQueryString(); + return expr.operator().applyOperator(columnName, (MongoDBConstant) right); + } + + @Override + public void visit(MongoDBConstant c) { + throw new UnsupportedOperationException(); + } + + @Override + public void visit(MongoDBSelect select) { + hasFilter = select.hasFilter(); + setLookup(select); + if (hasFilter) { + setFilter(select); + } + setProjection(select); + hasCountClause = select.getWithCountClause(); + if (hasCountClause) { + setCount(); + } + } + + private void setCount() { + count = Aggregates.count("count"); + } + + private void setFilter(MongoDBSelect select) { + filter = match(this.visitBson(select.getFilterClause())); + } + + private void setLookup(MongoDBSelect select) { + lookup = new ArrayList<>(); + for (MongoDBColumnTestReference reference : select.getLookupList()) { + if (reference.inMainTable()) { + continue; + } + lookup.add(Aggregates.lookup(reference.getTableName(), select.getJoinColumn().getPlainName(), + reference.getPlainName(), reference.getQueryString())); + } + } + + private void setProjection(MongoDBSelect select) { + List stringProjects = new ArrayList<>(); + for (MongoDBColumnTestReference ref : select.getProjectionList()) { + stringProjects.add(ref.getQueryString()); + } + List projections = new ArrayList<>(); + projections.add(include(stringProjects)); + if (select.hasComputed()) { + String name = "computed"; + int number = 0; + for (Node expressionNode : select.getComputedClause()) { + projections.add(Projections.computed(name + number, visitComputed(expressionNode))); + number++; + } + } + projection = project(fields(projections)); + } + + public List getPipeline() { + List result = new ArrayList<>(lookup); + if (hasFilter) { + result.add(filter); + } + result.add(projection); + if (hasCountClause) { + result.add(count); + } + return result; + } +} diff --git a/src/sqlancer/mongodb/visitor/MongoDBVisitor.java b/src/sqlancer/mongodb/visitor/MongoDBVisitor.java new file mode 100644 index 000000000..e02a50f02 --- /dev/null +++ b/src/sqlancer/mongodb/visitor/MongoDBVisitor.java @@ -0,0 +1,45 @@ +package sqlancer.mongodb.visitor; + +import java.util.List; + +import org.bson.conversions.Bson; + +import sqlancer.common.ast.newast.Node; +import sqlancer.mongodb.ast.MongoDBConstant; +import sqlancer.mongodb.ast.MongoDBExpression; +import sqlancer.mongodb.ast.MongoDBSelect; + +public abstract class MongoDBVisitor { + + public abstract void visit(MongoDBConstant c); + + public abstract void visit(MongoDBSelect s); + + public void visit(Node expr) { + if (expr instanceof MongoDBConstant) { + visit((MongoDBConstant) expr); + } else if (expr instanceof MongoDBSelect) { + visit((MongoDBSelect) expr); + } else { + throw new AssertionError(expr.getClass()); + } + } + + public static List asQuery(Node expr) { + MongoDBToQueryVisitor visitor = new MongoDBToQueryVisitor(); + visitor.visit(expr); + return visitor.getPipeline(); + } + + public static String asStringLog(Node expr) { + MongoDBToLogVisitor visitor = new MongoDBToLogVisitor(); + visitor.visit(expr); + return visitor.getStringLog(); + } + + public static Node cleanNegations(Node expr) { + MongoDBNegateVisitor visitor = new MongoDBNegateVisitor(false); + visitor.visit(expr); + return visitor.getNegatedExpression(); + } +}