Browse Source

feat: 第一版

SongZihuan 1 year ago
commit
db715c1d80
4 changed files with 446 additions and 0 deletions
  1. 16 0
      README.md
  2. 5 0
      go.mod
  3. 2 0
      go.sum
  4. 423 0
      sql.go

+ 16 - 0
README.md

@@ -0,0 +1,16 @@
+# 安全的SQL
+## 介绍
+检查sql是否安全。  
+安全:无副作用,仅执行查询。
+
+## 使用方式
+### context.Context参数
+* `Allow-Func-Ident`:字符串列表
+* `Allow-Func-Name`:字符串列表
+* `Allow-Col-Name`:字符串列表
+* `Allow-DataBase-Name`:字符串列表
+* `Allow-Table-Name`:字符串列表
+
+### 调用
+调用`CheckSQL`,传入`ctx`和`query`。
+返回值:是否安全、失败原因、错误。

+ 5 - 0
go.mod

@@ -0,0 +1,5 @@
+module github.com/SuperH-0630/safe-sql
+
+go 1.19
+
+require github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2

+ 2 - 0
go.sum

@@ -0,0 +1,2 @@
+github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 h1:zzrxE1FKn5ryBNl9eKOeqQ58Y/Qpo3Q9QNxKHX5uzzQ=
+github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2/go.mod h1:hzfGeIUDq/j97IG+FhNqkowIyEcD88LrW6fyU3K3WqY=

+ 423 - 0
sql.go

@@ -0,0 +1,423 @@
+package sql
+
+import (
+	"context"
+	"fmt"
+	"strings"
+
+	"github.com/xwb1989/sqlparser"
+)
+
+var SafeSqlFunc = []string{
+	"SUM", "AVG", "MAX", "MIN", "COUNT", "CONCAT", "SUBSTRING", "CHAR_LENGTH",
+	"LOWER", "UPPER", "NOW", "DATE_FORMAT", "DATE_ADD", "DATEDIFF", "ABS",
+	"CEIL", "FLOOR", "EXP", "LOG", "AND", "OR", "NOT", "CAST", "CONVERT",
+	"COALESCE", "NULLIF",
+}
+
+type Record struct {
+	Msg string `json:"msg"`
+}
+
+func CheckSQL(ctx context.Context, sqlQuery string) (bool, string, error) {
+	stmt, err := sqlparser.Parse(sqlQuery)
+	if err != nil {
+		return false, err.Error(), err
+	}
+
+	r := &Record{}
+	return checkStmt(ctx, stmt, r), r.Msg, nil
+}
+
+func checkStmt(ctx context.Context, stmt sqlparser.Statement, r *Record) bool {
+	if stmt == nil {
+		return true
+	}
+
+	switch stmt := stmt.(type) {
+	case *sqlparser.Select:
+		for _, i := range stmt.SelectExprs {
+			if !checkSelectExpr(ctx, i, r) {
+				return false
+			}
+		}
+
+		for _, i := range stmt.From {
+			if !checkTableExpr(ctx, i, r) {
+				return false
+			}
+		}
+
+		if stmt.Where != nil {
+			if !checkExpr(ctx, stmt.Where.Expr, r) {
+				return false
+			}
+		}
+
+		for _, g := range stmt.GroupBy {
+			if !checkExpr(ctx, g, r) {
+				return false
+			}
+		}
+
+		if stmt.Having != nil {
+			if !checkExpr(ctx, stmt.Having.Expr, r) {
+				return false
+			}
+		}
+
+		for _, o := range stmt.OrderBy {
+			if o == nil {
+				continue
+			}
+			if !checkExpr(ctx, o.Expr, r) {
+				return false
+			}
+		}
+
+		if stmt.Limit != nil {
+			if !checkExpr(ctx, stmt.Limit.Offset, r) {
+				return false
+			}
+
+			if !checkExpr(ctx, stmt.Limit.Rowcount, r) {
+				return false
+			}
+		}
+	default:
+		r.Msg = "bad stmt operation"
+		return false
+	}
+
+	return true
+}
+
+func checkTableExpr(ctx context.Context, expr sqlparser.TableExpr, r *Record) bool {
+	if expr == nil {
+		return true
+	}
+
+	switch expr := expr.(type) {
+	case *sqlparser.AliasedTableExpr:
+		return checkSimpleTableExpr(ctx, expr.Expr, r)
+	case *sqlparser.ParenTableExpr:
+		for _, e := range expr.Exprs {
+			if !checkTableExpr(ctx, e, r) {
+				return false
+			}
+		}
+		return true
+	case *sqlparser.JoinTableExpr:
+		if !checkTableExpr(ctx, expr.LeftExpr, r) {
+			return false
+		}
+
+		if !checkTableExpr(ctx, expr.RightExpr, r) {
+			return false
+		}
+
+		if !checkExpr(ctx, expr.Condition.On, r) {
+			return false
+		}
+		return true
+	}
+
+	r.Msg = "bad table expr"
+	return false
+}
+
+func checkExpr(ctx context.Context, expr sqlparser.Expr, r *Record) bool {
+	if expr == nil {
+		return true
+	}
+
+	switch expr := expr.(type) {
+	default:
+		return false
+	case *sqlparser.AndExpr:
+		if !checkExpr(ctx, expr.Left, r) {
+			return false
+		}
+		if !checkExpr(ctx, expr.Right, r) {
+			return false
+		}
+	case *sqlparser.OrExpr:
+		if !checkExpr(ctx, expr.Left, r) {
+			return false
+		}
+		if !checkExpr(ctx, expr.Right, r) {
+			return false
+		}
+	case *sqlparser.NotExpr:
+		if !checkExpr(ctx, expr.Expr, r) {
+			return false
+		}
+	case *sqlparser.ParenExpr:
+		if !checkExpr(ctx, expr.Expr, r) {
+			return false
+		}
+	case *sqlparser.ComparisonExpr:
+		if !checkExpr(ctx, expr.Left, r) {
+			return false
+		}
+		if !checkExpr(ctx, expr.Right, r) {
+			return false
+		}
+		if !checkExpr(ctx, expr.Escape, r) {
+			return false
+		}
+	case *sqlparser.RangeCond:
+		if !checkExpr(ctx, expr.Left, r) {
+			return false
+		}
+		if !checkExpr(ctx, expr.From, r) {
+			return false
+		}
+		if !checkExpr(ctx, expr.To, r) {
+			return false
+		}
+	case *sqlparser.IsExpr:
+		if !checkExpr(ctx, expr.Expr, r) {
+			return false
+		}
+	case *sqlparser.ExistsExpr:
+		if !checkExpr(ctx, expr.Subquery, r) {
+			return false
+		}
+	case *sqlparser.SQLVal:
+		// 可以
+	case *sqlparser.NullVal:
+		// 可以
+	case sqlparser.BoolVal:
+		// 可以
+	case *sqlparser.ColName:
+		// 检查ColName
+	case sqlparser.ValTuple:
+		for _, e := range expr {
+			if !checkExpr(ctx, e, r) {
+				return false
+			}
+		}
+	case *sqlparser.Subquery:
+		if !checkStmt(ctx, expr.Select, r) {
+			return false
+		}
+	case sqlparser.ListArg:
+		// 可以
+	case *sqlparser.BinaryExpr:
+		if !checkExpr(ctx, expr.Left, r) {
+			return false
+		}
+		if !checkExpr(ctx, expr.Right, r) {
+			return false
+		}
+	case *sqlparser.UnaryExpr:
+		if !checkExpr(ctx, expr.Expr, r) {
+			return false
+		}
+	case *sqlparser.IntervalExpr:
+		if !checkExpr(ctx, expr.Expr, r) {
+			return false
+		}
+	case *sqlparser.CollateExpr:
+		if !checkExpr(ctx, expr.Expr, r) {
+			return false
+		}
+	case *sqlparser.FuncExpr:
+		if !checkFuncName(ctx, expr.Qualifier.String(), expr.Name.String(), r) {
+			return false
+		}
+
+		for _, s := range expr.Exprs {
+			if !checkSelectExpr(ctx, s, r) {
+				return false
+			}
+		}
+	case *sqlparser.CaseExpr:
+		if !checkExpr(ctx, expr.Expr, r) {
+			return false
+		}
+		if !checkExpr(ctx, expr.Else, r) {
+			return false
+		}
+		for _, w := range expr.Whens {
+			if w == nil {
+				continue
+			}
+			if !checkExpr(ctx, w.Cond, r) {
+				return false
+			}
+			if !checkExpr(ctx, w.Val, r) {
+				return false
+			}
+		}
+	case *sqlparser.ValuesFuncExpr:
+		r.Msg = "bad values expr"
+		return false // values表达式用于插入
+	case *sqlparser.ConvertExpr:
+		if !checkExpr(ctx, expr.Expr, r) {
+			return false
+		}
+	case *sqlparser.SubstrExpr:
+		if !checkExpr(ctx, expr.From, r) {
+			return false
+		}
+		if !checkExpr(ctx, expr.To, r) {
+			return false
+		}
+		if !checkColName(ctx, expr.Name, r) {
+			return false
+		}
+	case *sqlparser.ConvertUsingExpr:
+		if !checkExpr(ctx, expr.Expr, r) {
+			return false
+		}
+	case *sqlparser.MatchExpr:
+		if !checkExpr(ctx, expr.Expr, r) {
+			return false
+		}
+		for _, s := range expr.Columns {
+			if !checkSelectExpr(ctx, s, r) {
+				return false
+			}
+		}
+	case *sqlparser.GroupConcatExpr:
+		for _, s := range expr.Exprs {
+			if !checkSelectExpr(ctx, s, r) {
+				return false
+			}
+		}
+		for _, o := range expr.OrderBy {
+			if o == nil {
+				continue
+			}
+			if !checkExpr(ctx, o.Expr, r) {
+				return false
+			}
+		}
+	case *sqlparser.Default:
+		if !checkColNameString(ctx, expr.ColName, r) {
+			return false
+		}
+	}
+
+	return true
+}
+
+func checkSimpleTableExpr(ctx context.Context, expr sqlparser.SimpleTableExpr, r *Record) bool {
+	if expr == nil {
+		return true
+	}
+
+	switch expr := expr.(type) {
+	case sqlparser.TableName:
+		return checkTableName(ctx, expr, r)
+	case *sqlparser.Subquery:
+		return checkStmt(ctx, expr.Select, r)
+	}
+
+	r.Msg = "bad simple table expr"
+	return false
+}
+
+func checkSelectExpr(ctx context.Context, expr sqlparser.SelectExpr, r *Record) bool {
+	if expr == nil {
+		return true
+	}
+
+	switch expr := expr.(type) {
+	case *sqlparser.StarExpr:
+		return checkTableName(ctx, expr.TableName, r)
+	case *sqlparser.AliasedExpr:
+		return checkExpr(ctx, expr.Expr, r)
+	case sqlparser.Nextval:
+		return checkExpr(ctx, expr.Expr, r)
+	}
+
+	r.Msg = "bad select expr"
+	return false
+}
+
+func checkTableName(ctx context.Context, tableName sqlparser.TableName, r *Record) bool {
+	if tableName.IsEmpty() {
+		return true
+	}
+
+	allowTableName, ok := ctx.Value("Allow-Table-Name").([]string)
+	if ok && !tableName.Name.IsEmpty() && !InList(allowTableName, tableName.Name.String()) {
+		r.Msg = fmt.Sprintf("bad table name for %s", tableName.Name.String())
+		return false
+	}
+
+	allowDBName, ok := ctx.Value("Allow-DataBase-Name").([]string)
+	if ok {
+		if !tableName.Qualifier.IsEmpty() && !InList(allowDBName, tableName.Qualifier.String()) {
+			r.Msg = fmt.Sprintf("bad table qualifier for %s", tableName.Qualifier.String())
+			return false
+		}
+	} else {
+		if !tableName.Qualifier.IsEmpty() {
+			r.Msg = fmt.Sprintf("bad table qualifier for %s", tableName.Qualifier.String())
+			return false
+		}
+	}
+
+	return true
+}
+
+func checkColName(ctx context.Context, colname *sqlparser.ColName, r *Record) bool {
+	if !checkTableName(ctx, colname.Qualifier, r) {
+		return false
+	}
+
+	return checkColNameString(ctx, colname.Name.String(), r)
+}
+
+func checkColNameString(ctx context.Context, colname string, r *Record) bool {
+	if len(colname) == 0 {
+		return true
+	}
+
+	allowColName, ok := ctx.Value("Allow-Col-Name").([]string)
+	if ok && !InList(allowColName, colname) && !InList(allowColName, fmt.Sprintf("`%s`", colname)) {
+		r.Msg = fmt.Sprintf("bad table col for %s", colname)
+		return false
+	}
+
+	return true
+}
+
+func checkFuncName(ctx context.Context, ident string, funcName string, r *Record) bool {
+	if len(funcName) == 0 && len(ident) == 0 {
+		return false
+	}
+
+	if len(ident) != 0 {
+		allowColIdent, ok := ctx.Value("Allow-Func-Ident").([]string)
+		if ok && !InList(allowColIdent, ident) {
+			r.Msg = fmt.Sprintf("bad func ident for %s", ident)
+			return false
+		}
+	}
+
+	allowColName, ok := ctx.Value("Allow-Func-Name").([]string)
+	if ok && !InList(allowColName, strings.ToUpper(funcName)) {
+		r.Msg = fmt.Sprintf("bad func name for %s", funcName)
+		return false
+	} else if !ok {
+		r.Msg = fmt.Sprintf("bad func name for %s", funcName)
+		return false
+	}
+
+	return true
+}
+
+func InList[T string | int64](lst []T, element T) bool {
+	for _, i := range lst {
+		if i == element {
+			return true
+		}
+	}
+
+	return false
+}