diff --git a/server/build.gradle b/server/build.gradle index 706423bb..e3d1da44 100644 --- a/server/build.gradle +++ b/server/build.gradle @@ -24,6 +24,9 @@ dependencies { implementation 'com.google.guava:guava:31.0.1-jre' implementation 'org.reflections:reflections:0.10.1' + implementation 'com.github.jsqlparser:jsqlparser:3.2' + implementation 'org.apache.commons:commons-lang3:3.0' + // Experimental implementation files('libs/keva-ioc-0.1.0-SNAPSHOT.jar') diff --git a/server/src/main/java/dev/keva/server/command/impl/kql/Kql.java b/server/src/main/java/dev/keva/server/command/impl/kql/Kql.java new file mode 100644 index 00000000..1cda58ad --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/kql/Kql.java @@ -0,0 +1,98 @@ +package dev.keva.server.command.impl.kql; + +import dev.keva.ioc.annotation.Autowired; +import dev.keva.ioc.annotation.Component; +import dev.keva.protocol.resp.reply.*; +import dev.keva.server.command.annotation.CommandImpl; +import dev.keva.server.command.annotation.Execute; +import dev.keva.server.command.annotation.ParamLength; +import dev.keva.server.command.impl.kql.manager.KqlManager; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.statement.Statement; +import net.sf.jsqlparser.statement.create.table.CreateTable; +import net.sf.jsqlparser.statement.delete.Delete; +import net.sf.jsqlparser.statement.drop.Drop; +import net.sf.jsqlparser.statement.insert.Insert; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.statement.update.Update; + +import java.util.List; + +@Component +@CommandImpl("kql") +@ParamLength(1) +public class Kql { + private final KqlManager kqlManager; + + @Autowired + public Kql(KqlManager kqlManager) { + this.kqlManager = kqlManager; + } + + @Execute + public Reply execute(byte[] sqlBytes) { + String sql = new String(sqlBytes); + Statement stmt; + try { + stmt = kqlManager.parse(sql); + } catch (JSQLParserException e) { + return new ErrorReply("ERR " + e.getMessage()); + } + if (stmt instanceof CreateTable) { + kqlManager.create(stmt); + Reply[] replies = new Reply[2]; + replies[0] = new StatusReply("DONE"); + replies[1] = new IntegerReply(0); + return new MultiBulkReply(replies); + } else if (stmt instanceof Drop) { + kqlManager.drop(stmt); + Reply[] replies = new Reply[2]; + replies[0] = new StatusReply("DONE"); + replies[1] = new IntegerReply(0); + return new MultiBulkReply(replies); + } else if (stmt instanceof Insert) { + kqlManager.insert(stmt); + Reply[] replies = new Reply[2]; + replies[0] = new StatusReply("DONE"); + replies[1] = new IntegerReply(1); + return new MultiBulkReply(replies); + } else if (stmt instanceof Update) { + int numOfUpdated = kqlManager.update(stmt); + Reply[] replies = new Reply[2]; + replies[0] = new StatusReply("DONE"); + replies[1] = new IntegerReply(numOfUpdated); + return new MultiBulkReply(replies); + } else if (stmt instanceof Delete) { + int numOfDeleted = kqlManager.delete(stmt); + Reply[] replies = new Reply[2]; + replies[0] = new StatusReply("DONE"); + replies[1] = new IntegerReply(numOfDeleted); + return new MultiBulkReply(replies); + } else if (stmt instanceof Select) { + List> result = kqlManager.select(stmt); + Reply[] rowReplies = new Reply[result.size()]; + for (int i = 0; i < result.size(); i++) { + Reply[] columnReplies = new Reply[result.get(i).size()]; + for (int j = 0; j < result.get(i).size(); j++) { + if (result.get(i).get(j) instanceof String) { + columnReplies[j] = new BulkReply((String) result.get(i).get(j)); + } else if (result.get(i).get(j) instanceof Integer) { + columnReplies[j] = new IntegerReply((Integer) result.get(i).get(j)); + } else if (result.get(i).get(j) instanceof Long) { + columnReplies[j] = new IntegerReply((Long) result.get(i).get(j)); + } else if (result.get(i).get(j) instanceof Double) { + columnReplies[j] = new BulkReply(result.get(i).get(j).toString()); + } else if (result.get(i).get(j) instanceof Boolean) { + columnReplies[j] = new IntegerReply((Boolean) result.get(i).get(j) ? 1 : 0); + } else if (result.get(i).get(j) == null) { + columnReplies[j] = BulkReply.NIL_REPLY; + } + } + rowReplies[i] = new MultiBulkReply(columnReplies); + } + return new MultiBulkReply(rowReplies); + } else { + return new ErrorReply("ERR unsupported statement"); + } + } +} diff --git a/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaColumnDefinition.java b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaColumnDefinition.java new file mode 100644 index 00000000..08234980 --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaColumnDefinition.java @@ -0,0 +1,11 @@ +package dev.keva.server.command.impl.kql.manager; + +import lombok.AllArgsConstructor; + +import java.io.Serializable; + +@AllArgsConstructor +public class KevaColumnDefinition implements Serializable { + public String name; + public String type; +} diff --git a/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaColumnFinder.java b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaColumnFinder.java new file mode 100644 index 00000000..caf64feb --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaColumnFinder.java @@ -0,0 +1,15 @@ +package dev.keva.server.command.impl.kql.manager; + +import java.util.List; + +public class KevaColumnFinder { + public static int findColumn(String columnName, List kevaColumns) { + for (int i = 0; i < kevaColumns.size(); i++) { + KevaColumnDefinition kevaColumn = kevaColumns.get(i); + if (kevaColumn.name.equals(columnName)) { + return i; + } + } + return -1; + } +} diff --git a/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaSQLConvertUtil.java b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaSQLConvertUtil.java new file mode 100644 index 00000000..82bf4f33 --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaSQLConvertUtil.java @@ -0,0 +1,25 @@ +package dev.keva.server.command.impl.kql.manager; + +public class KevaSQLConvertUtil { + public static Object convertToRowData(String type, String value) { + value = KevaSQLStringUtil.escape(value); + if (value.equalsIgnoreCase("null")) { + return null; + } + if (type.equals("CHAR") || type.equals("VARCHAR") || type.equals("TEXT")) { + return value; + } else if (type.equals("INT") || type.equals("INTEGER")) { + return Integer.parseInt(value); + } else if (type.equals("BIGINT")) { + return Long.parseLong(value); + } else if (type.equals("DOUBLE")) { + return Double.parseDouble(value); + } else if (type.equals("FLOAT")) { + return Float.parseFloat(value); + } else if (type.equals("BOOL") || type.equals("BOOLEAN")) { + return Boolean.parseBoolean(value); + } else { + throw new KevaSQLException("unknown type: " + type); + } + } +} diff --git a/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaSQLException.java b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaSQLException.java new file mode 100644 index 00000000..a73e149d --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaSQLException.java @@ -0,0 +1,7 @@ +package dev.keva.server.command.impl.kql.manager; + +public class KevaSQLException extends RuntimeException { + public KevaSQLException(String message) { + super(message); + } +} diff --git a/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaSQLResponseUtil.java b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaSQLResponseUtil.java new file mode 100644 index 00000000..f5519160 --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaSQLResponseUtil.java @@ -0,0 +1,10 @@ +package dev.keva.server.command.impl.kql.manager; + +import java.util.Collections; +import java.util.List; + +public class KevaSQLResponseUtil { + public static List> singleSelectResponse(Object object) { + return Collections.singletonList(Collections.singletonList(object)); + } +} diff --git a/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaSQLStringUtil.java b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaSQLStringUtil.java new file mode 100644 index 00000000..7afd8b9c --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaSQLStringUtil.java @@ -0,0 +1,10 @@ +package dev.keva.server.command.impl.kql.manager; + +public class KevaSQLStringUtil { + public static String escape(String str) { + if (str.startsWith("'") && str.endsWith("'")) { + return str.replaceAll("^.|.$", ""); + } + return str; + } +} diff --git a/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaTable.java b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaTable.java new file mode 100644 index 00000000..16afd7df --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KevaTable.java @@ -0,0 +1,20 @@ +package dev.keva.server.command.impl.kql.manager; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +import java.io.Serializable; +import java.util.List; + +@RequiredArgsConstructor +public class KevaTable implements Serializable { + @Getter + private final List columns; + + @Getter + private long increment = 1; + + public void increment() { + increment++; + } +} diff --git a/server/src/main/java/dev/keva/server/command/impl/kql/manager/KqlExpressionVisitor.java b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KqlExpressionVisitor.java new file mode 100644 index 00000000..2acfc070 --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KqlExpressionVisitor.java @@ -0,0 +1,429 @@ +package dev.keva.server.command.impl.kql.manager; + +import lombok.Getter; +import net.sf.jsqlparser.expression.*; +import net.sf.jsqlparser.expression.operators.arithmetic.*; +import net.sf.jsqlparser.expression.operators.conditional.AndExpression; +import net.sf.jsqlparser.expression.operators.conditional.OrExpression; +import net.sf.jsqlparser.expression.operators.relational.*; +import net.sf.jsqlparser.schema.Column; +import net.sf.jsqlparser.statement.select.SubSelect; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +public class KqlExpressionVisitor implements ExpressionVisitor { + private final List kevaColumns; + private List> result; + @Getter + private List> temp; + + public KqlExpressionVisitor(List> result, List kevaColumns) { + this.result = result; + this.kevaColumns = kevaColumns; + this.temp = result; + } + + @Override + public void visit(BitwiseRightShift aThis) { + + } + + @Override + public void visit(BitwiseLeftShift aThis) { + + } + + @Override + public void visit(NullValue nullValue) { + + } + + @Override + public void visit(Function function) { + + } + + @Override + public void visit(SignedExpression signedExpression) { + + } + + @Override + public void visit(JdbcParameter jdbcParameter) { + + } + + @Override + public void visit(JdbcNamedParameter jdbcNamedParameter) { + + } + + @Override + public void visit(DoubleValue doubleValue) { + + } + + @Override + public void visit(LongValue longValue) { + + } + + @Override + public void visit(HexValue hexValue) { + + } + + @Override + public void visit(DateValue dateValue) { + + } + + @Override + public void visit(TimeValue timeValue) { + + } + + @Override + public void visit(TimestampValue timestampValue) { + + } + + @Override + public void visit(Parenthesis parenthesis) { + + } + + @Override + public void visit(StringValue stringValue) { + + } + + @Override + public void visit(Addition addition) { + + } + + @Override + public void visit(Division division) { + + } + + @Override + public void visit(IntegerDivision division) { + + } + + @Override + public void visit(Multiplication multiplication) { + + } + + @Override + public void visit(Subtraction subtraction) { + + } + + @Override + public void visit(AndExpression andExpression) { + andExpression.getLeftExpression().accept(this); + result = temp; + andExpression.getRightExpression().accept(this); + } + + @Override + public void visit(OrExpression orExpression) { + orExpression.getLeftExpression().accept(this); + List> combine = new ArrayList<>(temp); + orExpression.getRightExpression().accept(this); + combine.addAll(temp); + temp = combine; + } + + @Override + public void visit(Between between) { + + } + + @Override + public void visit(EqualsTo equalsTo) { + String columnName = equalsTo.getLeftExpression().toString(); + int columnIndex = KevaColumnFinder.findColumn(columnName, kevaColumns); + String type = kevaColumns.get(columnIndex).type; + String valueStr = KevaSQLStringUtil.escape(equalsTo.getRightExpression().toString()); + if (type.equals("CHAR") || type.equals("VARCHAR") || type.equals("TEXT")) { + temp = result.stream() + .filter(row -> row.get(columnIndex).equals(valueStr)) + .collect(Collectors.toList()); + } else if (type.equals("INT") || type.equals("INTEGER")) { + int value = Integer.parseInt(valueStr); + temp = result.stream() + .filter(row -> (int) row.get(columnIndex) == value) + .collect(Collectors.toList()); + } else if (type.equals("FLOAT") || type.equals("DOUBLE")) { + double value = Double.parseDouble(valueStr); + temp = result.stream() + .filter(row -> (double) row.get(columnIndex) == value) + .collect(Collectors.toList()); + } else if (type.equals("BOOL") || type.equals("BOOLEAN")) { + boolean value = Boolean.parseBoolean(valueStr); + temp = result.stream() + .filter(row -> (boolean) row.get(columnIndex) == value) + .collect(Collectors.toList()); + } + } + + @Override + public void visit(GreaterThan greaterThan) { + + } + + @Override + public void visit(GreaterThanEquals greaterThanEquals) { + + } + + @Override + public void visit(InExpression inExpression) { + + } + + @Override + public void visit(FullTextSearch fullTextSearch) { + + } + + @Override + public void visit(IsNullExpression isNullExpression) { + + } + + @Override + public void visit(IsBooleanExpression isBooleanExpression) { + + } + + @Override + public void visit(LikeExpression likeExpression) { + + } + + @Override + public void visit(MinorThan minorThan) { + + } + + @Override + public void visit(MinorThanEquals minorThanEquals) { + + } + + @Override + public void visit(NotEqualsTo notEqualsTo) { + String columnName = notEqualsTo.getLeftExpression().toString(); + int columnIndex = KevaColumnFinder.findColumn(columnName, kevaColumns); + String type = kevaColumns.get(columnIndex).type; + String valueStr = KevaSQLStringUtil.escape(notEqualsTo.getRightExpression().toString()); + if (type.equals("CHAR") || type.equals("VARCHAR") || type.equals("TEXT")) { + temp = result.stream() + .filter(row -> !row.get(columnIndex).equals(valueStr)) + .collect(Collectors.toList()); + } else if (type.equals("INT") || type.equals("INTEGER")) { + int value = Integer.parseInt(valueStr); + temp = result.stream() + .filter(row -> (int) row.get(columnIndex) != value) + .collect(Collectors.toList()); + } else if (type.equals("FLOAT") || type.equals("DOUBLE")) { + double value = Double.parseDouble(valueStr); + temp = result.stream() + .filter(row -> (double) row.get(columnIndex) != value) + .collect(Collectors.toList()); + } else if (type.equals("BOOL") || type.equals("BOOLEAN")) { + boolean value = Boolean.parseBoolean(valueStr); + temp = result.stream() + .filter(row -> (boolean) row.get(columnIndex) != value) + .collect(Collectors.toList()); + } + } + + @Override + public void visit(Column tableColumn) { + + } + + @Override + public void visit(SubSelect subSelect) { + + } + + @Override + public void visit(CaseExpression caseExpression) { + + } + + @Override + public void visit(WhenClause whenClause) { + + } + + @Override + public void visit(ExistsExpression existsExpression) { + + } + + @Override + public void visit(AllComparisonExpression allComparisonExpression) { + + } + + @Override + public void visit(AnyComparisonExpression anyComparisonExpression) { + + } + + @Override + public void visit(Concat concat) { + + } + + @Override + public void visit(Matches matches) { + + } + + @Override + public void visit(BitwiseAnd bitwiseAnd) { + + } + + @Override + public void visit(BitwiseOr bitwiseOr) { + + } + + @Override + public void visit(BitwiseXor bitwiseXor) { + + } + + @Override + public void visit(CastExpression cast) { + + } + + @Override + public void visit(Modulo modulo) { + + } + + @Override + public void visit(AnalyticExpression aexpr) { + + } + + @Override + public void visit(ExtractExpression eexpr) { + + } + + @Override + public void visit(IntervalExpression iexpr) { + + } + + @Override + public void visit(OracleHierarchicalExpression oexpr) { + + } + + @Override + public void visit(RegExpMatchOperator rexpr) { + + } + + @Override + public void visit(JsonExpression jsonExpr) { + + } + + @Override + public void visit(JsonOperator jsonExpr) { + + } + + @Override + public void visit(RegExpMySQLOperator regExpMySQLOperator) { + + } + + @Override + public void visit(UserVariable var) { + + } + + @Override + public void visit(NumericBind bind) { + + } + + @Override + public void visit(KeepExpression aexpr) { + + } + + @Override + public void visit(MySQLGroupConcat groupConcat) { + + } + + @Override + public void visit(ValueListExpression valueList) { + + } + + @Override + public void visit(RowConstructor rowConstructor) { + + } + + @Override + public void visit(OracleHint hint) { + + } + + @Override + public void visit(TimeKeyExpression timeKeyExpression) { + + } + + @Override + public void visit(DateTimeLiteralExpression literal) { + + } + + @Override + public void visit(NotExpression aThis) { + + } + + @Override + public void visit(NextValExpression aThis) { + + } + + @Override + public void visit(CollateExpression aThis) { + + } + + @Override + public void visit(SimilarToExpression aThis) { + + } + + @Override + public void visit(ArrayExpression aThis) { + + } +} diff --git a/server/src/main/java/dev/keva/server/command/impl/kql/manager/KqlManager.java b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KqlManager.java new file mode 100644 index 00000000..101d64c8 --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/kql/manager/KqlManager.java @@ -0,0 +1,391 @@ +package dev.keva.server.command.impl.kql.manager; + +import dev.keva.ioc.annotation.Autowired; +import dev.keva.ioc.annotation.Component; +import dev.keva.store.KevaDatabase; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.operators.relational.ExpressionList; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.schema.Column; +import net.sf.jsqlparser.schema.Table; +import net.sf.jsqlparser.statement.Statement; +import net.sf.jsqlparser.statement.create.table.ColumnDefinition; +import net.sf.jsqlparser.statement.create.table.CreateTable; +import net.sf.jsqlparser.statement.delete.Delete; +import net.sf.jsqlparser.statement.drop.Drop; +import net.sf.jsqlparser.statement.insert.Insert; +import net.sf.jsqlparser.statement.select.OrderByElement; +import net.sf.jsqlparser.statement.select.PlainSelect; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.statement.select.SelectItem; +import net.sf.jsqlparser.statement.update.Update; +import net.sf.jsqlparser.util.TablesNamesFinder; +import org.apache.commons.lang3.SerializationUtils; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +@Component +public class KqlManager { + private final KevaDatabase database; + + @Autowired + public KqlManager(KevaDatabase database) { + this.database = database; + } + + public Statement parse(String sql) throws JSQLParserException { + return CCJSqlParserUtil.parse(sql); + } + + public Object sqlGet(String key) { + byte[] got = database.get(key.getBytes()); + if (got == null) { + return null; + } + return SerializationUtils.deserialize(got); + } + + public void sqlPut(String key, Object object) { + database.put(key.getBytes(), SerializationUtils.serialize((Serializable) object)); + } + + public void sqlRemove(String key) { + database.remove(key.getBytes()); + } + + public boolean sqlContainsKey(String key) { + return sqlGet(key) != null; + } + + public void create(Statement stmt) { + CreateTable createTable = (CreateTable) stmt; + String tableName = createTable.getTable().getName(); + if (sqlGet(tableName) != null) { + throw new KevaSQLException("table " + tableName + " already exists"); + } + List columnDefinitions = createTable.getColumnDefinitions(); + List kevaColumns = new ArrayList<>(); + for (ColumnDefinition columnDefinition : columnDefinitions) { + KevaColumnDefinition kevaColumn = new KevaColumnDefinition + (columnDefinition.getColumnName().toLowerCase(), + columnDefinition.getColDataType().getDataType().toUpperCase()); + kevaColumns.add(kevaColumn); + } + KevaTable kevaTable = new KevaTable(kevaColumns); + sqlPut(tableName, kevaTable); + } + + public void drop(Statement stmt) { + Drop dropStatement = (Drop) stmt; + String tableName = dropStatement.getName().getName(); + KevaTable kevaTable = (KevaTable) sqlGet(tableName); + if (kevaTable == null) { + throw new KevaSQLException("table " + tableName + " does not exist"); + } + + for (int i = 0; i < kevaTable.getIncrement(); i++) { + String key = tableName + ":" + i; + if (sqlContainsKey(key)) { + sqlRemove(key); + } + } + + sqlRemove(tableName); + } + + public void insert(Statement stmt) { + Insert insertStatement = (Insert) stmt; + Table table = insertStatement.getTable(); + String tableName = table.getName(); + KevaTable kevaTable = (KevaTable) sqlGet(tableName); + if (kevaTable == null) { + throw new KevaSQLException("table " + tableName + " does not exist"); + } + List columnDefinitions = kevaTable.getColumns(); + List insertValuesExpression = ((ExpressionList) insertStatement.getItemsList()).getExpressions(); + List values = new ArrayList<>(); + for (Expression expression : insertValuesExpression) { + values.add(expression.toString()); + } + + List result = new ArrayList<>(columnDefinitions.size()); + for (int i = 0; i < columnDefinitions.size(); i++) { + result.add(null); + } + List insertColumns = insertStatement.getColumns(); + if (insertColumns != null) { + for (Column column : insertColumns) { + int index = KevaColumnFinder.findColumn(column.getColumnName(), columnDefinitions); + if (index == -1) { + throw new KevaSQLException("column " + column + " does not exist"); + } + String type = columnDefinitions.get(index).type; + String value = values.get(index); + result.set(index, KevaSQLConvertUtil.convertToRowData(type, value)); + } + } else { + for (int i = 0; i < columnDefinitions.size(); i++) { + String type = columnDefinitions.get(i).type; + String value = values.get(i); + result.set(i, KevaSQLConvertUtil.convertToRowData(type, value)); + } + } + String id = Long.toString(kevaTable.getIncrement()); + sqlPut(tableName + ":" + id, result); + kevaTable.increment(); + sqlPut(tableName, kevaTable); + } + + @SuppressWarnings("unchecked") + public int update(Statement stmt) { + Update updateStatement = (Update) stmt; + Expression where = updateStatement.getWhere(); + if (where == null) { + throw new KevaSQLException("where clause is required"); + } + String tableName = updateStatement.getTable().getName(); + KevaTable kevaTable = (KevaTable) sqlGet(tableName); + if (kevaTable == null) { + throw new KevaSQLException("table " + tableName + " does not exist"); + } + List columnDefinitions = kevaTable.getColumns(); + List> result = new ArrayList<>(); + for (int i = 0; i < kevaTable.getIncrement(); i++) { + String key = tableName + ":" + i; + if (sqlContainsKey(key)) { + List value = (List) sqlGet(key); + result.addAll(Collections.singleton(value)); + } + } + KqlExpressionVisitor kqlExpressionVisitor = new KqlExpressionVisitor(result, columnDefinitions); + where.accept(kqlExpressionVisitor); + List> toBeUpdated = kqlExpressionVisitor.getTemp(); + int count = 0; + for (int i = 0; i < kevaTable.getIncrement(); i++) { + String key = tableName + ":" + i; + if (sqlContainsKey(key)) { + List value = (List) sqlGet(key); + if (value != null && toBeUpdated.contains(value)) { + // Update + List updateColumns = updateStatement.getColumns(); + for (int j = 0; j < updateColumns.size(); j++) { + int index = KevaColumnFinder.findColumn(updateColumns.get(j).getColumnName(), columnDefinitions); + if (index == -1) { + throw new KevaSQLException("column " + updateColumns.get(j).getColumnName() + " does not exist"); + } + String type = columnDefinitions.get(index).type; + String updatedValue = updateStatement.getExpressions().get(j).toString(); + value.set(index, KevaSQLConvertUtil.convertToRowData(type, updatedValue)); + } + sqlPut(key, value); + count++; + } + } + } + return count; + } + + @SuppressWarnings("unchecked") + public int delete(Statement stmt) { + Delete deleteStatement = (Delete) stmt; + Expression where = deleteStatement.getWhere(); + if (where == null) { + throw new KevaSQLException("where clause is required"); + } + String tableName = deleteStatement.getTable().getName(); + KevaTable kevaTable = (KevaTable) sqlGet(tableName); + if (kevaTable == null) { + throw new KevaSQLException("table " + tableName + " does not exist"); + } + List columnDefinitions = kevaTable.getColumns(); + List> result = new ArrayList<>(); + for (int i = 0; i < kevaTable.getIncrement(); i++) { + String key = tableName + ":" + i; + if (sqlContainsKey(key)) { + List value = (List) sqlGet(key); + result.addAll(Collections.singleton(value)); + } + } + KqlExpressionVisitor kqlExpressionVisitor = new KqlExpressionVisitor(result, columnDefinitions); + where.accept(kqlExpressionVisitor); + List> toBeDeleted = kqlExpressionVisitor.getTemp(); + int count = 0; + for (int i = 0; i < kevaTable.getIncrement(); i++) { + String key = tableName + ":" + i; + if (sqlContainsKey(key)) { + List value = (List) sqlGet(key); + if (toBeDeleted.contains(value)) { + sqlRemove(key); + count++; + } + } + } + return count; + } + + @SuppressWarnings("unchecked") + public List> select(Statement stmt) { + Select selectStatement = (Select) stmt; + TablesNamesFinder tablesNamesFinder = new TablesNamesFinder(); + List tableList = tablesNamesFinder.getTableList(selectStatement); + String tableName = tableList.get(0); + KevaTable kevaTable = (KevaTable) sqlGet(tableName); + if (kevaTable == null) { + throw new KevaSQLException("table " + tableName + " does not exist"); + } + List columnDefinitions = kevaTable.getColumns(); + List columns = new ArrayList<>(); + PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody(); + for (SelectItem selectItem : plainSelect.getSelectItems()) { + columns.add(selectItem.toString()); + } + List> result = new ArrayList<>(); + for (int i = 0; i < kevaTable.getIncrement(); i++) { + String key = tableName + ":" + i; + if (sqlContainsKey(key)) { + List value = (List) sqlGet(key); + List row = null; + for (String column : columns) { + if (column.equals("*") || column.equals("COUNT(*)") || + column.startsWith("COUNT(") || column.startsWith("SUM(") || column.startsWith("AVG(") || + column.startsWith("MIN(") || column.startsWith("MAX(")) { + result.addAll(Collections.singleton(value)); + break; + } else { + int index = KevaColumnFinder.findColumn(column, columnDefinitions); + if (index == -1) { + throw new KevaSQLException("column " + column + " does not exist"); + } + if (row == null) { + row = new ArrayList<>(); + } + row.add(value.get(index)); + } + } + if (row != null) { + result.add(row); + } + } + } + + List> proceededResult = selectProcess(plainSelect, result, columns, columnDefinitions); + return selectPostProcess(plainSelect, proceededResult, columns, columnDefinitions); + } + + private List> selectProcess(PlainSelect plainSelect, List> result, + List columns, List columnDefinitions) { + KqlExpressionVisitor kqlExpressionVisitor = new KqlExpressionVisitor(result, columnDefinitions); + if (plainSelect.getWhere() != null) { + plainSelect.getWhere().accept(kqlExpressionVisitor); + } + if (columns.get(0).equals("COUNT(*)")) { + return Collections.singletonList(Collections.singletonList(kqlExpressionVisitor.getTemp().size())); + } else if (columns.get(0).startsWith("COUNT(") || columns.get(0).startsWith("AVG(") || columns.get(0).startsWith("SUM(") + || columns.get(0).startsWith("MIN(") || columns.get(0).startsWith("MAX(")) { + String columnInBracket = columns.get(0).substring(columns.get(0).indexOf("(") + 1, columns.get(0).indexOf(")")); + int index = KevaColumnFinder.findColumn(columnInBracket, columnDefinitions); + if (index == -1) { + throw new KevaSQLException("column " + columnInBracket + " does not exist"); + } + if (columns.get(0).startsWith("COUNT(")) { + int count = (int) kqlExpressionVisitor.getTemp().stream().filter(row -> row.get(index) != null).count(); + return KevaSQLResponseUtil.singleSelectResponse(count); + } else if (columns.get(0).startsWith("MIN(")) { + Optional minOptional = kqlExpressionVisitor.getTemp().stream().filter(row -> row.get(index) != null) + .map(row -> row.get(index).toString()) + .map(Double::parseDouble) + .min(Double::compare); + return minOptional.map(KevaSQLResponseUtil::singleSelectResponse) + .orElseGet(() -> KevaSQLResponseUtil.singleSelectResponse(null)); + } else if (columns.get(0).startsWith("MAX(")) { + Optional maxOptional = kqlExpressionVisitor.getTemp().stream().filter(row -> row.get(index) != null) + .map(row -> row.get(index).toString()) + .map(Double::parseDouble) + .max(Double::compare); + return maxOptional.map(KevaSQLResponseUtil::singleSelectResponse) + .orElseGet(() -> KevaSQLResponseUtil.singleSelectResponse(null)); + } + Double sum = 0.0D; + int count = 0; + List> temp = kqlExpressionVisitor.getTemp(); + for (List row : temp) { + if (row.get(index) != null) { + sum += (Double) row.get(index); + count++; + } + } + if (columns.get(0).startsWith("AVG(")) { + return KevaSQLResponseUtil.singleSelectResponse(sum / count); + } + return KevaSQLResponseUtil.singleSelectResponse(sum); + } + return kqlExpressionVisitor.getTemp(); + } + + private List> selectPostProcess(PlainSelect plainSelect, List> result, + List columns, List columnDefinitions) { + List> postProcessedResult = result; + if (plainSelect.getOrderByElements() != null) { + Stream> sortedResultStream; + for (OrderByElement orderByElement : plainSelect.getOrderByElements()) { + String column = orderByElement.getExpression().toString(); + int index = KevaColumnFinder.findColumn(column, columnDefinitions); + if (index == -1) { + throw new KevaSQLException("column " + column + " does not exist"); + } + String type = columnDefinitions.get(index).type; + sortedResultStream = postProcessedResult.stream() + .sorted((p1, p2) -> { + if (p1.get(index) == null && p2.get(index) == null) { + return 0; + } else if (p1.get(index) == null) { + return -1; + } else if (p2.get(index) == null) { + return 1; + } else if (type.equals("TEXT") || type.equals("VARCHAR") || type.equals("CHAR")) { + return p1.get(index).toString().compareTo(p2.get(index).toString()); + } else if (type.equals("INTEGER") || type.equals("INT")) { + return ((Integer) p1.get(index)).compareTo((Integer) p2.get(index)); + } else if (type.equals("DOUBLE") || type.equals("FLOAT")) { + return ((Double) p1.get(index)).compareTo((Double) p2.get(index)); + } else if (type.equals("BOOLEAN")) { + return ((Boolean) p1.get(index)).compareTo((Boolean) p2.get(index)); + } else { + return 0; + } + }); + if (orderByElement.isAsc()) { + postProcessedResult = sortedResultStream.collect(Collectors.toList()); + } else { + postProcessedResult = sortedResultStream.collect(Collectors.toList()); + Collections.reverse(postProcessedResult); + } + } + } + if (plainSelect.getOffset() != null) { + long offset = plainSelect.getOffset().getOffset(); + postProcessedResult = postProcessedResult.stream().skip(offset).collect(Collectors.toList()); + + } + if (plainSelect.getLimit() != null) { + Stream> limitedResultStream = postProcessedResult.stream(); + if (plainSelect.getLimit().getOffset() != null) { + long offset = Long.parseLong(plainSelect.getLimit().getOffset().toString()); + limitedResultStream = limitedResultStream.skip(offset); + } + if (plainSelect.getLimit().getRowCount() != null) { + long rowCount = Long.parseLong(plainSelect.getLimit().getRowCount().toString()); + limitedResultStream = limitedResultStream.limit(rowCount); + } + postProcessedResult = limitedResultStream.collect(Collectors.toList()); + } + return postProcessedResult; + } +}