• SQL解析器使用指南

    hzero-starter-sqlparser是基于druid的sql解析器移植而来,未来会逐渐完善功能、增强api。

    sql语句首先需要由parser转化为AST,并通过visitor获取sql语句的类型、from条件、where条件、groupby等信息。

    AST(abstract syntax tree)

    中文含义为抽象语法树,以下是一个简单sql语句的ast图。

    SELECT a FROM table_a WHERE table_a.a = ‘a’;

    简单sql的AST图

    简单理解,就是把字符串形式的sql语句结构化,变成具有层次的树形结构,便于sql语句的分析、改造。如果需要对现有的sql语句改造,则可以直接更新ast上的节点信息。

    ast节点主要包括SQLObject、SQLExpr、SQLStatement三种抽象类型。

    interface SQLObject {}
    interface SQLExpr extends SQLObject {}
    interface SQLStatement extends SQLObject {}
    

    SQLExpr

    // SQLName是一种的SQLExpr的Expr,包括SQLIdentifierExpr、SQLPropertyExpr等
    public interface SQLName extends SQLExpr {}
    
    // 例如 ID = 3 这里的ID是一个SQLIdentifierExpr
    class SQLIdentifierExpr implements SQLExpr, SQLName {
        String name;
    } 
    
    // 例如 A.ID = 3 这里的A.ID是一个SQLPropertyExpr
    class SQLPropertyExpr implements SQLExpr, SQLName {
        SQLExpr owner;
        String name;
    } 
    
    // 例如 ID = 3 这是一个SQLBinaryOpExpr
    // left是ID (SQLIdentifierExpr)
    // right是3 (SQLIntegerExpr)
    class SQLBinaryOpExpr implements SQLExpr {
        SQLExpr left;
        SQLExpr right;
        SQLBinaryOperator operator;
    }
    
    // 例如 select * from where id = ?,这里的?是一个SQLVariantRefExpr,name是'?'
    class SQLVariantRefExpr extends SQLExprImpl { 
        String name;
    }
    
    // 例如 ID = 3 这里的3是一个SQLIntegerExpr
    public class SQLIntegerExpr extends SQLNumericLiteralExpr implements SQLValuableExpr { 
        Number number;
    
        // 所有实现了SQLValuableExpr接口的SQLExpr都可以直接调用这个方法求值
        @Override
        public Object getValue() {
            return this.number;
        }
    }
    
    // 例如 NAME = 'jobs' 这里的'jobs'是一个SQLCharExpr
    public class SQLCharExpr extends SQLTextLiteralExpr implements SQLValuableExpr{
        String text;
    }
    

    SQLStatement

    class SQLSelectStatement implements SQLStatement {
        SQLSelect select;
    }
    class SQLUpdateStatement implements SQLStatement {
        SQLExprTableSource tableSource;
         List<SQLUpdateSetItem> items;
         SQLExpr where;
    }
    class SQLDeleteStatement implements SQLStatement {
        SQLTableSource tableSource; 
        SQLExpr where;
    }
    class SQLInsertStatement implements SQLStatement {
        SQLExprTableSource tableSource;
        List<SQLExpr> columns;
        SQLSelect query;
    }
    

    SQLTableSource

    class SQLTableSourceImpl extends SQLObjectImpl implements SQLTableSource { 
        String alias;
    }
    
    // 例如 select * from emp where i = 3,这里的from emp是一个SQLExprTableSource
    // 其中expr是一个name=emp的SQLIdentifierExpr
    class SQLExprTableSource extends SQLTableSourceImpl {
        SQLExpr expr;
    }
    
    // 例如 select * from emp e inner join org o on e.org_id = o.id
    // 其中left 'emp e' 是一个SQLExprTableSource,right 'org o'也是一个SQLExprTableSource
    // condition 'e.org_id = o.id'是一个SQLBinaryOpExpr
    class SQLJoinTableSource extends SQLTableSourceImpl {
        SQLTableSource left;
        SQLTableSource right;
        JoinType joinType; // INNER_JOIN/CROSS_JOIN/LEFT_OUTER_JOIN/RIGHT_OUTER_JOIN/...
        SQLExpr condition;
    }
    
    // 例如 select * from (select * from temp) a,这里第一层from(...)是一个SQLSubqueryTableSource
    SQLSubqueryTableSource extends SQLTableSourceImpl {
        SQLSelect select;
    }
    
    /* 
    例如
    WITH RECURSIVE ancestors AS (
        SELECT *
        FROM org
        UNION
        SELECT f.*
        FROM org f, ancestors a
        WHERE f.id = a.parent_id
    )
    SELECT *
    FROM ancestors;
    
    这里的ancestors AS (...) 是一个SQLWithSubqueryClause.Entry
    */
    class SQLWithSubqueryClause {
        static class Entry extends SQLTableSourceImpl { 
             SQLSelect subQuery;
        }
    }
    

    SQLSelectStatement

    SQLSelectStatement包含一个SQLSelect,SQLSelect包含一个SQLSelectQuery,都是组成的关系。SQLSelectQuery有主要的两个派生类,分别是SQLSelectQueryBlock和SQLUnionQuery。

    class SQLSelect extends SQLObjectImpl { 
        SQLWithSubqueryClause withSubQuery;
        SQLSelectQuery query;
    }
    
    interface SQLSelectQuery extends SQLObject {}
    
    class SQLSelectQueryBlock implements SQLSelectQuery {
        List<SQLSelectItem> selectList;
        SQLTableSource from;
        SQLExprTableSource into;
        SQLExpr where;
        SQLSelectGroupByClause groupBy;
        SQLOrderBy orderBy;
        SQLLimit limit;
    }
    
    class SQLUnionQuery implements SQLSelectQuery {
        SQLSelectQuery left;
        SQLSelectQuery right;
        SQLUnionOperator operator; // UNION/UNION_ALL/MINUS/INTERSECT
    }
    

    SQLCreateTableStatement

    public class SQLCreateTableStatement extends SQLStatementImpl implements SQLDDLStatement, SQLCreateStatement {
        SQLExprTableSource tableSource;
        List<SQLTableElement> tableElementList;
        Select select;
    
        // 忽略大小写的查找SQLCreateTableStatement中的SQLColumnDefinition
        public SQLColumnDefinition findColumn(String columName) {}
    
        // 忽略大小写的查找SQLCreateTableStatement中的column关联的索引
        public SQLTableElement findIndex(String columnName) {}
    
        // 是否外键依赖另外一个表
        public boolean isReferenced(String tableName) {}
    }
    

    使用案例

    ast全语法示例代码

    public class Test {
    
        public static void main(String[] args) {
            String sql = "select a,b from (select * from table_a) temp where temp.a = 'a';";
            // 解析
            List<SQLStatement> statements = SQLUtils.parseStatements(sql, JdbcConstants.MYSQL);
            // 只考虑一条语句
            SQLStatement statement = statements.get(0);
            // 1 查询语句
            SQLSelectStatement sqlSelectStatement = (SQLSelectStatement) statement;
            SQLSelectQuery sqlSelectQuery = sqlSelectStatement.getSelect().getQuery();
            // 1.1 非union的查询语句
            if (sqlSelectQuery instanceof SQLSelectQueryBlock) {
                SQLSelectQueryBlock sqlSelectQueryBlock = (SQLSelectQueryBlock) sqlSelectQuery;
                // 1.1.1获取字段列表
                List<SQLSelectItem> selectItems         = sqlSelectQueryBlock.getSelectList();
                selectItems.forEach(x -> {
                    // 处理---------------------
                });
                // 1.1.2 获取表
                SQLTableSource table = sqlSelectQueryBlock.getFrom();
                // 1.1.2.1 普通单表
                if (table instanceof SQLExprTableSource) {
                    // 处理---------------------
                    // 1.1.2.2 join多表
                } else if (table instanceof SQLJoinTableSource) {
                    // 处理---------------------
                    // 1.1.2.3 子查询作为表
                } else if (table instanceof SQLSubqueryTableSource) {
                    // 处理---------------------
                }
                // 1.1.3 获取where条件
                SQLExpr where = sqlSelectQueryBlock.getWhere();
                // 1.1.3.1 如果是二元表达式
                if (where instanceof SQLBinaryOpExpr) {
                    SQLBinaryOpExpr   sqlBinaryOpExpr = (SQLBinaryOpExpr) where;
                    SQLExpr           left            = sqlBinaryOpExpr.getLeft();
                    SQLBinaryOperator operator        = sqlBinaryOpExpr.getOperator();
                    SQLExpr           right           = sqlBinaryOpExpr.getRight();
                    // 处理---------------------
                    // 1.1.3.2 如果是子查询
                } else if (where instanceof SQLInSubQueryExpr) {
                    SQLInSubQueryExpr sqlInSubQueryExpr = (SQLInSubQueryExpr) where;
                    // 处理---------------------
                }
                // 1.1.4 获取分组
                SQLSelectGroupByClause groupBy = sqlSelectQueryBlock.getGroupBy();
                // 处理---------------------
                // 1.1.5 获取排序
                SQLOrderBy orderBy = sqlSelectQueryBlock.getOrderBy();
                // 处理---------------------
                // 1.1.6 获取分页
                SQLLimit limit = sqlSelectQueryBlock.getLimit();
                // 处理---------------------
                // 1.2 union的查询语句
            } else if (sqlSelectQuery instanceof SQLUnionQuery) {
                // 处理---------------------
            }
    
            // 2 插入语句
            SQLInsertStatement sqlInsertStatement = (SQLInsertStatement) statement;
            // 2.1 with语句
            SQLWithSubqueryClause with = sqlInsertStatement.getWith();
            // 2.2 获取表
            SQLTableSource insertTable = sqlInsertStatement.getTableSource();
            // 2.3 获取插入列名
            List<SQLExpr> insertColumns = sqlInsertStatement.getColumns();
            // 2.4 获取插入子查询
            SQLSelect insertSelect = sqlInsertStatement.getQuery();
            // 2.5 获取批量插入values
            List<SQLInsertStatement.ValuesClause> values = sqlInsertStatement.getValuesList();
    
            // 3 更新语句
            SQLUpdateStatement sqlUpdateStatement = (SQLUpdateStatement) statement;
            // 3.1 获取更新条目
            List<SQLUpdateSetItem> updateSetItems = sqlUpdateStatement.getItems();
            // 3.2 获取表
            SQLTableSource updateTable = sqlUpdateStatement.getFrom();
            // 3.3 获取条件
            SQLExpr updateWhere = sqlUpdateStatement.getWhere();
    
            // 4 删除语句
            SQLDeleteStatement sqlDeleteStatement = (SQLDeleteStatement) statement;
            // 4.1 获取表
            SQLTableSource deleteTable = sqlDeleteStatement.getFrom();
            // 4.2 获取条件
            SQLExpr deleteWhere = sqlDeleteStatement.getWhere();
    
            // 5 建表语句
            SQLCreateTableStatement sqlCreateTableStatement = (SQLCreateTableStatement) statement;
            // 5.1 获取列索引
            sqlCreateTableStatement.findIndex("a");
            // 5.2 获取列
            sqlCreateTableStatement.findColumn("a");
    
            // ...
        }
    
    }
    

    使用visitor获取sql信息

    public class Test {
    
        public static void main(String[] args) {
            String sql = "select a,b from (select * from table_a) temp where temp.a = 'a';";
            // 解析
            List<SQLStatement> statements = SQLUtils.parseStatements(sql, JdbcConstants.MYSQL);
            // 只考虑一条语句
            SQLStatement statement = statements.get(0);
            // 根据数据库类型构造visitor
            MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
            // 使用visitor访问ast
            statement.accept(visitor);
    
            // 获取数据库类型
            System.out.println("数据库类型\t\t" + visitor.getDbType());
            // 获取字段名称
            System.out.println("查询的字段\t\t" + visitor.getColumns());
            // 获取表名称
            System.out.println("表名\t\t\t" + visitor.getTables().keySet());
            // 条件字段
            System.out.println("条件\t\t\t" + visitor.getConditions());
            // group by
            System.out.println("group by\t\t" + visitor.getGroupByColumns());
            // order by
            System.out.println("order by\t\t" + visitor.getOrderByColumns());
        }
    
    }