SQL解析器使用指南
hzero-starter-sqlparser是基于druid的sql解析器移植而来,未来会逐渐完善功能、增强api。
sql语句首先需要由parser转化为AST,并通过visitor获取sql语句的类型、from条件、where条件、groupby等信息。
AST(abstract syntax tree)
中文含义为抽象语法树,以下是一个简单sql语句的ast图。
SELECT a FROM table_a WHERE table_a.a = ‘a’;
简单理解,就是把字符串形式的sql语句结构化,变成具有层次的树形结构,便于sql语句的分析、改造。如果需要对现有的sql语句改造,则可以直接更新ast上的节点信息。
ast节点主要包括SQLObject、SQLExpr、SQLStatement三种抽象类型。
interface SQLObject {}
interface SQLExpr extends SQLObject {}
interface SQLStatement extends SQLObject {}
SQLExpr
// SQLName是一种的SQLExpr的Expr,包括SQLIdentifierExpr、SQLPropertyExpr等
public interface SQLName extends SQLExpr {}
// 例如 ID = 3 这里的ID是一个SQLIdentifierExpr
class SQLIdentifierExpr implements SQLExpr, SQLName {
String name;
}
// 例如 A.ID = 3 这里的A.ID是一个SQLPropertyExpr
class SQLPropertyExpr implements SQLExpr, SQLName {
SQLExpr owner;
String name;
}
// 例如 ID = 3 这是一个SQLBinaryOpExpr
// left是ID (SQLIdentifierExpr)
// right是3 (SQLIntegerExpr)
class SQLBinaryOpExpr implements SQLExpr {
SQLExpr left;
SQLExpr right;
SQLBinaryOperator operator;
}
// 例如 select * from where id = ?,这里的?是一个SQLVariantRefExpr,name是'?'
class SQLVariantRefExpr extends SQLExprImpl {
String name;
}
// 例如 ID = 3 这里的3是一个SQLIntegerExpr
public class SQLIntegerExpr extends SQLNumericLiteralExpr implements SQLValuableExpr {
Number number;
// 所有实现了SQLValuableExpr接口的SQLExpr都可以直接调用这个方法求值
@Override
public Object getValue() {
return this.number;
}
}
// 例如 NAME = 'jobs' 这里的'jobs'是一个SQLCharExpr
public class SQLCharExpr extends SQLTextLiteralExpr implements SQLValuableExpr{
String text;
}
SQLStatement
class SQLSelectStatement implements SQLStatement {
SQLSelect select;
}
class SQLUpdateStatement implements SQLStatement {
SQLExprTableSource tableSource;
List<SQLUpdateSetItem> items;
SQLExpr where;
}
class SQLDeleteStatement implements SQLStatement {
SQLTableSource tableSource;
SQLExpr where;
}
class SQLInsertStatement implements SQLStatement {
SQLExprTableSource tableSource;
List<SQLExpr> columns;
SQLSelect query;
}
SQLTableSource
class SQLTableSourceImpl extends SQLObjectImpl implements SQLTableSource {
String alias;
}
// 例如 select * from emp where i = 3,这里的from emp是一个SQLExprTableSource
// 其中expr是一个name=emp的SQLIdentifierExpr
class SQLExprTableSource extends SQLTableSourceImpl {
SQLExpr expr;
}
// 例如 select * from emp e inner join org o on e.org_id = o.id
// 其中left 'emp e' 是一个SQLExprTableSource,right 'org o'也是一个SQLExprTableSource
// condition 'e.org_id = o.id'是一个SQLBinaryOpExpr
class SQLJoinTableSource extends SQLTableSourceImpl {
SQLTableSource left;
SQLTableSource right;
JoinType joinType; // INNER_JOIN/CROSS_JOIN/LEFT_OUTER_JOIN/RIGHT_OUTER_JOIN/...
SQLExpr condition;
}
// 例如 select * from (select * from temp) a,这里第一层from(...)是一个SQLSubqueryTableSource
SQLSubqueryTableSource extends SQLTableSourceImpl {
SQLSelect select;
}
/*
例如
WITH RECURSIVE ancestors AS (
SELECT *
FROM org
UNION
SELECT f.*
FROM org f, ancestors a
WHERE f.id = a.parent_id
)
SELECT *
FROM ancestors;
这里的ancestors AS (...) 是一个SQLWithSubqueryClause.Entry
*/
class SQLWithSubqueryClause {
static class Entry extends SQLTableSourceImpl {
SQLSelect subQuery;
}
}
SQLSelectStatement
SQLSelectStatement包含一个SQLSelect,SQLSelect包含一个SQLSelectQuery,都是组成的关系。SQLSelectQuery有主要的两个派生类,分别是SQLSelectQueryBlock和SQLUnionQuery。
class SQLSelect extends SQLObjectImpl {
SQLWithSubqueryClause withSubQuery;
SQLSelectQuery query;
}
interface SQLSelectQuery extends SQLObject {}
class SQLSelectQueryBlock implements SQLSelectQuery {
List<SQLSelectItem> selectList;
SQLTableSource from;
SQLExprTableSource into;
SQLExpr where;
SQLSelectGroupByClause groupBy;
SQLOrderBy orderBy;
SQLLimit limit;
}
class SQLUnionQuery implements SQLSelectQuery {
SQLSelectQuery left;
SQLSelectQuery right;
SQLUnionOperator operator; // UNION/UNION_ALL/MINUS/INTERSECT
}
SQLCreateTableStatement
public class SQLCreateTableStatement extends SQLStatementImpl implements SQLDDLStatement, SQLCreateStatement {
SQLExprTableSource tableSource;
List<SQLTableElement> tableElementList;
Select select;
// 忽略大小写的查找SQLCreateTableStatement中的SQLColumnDefinition
public SQLColumnDefinition findColumn(String columName) {}
// 忽略大小写的查找SQLCreateTableStatement中的column关联的索引
public SQLTableElement findIndex(String columnName) {}
// 是否外键依赖另外一个表
public boolean isReferenced(String tableName) {}
}
使用案例
ast全语法示例代码
public class Test {
public static void main(String[] args) {
String sql = "select a,b from (select * from table_a) temp where temp.a = 'a';";
// 解析
List<SQLStatement> statements = SQLUtils.parseStatements(sql, JdbcConstants.MYSQL);
// 只考虑一条语句
SQLStatement statement = statements.get(0);
// 1 查询语句
SQLSelectStatement sqlSelectStatement = (SQLSelectStatement) statement;
SQLSelectQuery sqlSelectQuery = sqlSelectStatement.getSelect().getQuery();
// 1.1 非union的查询语句
if (sqlSelectQuery instanceof SQLSelectQueryBlock) {
SQLSelectQueryBlock sqlSelectQueryBlock = (SQLSelectQueryBlock) sqlSelectQuery;
// 1.1.1获取字段列表
List<SQLSelectItem> selectItems = sqlSelectQueryBlock.getSelectList();
selectItems.forEach(x -> {
// 处理---------------------
});
// 1.1.2 获取表
SQLTableSource table = sqlSelectQueryBlock.getFrom();
// 1.1.2.1 普通单表
if (table instanceof SQLExprTableSource) {
// 处理---------------------
// 1.1.2.2 join多表
} else if (table instanceof SQLJoinTableSource) {
// 处理---------------------
// 1.1.2.3 子查询作为表
} else if (table instanceof SQLSubqueryTableSource) {
// 处理---------------------
}
// 1.1.3 获取where条件
SQLExpr where = sqlSelectQueryBlock.getWhere();
// 1.1.3.1 如果是二元表达式
if (where instanceof SQLBinaryOpExpr) {
SQLBinaryOpExpr sqlBinaryOpExpr = (SQLBinaryOpExpr) where;
SQLExpr left = sqlBinaryOpExpr.getLeft();
SQLBinaryOperator operator = sqlBinaryOpExpr.getOperator();
SQLExpr right = sqlBinaryOpExpr.getRight();
// 处理---------------------
// 1.1.3.2 如果是子查询
} else if (where instanceof SQLInSubQueryExpr) {
SQLInSubQueryExpr sqlInSubQueryExpr = (SQLInSubQueryExpr) where;
// 处理---------------------
}
// 1.1.4 获取分组
SQLSelectGroupByClause groupBy = sqlSelectQueryBlock.getGroupBy();
// 处理---------------------
// 1.1.5 获取排序
SQLOrderBy orderBy = sqlSelectQueryBlock.getOrderBy();
// 处理---------------------
// 1.1.6 获取分页
SQLLimit limit = sqlSelectQueryBlock.getLimit();
// 处理---------------------
// 1.2 union的查询语句
} else if (sqlSelectQuery instanceof SQLUnionQuery) {
// 处理---------------------
}
// 2 插入语句
SQLInsertStatement sqlInsertStatement = (SQLInsertStatement) statement;
// 2.1 with语句
SQLWithSubqueryClause with = sqlInsertStatement.getWith();
// 2.2 获取表
SQLTableSource insertTable = sqlInsertStatement.getTableSource();
// 2.3 获取插入列名
List<SQLExpr> insertColumns = sqlInsertStatement.getColumns();
// 2.4 获取插入子查询
SQLSelect insertSelect = sqlInsertStatement.getQuery();
// 2.5 获取批量插入values
List<SQLInsertStatement.ValuesClause> values = sqlInsertStatement.getValuesList();
// 3 更新语句
SQLUpdateStatement sqlUpdateStatement = (SQLUpdateStatement) statement;
// 3.1 获取更新条目
List<SQLUpdateSetItem> updateSetItems = sqlUpdateStatement.getItems();
// 3.2 获取表
SQLTableSource updateTable = sqlUpdateStatement.getFrom();
// 3.3 获取条件
SQLExpr updateWhere = sqlUpdateStatement.getWhere();
// 4 删除语句
SQLDeleteStatement sqlDeleteStatement = (SQLDeleteStatement) statement;
// 4.1 获取表
SQLTableSource deleteTable = sqlDeleteStatement.getFrom();
// 4.2 获取条件
SQLExpr deleteWhere = sqlDeleteStatement.getWhere();
// 5 建表语句
SQLCreateTableStatement sqlCreateTableStatement = (SQLCreateTableStatement) statement;
// 5.1 获取列索引
sqlCreateTableStatement.findIndex("a");
// 5.2 获取列
sqlCreateTableStatement.findColumn("a");
// ...
}
}
使用visitor获取sql信息
public class Test {
public static void main(String[] args) {
String sql = "select a,b from (select * from table_a) temp where temp.a = 'a';";
// 解析
List<SQLStatement> statements = SQLUtils.parseStatements(sql, JdbcConstants.MYSQL);
// 只考虑一条语句
SQLStatement statement = statements.get(0);
// 根据数据库类型构造visitor
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
// 使用visitor访问ast
statement.accept(visitor);
// 获取数据库类型
System.out.println("数据库类型\t\t" + visitor.getDbType());
// 获取字段名称
System.out.println("查询的字段\t\t" + visitor.getColumns());
// 获取表名称
System.out.println("表名\t\t\t" + visitor.getTables().keySet());
// 条件字段
System.out.println("条件\t\t\t" + visitor.getConditions());
// group by
System.out.println("group by\t\t" + visitor.getGroupByColumns());
// order by
System.out.println("order by\t\t" + visitor.getOrderByColumns());
}
}