123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423 |
- package sql
- import (
- "context"
- "fmt"
- "strings"
- "github.com/xwb1989/sqlparser"
- )
- var SafeSqlFunc = []string{
- "SUM", "AVG", "MAX", "MIN", "COUNT", "CONCAT", "SUBSTRING", "CHAR_LENGTH",
- "LOWER", "UPPER", "NOW", "DATE_FORMAT", "DATE_ADD", "DATEDIFF", "ABS",
- "CEIL", "FLOOR", "EXP", "LOG", "AND", "OR", "NOT", "CAST", "CONVERT",
- "COALESCE", "NULLIF",
- }
- type Record struct {
- Msg string `json:"msg"`
- }
- func CheckSQL(ctx context.Context, sqlQuery string) (bool, string, error) {
- stmt, err := sqlparser.Parse(sqlQuery)
- if err != nil {
- return false, err.Error(), err
- }
- r := &Record{}
- return checkStmt(ctx, stmt, r), r.Msg, nil
- }
- func checkStmt(ctx context.Context, stmt sqlparser.Statement, r *Record) bool {
- if stmt == nil {
- return true
- }
- switch stmt := stmt.(type) {
- case *sqlparser.Select:
- for _, i := range stmt.SelectExprs {
- if !checkSelectExpr(ctx, i, r) {
- return false
- }
- }
- for _, i := range stmt.From {
- if !checkTableExpr(ctx, i, r) {
- return false
- }
- }
- if stmt.Where != nil {
- if !checkExpr(ctx, stmt.Where.Expr, r) {
- return false
- }
- }
- for _, g := range stmt.GroupBy {
- if !checkExpr(ctx, g, r) {
- return false
- }
- }
- if stmt.Having != nil {
- if !checkExpr(ctx, stmt.Having.Expr, r) {
- return false
- }
- }
- for _, o := range stmt.OrderBy {
- if o == nil {
- continue
- }
- if !checkExpr(ctx, o.Expr, r) {
- return false
- }
- }
- if stmt.Limit != nil {
- if !checkExpr(ctx, stmt.Limit.Offset, r) {
- return false
- }
- if !checkExpr(ctx, stmt.Limit.Rowcount, r) {
- return false
- }
- }
- default:
- r.Msg = "bad stmt operation"
- return false
- }
- return true
- }
- func checkTableExpr(ctx context.Context, expr sqlparser.TableExpr, r *Record) bool {
- if expr == nil {
- return true
- }
- switch expr := expr.(type) {
- case *sqlparser.AliasedTableExpr:
- return checkSimpleTableExpr(ctx, expr.Expr, r)
- case *sqlparser.ParenTableExpr:
- for _, e := range expr.Exprs {
- if !checkTableExpr(ctx, e, r) {
- return false
- }
- }
- return true
- case *sqlparser.JoinTableExpr:
- if !checkTableExpr(ctx, expr.LeftExpr, r) {
- return false
- }
- if !checkTableExpr(ctx, expr.RightExpr, r) {
- return false
- }
- if !checkExpr(ctx, expr.Condition.On, r) {
- return false
- }
- return true
- }
- r.Msg = "bad table expr"
- return false
- }
- func checkExpr(ctx context.Context, expr sqlparser.Expr, r *Record) bool {
- if expr == nil {
- return true
- }
- switch expr := expr.(type) {
- default:
- return false
- case *sqlparser.AndExpr:
- if !checkExpr(ctx, expr.Left, r) {
- return false
- }
- if !checkExpr(ctx, expr.Right, r) {
- return false
- }
- case *sqlparser.OrExpr:
- if !checkExpr(ctx, expr.Left, r) {
- return false
- }
- if !checkExpr(ctx, expr.Right, r) {
- return false
- }
- case *sqlparser.NotExpr:
- if !checkExpr(ctx, expr.Expr, r) {
- return false
- }
- case *sqlparser.ParenExpr:
- if !checkExpr(ctx, expr.Expr, r) {
- return false
- }
- case *sqlparser.ComparisonExpr:
- if !checkExpr(ctx, expr.Left, r) {
- return false
- }
- if !checkExpr(ctx, expr.Right, r) {
- return false
- }
- if !checkExpr(ctx, expr.Escape, r) {
- return false
- }
- case *sqlparser.RangeCond:
- if !checkExpr(ctx, expr.Left, r) {
- return false
- }
- if !checkExpr(ctx, expr.From, r) {
- return false
- }
- if !checkExpr(ctx, expr.To, r) {
- return false
- }
- case *sqlparser.IsExpr:
- if !checkExpr(ctx, expr.Expr, r) {
- return false
- }
- case *sqlparser.ExistsExpr:
- if !checkExpr(ctx, expr.Subquery, r) {
- return false
- }
- case *sqlparser.SQLVal:
- // 可以
- case *sqlparser.NullVal:
- // 可以
- case sqlparser.BoolVal:
- // 可以
- case *sqlparser.ColName:
- // 检查ColName
- case sqlparser.ValTuple:
- for _, e := range expr {
- if !checkExpr(ctx, e, r) {
- return false
- }
- }
- case *sqlparser.Subquery:
- if !checkStmt(ctx, expr.Select, r) {
- return false
- }
- case sqlparser.ListArg:
- // 可以
- case *sqlparser.BinaryExpr:
- if !checkExpr(ctx, expr.Left, r) {
- return false
- }
- if !checkExpr(ctx, expr.Right, r) {
- return false
- }
- case *sqlparser.UnaryExpr:
- if !checkExpr(ctx, expr.Expr, r) {
- return false
- }
- case *sqlparser.IntervalExpr:
- if !checkExpr(ctx, expr.Expr, r) {
- return false
- }
- case *sqlparser.CollateExpr:
- if !checkExpr(ctx, expr.Expr, r) {
- return false
- }
- case *sqlparser.FuncExpr:
- if !checkFuncName(ctx, expr.Qualifier.String(), expr.Name.String(), r) {
- return false
- }
- for _, s := range expr.Exprs {
- if !checkSelectExpr(ctx, s, r) {
- return false
- }
- }
- case *sqlparser.CaseExpr:
- if !checkExpr(ctx, expr.Expr, r) {
- return false
- }
- if !checkExpr(ctx, expr.Else, r) {
- return false
- }
- for _, w := range expr.Whens {
- if w == nil {
- continue
- }
- if !checkExpr(ctx, w.Cond, r) {
- return false
- }
- if !checkExpr(ctx, w.Val, r) {
- return false
- }
- }
- case *sqlparser.ValuesFuncExpr:
- r.Msg = "bad values expr"
- return false // values表达式用于插入
- case *sqlparser.ConvertExpr:
- if !checkExpr(ctx, expr.Expr, r) {
- return false
- }
- case *sqlparser.SubstrExpr:
- if !checkExpr(ctx, expr.From, r) {
- return false
- }
- if !checkExpr(ctx, expr.To, r) {
- return false
- }
- if !checkColName(ctx, expr.Name, r) {
- return false
- }
- case *sqlparser.ConvertUsingExpr:
- if !checkExpr(ctx, expr.Expr, r) {
- return false
- }
- case *sqlparser.MatchExpr:
- if !checkExpr(ctx, expr.Expr, r) {
- return false
- }
- for _, s := range expr.Columns {
- if !checkSelectExpr(ctx, s, r) {
- return false
- }
- }
- case *sqlparser.GroupConcatExpr:
- for _, s := range expr.Exprs {
- if !checkSelectExpr(ctx, s, r) {
- return false
- }
- }
- for _, o := range expr.OrderBy {
- if o == nil {
- continue
- }
- if !checkExpr(ctx, o.Expr, r) {
- return false
- }
- }
- case *sqlparser.Default:
- if !checkColNameString(ctx, expr.ColName, r) {
- return false
- }
- }
- return true
- }
- func checkSimpleTableExpr(ctx context.Context, expr sqlparser.SimpleTableExpr, r *Record) bool {
- if expr == nil {
- return true
- }
- switch expr := expr.(type) {
- case sqlparser.TableName:
- return checkTableName(ctx, expr, r)
- case *sqlparser.Subquery:
- return checkStmt(ctx, expr.Select, r)
- }
- r.Msg = "bad simple table expr"
- return false
- }
- func checkSelectExpr(ctx context.Context, expr sqlparser.SelectExpr, r *Record) bool {
- if expr == nil {
- return true
- }
- switch expr := expr.(type) {
- case *sqlparser.StarExpr:
- return checkTableName(ctx, expr.TableName, r)
- case *sqlparser.AliasedExpr:
- return checkExpr(ctx, expr.Expr, r)
- case sqlparser.Nextval:
- return checkExpr(ctx, expr.Expr, r)
- }
- r.Msg = "bad select expr"
- return false
- }
- func checkTableName(ctx context.Context, tableName sqlparser.TableName, r *Record) bool {
- if tableName.IsEmpty() {
- return true
- }
- allowTableName, ok := ctx.Value("Allow-Table-Name").([]string)
- if ok && !tableName.Name.IsEmpty() && !InList(allowTableName, tableName.Name.String()) {
- r.Msg = fmt.Sprintf("bad table name for %s", tableName.Name.String())
- return false
- }
- allowDBName, ok := ctx.Value("Allow-DataBase-Name").([]string)
- if ok {
- if !tableName.Qualifier.IsEmpty() && !InList(allowDBName, tableName.Qualifier.String()) {
- r.Msg = fmt.Sprintf("bad table qualifier for %s", tableName.Qualifier.String())
- return false
- }
- } else {
- if !tableName.Qualifier.IsEmpty() {
- r.Msg = fmt.Sprintf("bad table qualifier for %s", tableName.Qualifier.String())
- return false
- }
- }
- return true
- }
- func checkColName(ctx context.Context, colname *sqlparser.ColName, r *Record) bool {
- if !checkTableName(ctx, colname.Qualifier, r) {
- return false
- }
- return checkColNameString(ctx, colname.Name.String(), r)
- }
- func checkColNameString(ctx context.Context, colname string, r *Record) bool {
- if len(colname) == 0 {
- return true
- }
- allowColName, ok := ctx.Value("Allow-Col-Name").([]string)
- if ok && !InList(allowColName, colname) && !InList(allowColName, fmt.Sprintf("`%s`", colname)) {
- r.Msg = fmt.Sprintf("bad table col for %s", colname)
- return false
- }
- return true
- }
- func checkFuncName(ctx context.Context, ident string, funcName string, r *Record) bool {
- if len(funcName) == 0 && len(ident) == 0 {
- return false
- }
- if len(ident) != 0 {
- allowColIdent, ok := ctx.Value("Allow-Func-Ident").([]string)
- if ok && !InList(allowColIdent, ident) {
- r.Msg = fmt.Sprintf("bad func ident for %s", ident)
- return false
- }
- }
- allowColName, ok := ctx.Value("Allow-Func-Name").([]string)
- if ok && !InList(allowColName, strings.ToUpper(funcName)) {
- r.Msg = fmt.Sprintf("bad func name for %s", funcName)
- return false
- } else if !ok {
- r.Msg = fmt.Sprintf("bad func name for %s", funcName)
- return false
- }
- return true
- }
- func InList[T string | int64](lst []T, element T) bool {
- for _, i := range lst {
- if i == element {
- return true
- }
- }
- return false
- }
|