fix:回退公共query
This commit is contained in:
parent
187f2a3016
commit
1f5f7a3dc6
175
lxDb/sql.go
175
lxDb/sql.go
@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"git.listensoft.net/tool/lxutils/lxUtil"
|
"git.listensoft.net/tool/lxutils/lxUtil"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"regexp"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -208,7 +207,6 @@ func Query(tx *gorm.DB, m interface{}, list interface{}, q *PaginationQuery) (er
|
|||||||
//}
|
//}
|
||||||
|
|
||||||
func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, params ...interface{}) (err error) {
|
func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, params ...interface{}) (err error) {
|
||||||
|
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
builder.WriteString(sql)
|
builder.WriteString(sql)
|
||||||
|
|
||||||
@ -219,8 +217,6 @@ func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, par
|
|||||||
// 条件字段
|
// 条件字段
|
||||||
if q != nil {
|
if q != nil {
|
||||||
where, args := q.BuildRawWhere()
|
where, args := q.BuildRawWhere()
|
||||||
|
|
||||||
// 安全地添加 WHERE 子句
|
|
||||||
if hasWhere(sql) { // 原SQL已有WHERE子句
|
if hasWhere(sql) { // 原SQL已有WHERE子句
|
||||||
builder.WriteString(where) // 去掉where 前头的and or ..
|
builder.WriteString(where) // 去掉where 前头的and or ..
|
||||||
} else { // 原SQL没有WHERE子句
|
} else { // 原SQL没有WHERE子句
|
||||||
@ -231,10 +227,9 @@ func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, par
|
|||||||
where = strings.Replace(where, " OR ", " WHERE ", 1)
|
where = strings.Replace(where, " OR ", " WHERE ", 1)
|
||||||
builder.WriteString(where)
|
builder.WriteString(where)
|
||||||
} else {
|
} else {
|
||||||
builder.WriteString(where)
|
builder.WriteString(where) // "" 或者 " GROUP BY ..."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(args) > 0 {
|
if len(args) > 0 {
|
||||||
params = append(params, args...)
|
params = append(params, args...)
|
||||||
}
|
}
|
||||||
@ -252,9 +247,8 @@ func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, par
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取汇总信息
|
// 获取汇总信息 // TODO: 汇总应该放到查询列表的后面
|
||||||
if q.Summary != "" && len(q.SummarySql) == 0 {
|
if q.Summary != "" && len(q.SummarySql) == 0 {
|
||||||
// 安全地构建汇总字段
|
|
||||||
q.SummarySql = fieldsToSumSql(q.Summary)
|
q.SummarySql = fieldsToSumSql(q.Summary)
|
||||||
}
|
}
|
||||||
if len(q.Summary) != 0 {
|
if len(q.Summary) != 0 {
|
||||||
@ -266,7 +260,6 @@ func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, par
|
|||||||
tx.Raw("SELECT "+strings.Join(q.SummarySql, ", ")+" FROM ("+sql2+") ssss", params...).Take(&summary)
|
tx.Raw("SELECT "+strings.Join(q.SummarySql, ", ")+" FROM ("+sql2+") ssss", params...).Take(&summary)
|
||||||
|
|
||||||
// []byte 转 string. 不太合理, 应该返回int或float
|
// []byte 转 string. 不太合理, 应该返回int或float
|
||||||
|
|
||||||
for k, v := range summary {
|
for k, v := range summary {
|
||||||
if bs, ok := v.([]byte); ok {
|
if bs, ok := v.([]byte); ok {
|
||||||
summary[k] = string(bs)
|
summary[k] = string(bs)
|
||||||
@ -276,17 +269,13 @@ func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, par
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 安全地处理排序 - 使用白名单验证字段名
|
// 排序处理
|
||||||
if q.OrderBy != "" {
|
if q.OrderBy != "" {
|
||||||
// 验证输入参数的安全性
|
|
||||||
if !isSafeSQL(q.OrderBy) {
|
|
||||||
return errors.New("环境异常")
|
|
||||||
}
|
|
||||||
s := fmt.Sprintf(" ORDER BY %s", lxUtil.FieldToColumn(q.OrderBy)) // TODO: q.OrderBy是字符串,可能多个字段 会有问题吗
|
s := fmt.Sprintf(" ORDER BY %s", lxUtil.FieldToColumn(q.OrderBy)) // TODO: q.OrderBy是字符串,可能多个字段 会有问题吗
|
||||||
builder.WriteString(s)
|
builder.WriteString(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 安全地处理分页 - 使用参数化查询
|
// 偏移量处理
|
||||||
if q.Limit > 0 {
|
if q.Limit > 0 {
|
||||||
if q.Offset > 0 {
|
if q.Offset > 0 {
|
||||||
offset := (q.Offset - 1) * q.Limit
|
offset := (q.Offset - 1) * q.Limit
|
||||||
@ -299,14 +288,14 @@ func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, par
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 执行最终查询 - 使用参数化查询
|
//tx.Raw(builder.String(), params...).Scan(list) // FIXME: unsupported data type: &[] why?
|
||||||
tx.Raw(builder.String(), params...).Find(list)
|
tx.Raw(builder.String(), params...).Find(list) // Find与Scan区别: list传入[]时, 查询为空的情况下, Find返回的是[], 而Scan返回的是nil.
|
||||||
|
// ref: What is the difference between Find and Scan: https://github.com/go-gorm/gorm/issues/4218
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, params ...interface{}) (err error) {
|
func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, params ...interface{}) (err error) {
|
||||||
|
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
builder.WriteString(sql)
|
builder.WriteString(sql)
|
||||||
|
|
||||||
@ -317,16 +306,8 @@ func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery,
|
|||||||
// 条件字段
|
// 条件字段
|
||||||
if q != nil {
|
if q != nil {
|
||||||
where, args := q.BuildRawWhere()
|
where, args := q.BuildRawWhere()
|
||||||
|
|
||||||
if hasWhere(sql) { // 原SQL已有WHERE子句
|
if hasWhere(sql) { // 原SQL已有WHERE子句
|
||||||
// 确保 where 子句以 AND 或 OR 开头,然后安全添加
|
builder.WriteString(where) // 去掉where 前头的and or ..
|
||||||
if strings.HasPrefix(where, " AND ") || strings.HasPrefix(where, " OR ") {
|
|
||||||
builder.WriteString(where)
|
|
||||||
} else {
|
|
||||||
// 如果 where 不是以 AND/OR 开头,添加 AND 前缀
|
|
||||||
builder.WriteString(" AND ")
|
|
||||||
builder.WriteString(strings.TrimSpace(where))
|
|
||||||
}
|
|
||||||
} else { // 原SQL没有WHERE子句
|
} else { // 原SQL没有WHERE子句
|
||||||
if strings.HasPrefix(where, " AND ") {
|
if strings.HasPrefix(where, " AND ") {
|
||||||
where = strings.Replace(where, " AND ", " WHERE ", 1)
|
where = strings.Replace(where, " AND ", " WHERE ", 1)
|
||||||
@ -334,12 +315,10 @@ func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery,
|
|||||||
} else if strings.HasPrefix(where, " OR ") {
|
} else if strings.HasPrefix(where, " OR ") {
|
||||||
where = strings.Replace(where, " OR ", " WHERE ", 1)
|
where = strings.Replace(where, " OR ", " WHERE ", 1)
|
||||||
builder.WriteString(where)
|
builder.WriteString(where)
|
||||||
} else if where != "" {
|
} else {
|
||||||
builder.WriteString(" WHERE ")
|
builder.WriteString(where) // "" 或者 " GROUP BY ..."
|
||||||
builder.WriteString(where)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(args) > 0 {
|
if len(args) > 0 {
|
||||||
params = append(params, args...)
|
params = append(params, args...)
|
||||||
}
|
}
|
||||||
@ -350,30 +329,29 @@ func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery,
|
|||||||
// 记录条数
|
// 记录条数
|
||||||
if needDoCount(q) {
|
if needDoCount(q) {
|
||||||
var total int64
|
var total int64
|
||||||
// 使用安全的 COUNT 查询
|
//tx = tx.Count(&total)
|
||||||
|
//tx.Raw("SELECT COUNT(*) as total FROM ("+sql2+") aaaa", params...).Take(&total)
|
||||||
|
//todo lcs 优化速度 start
|
||||||
|
// 替换 SELECT 和 FROM 之间的部分,并去掉 GROUP BY
|
||||||
replacedSQL := replaceSelectAndRemoveGroupBy(sql2)
|
replacedSQL := replaceSelectAndRemoveGroupBy(sql2)
|
||||||
if err := tx.Raw(replacedSQL, params...).Take(&total).Error; err != nil {
|
tx.Raw(replacedSQL, params...).Take(&total)
|
||||||
return err
|
//todo end
|
||||||
}
|
|
||||||
q.Total = int(total)
|
q.Total = int(total)
|
||||||
if total == 0 { // 如果查了记录条数并且是0, 就不需要查记录和汇总了
|
if total == 0 { // 如果查了记录条数并且是0, 就不需要查记录和汇总了
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取汇总信息
|
// 获取汇总信息 // TODO: 汇总应该放到查询列表的后面
|
||||||
if q.Summary != "" && len(q.SummarySql) == 0 {
|
if q.Summary != "" && len(q.SummarySql) == 0 {
|
||||||
// 安全地构建汇总字段
|
|
||||||
q.SummarySql = fieldsToSumSql(q.Summary)
|
q.SummarySql = fieldsToSumSql(q.Summary)
|
||||||
}
|
}
|
||||||
if len(q.Summary) != 0 {
|
if len(q.Summary) != 0 {
|
||||||
tx = tx.Offset(-1) // 需要去除offset, 否则结果可能为空, 注意: 设置0不起作用.
|
tx = tx.Offset(-1) // 需要去除offset, 否则结果可能为空, 注意: 设置0不起作用.
|
||||||
var summary = make(map[string]interface{})
|
var summary = make(map[string]interface{})
|
||||||
|
//tx.Order("") // FIXME: 怎么去掉order by, sum是不需要order by的, 影响性能.
|
||||||
|
//tx.Select(q.SummarySql).Take(&summary) // 不适合rawsql?
|
||||||
|
|
||||||
// 安全构建汇总查询 - 使用参数化查询
|
tx.Raw("SELECT "+strings.Join(q.SummarySql, ", ")+" FROM ("+sql2+") ssss", params...).Take(&summary)
|
||||||
summarySQL := "SELECT " + strings.Join(q.SummarySql, ", ") + " FROM (" + sql2 + ") ssss"
|
|
||||||
if err := tx.Raw(summarySQL, params...).Take(&summary).Error; err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// []byte 转 string. 不太合理, 应该返回int或float
|
// []byte 转 string. 不太合理, 应该返回int或float
|
||||||
for k, v := range summary {
|
for k, v := range summary {
|
||||||
@ -385,38 +363,30 @@ func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 安全地处理排序 - 使用白名单验证字段名
|
// 排序处理
|
||||||
if q.OrderBy != "" {
|
if q.OrderBy != "" {
|
||||||
// 验证输入参数的安全性
|
s := fmt.Sprintf(" ORDER BY %s", lxUtil.FieldToColumn(q.OrderBy)) // TODO: q.OrderBy是字符串,可能多个字段 会有问题吗
|
||||||
if !isSafeSQL(q.OrderBy) {
|
builder.WriteString(s)
|
||||||
return errors.New("环境异常")
|
|
||||||
}
|
|
||||||
safeOrderBy := sanitizeOrderBy(q.OrderBy)
|
|
||||||
if safeOrderBy != "" {
|
|
||||||
builder.WriteString(" ORDER BY ")
|
|
||||||
builder.WriteString(safeOrderBy)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 安全地处理分页 - 使用参数化查询
|
// 偏移量处理
|
||||||
if q.Limit > 0 {
|
if q.Limit > 0 {
|
||||||
if q.Offset > 0 {
|
if q.Offset > 0 {
|
||||||
offset := (q.Offset - 1) * q.Limit
|
offset := (q.Offset - 1) * q.Limit
|
||||||
builder.WriteString(" LIMIT ?, ?")
|
s := fmt.Sprintf(" LIMIT %d, %d", offset, q.Limit)
|
||||||
params = append(params, offset, q.Limit)
|
builder.WriteString(s)
|
||||||
} else {
|
} else {
|
||||||
builder.WriteString(" LIMIT ?")
|
s := fmt.Sprintf(" LIMIT %d", q.Limit)
|
||||||
params = append(params, q.Limit)
|
builder.WriteString(s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 执行最终查询 - 使用参数化查询
|
//tx.Raw(builder.String(), params...).Scan(list) // FIXME: unsupported data type: &[] why?
|
||||||
if err := tx.Raw(builder.String(), params...).Find(list).Error; err != nil {
|
tx.Raw(builder.String(), params...).Find(list) // Find与Scan区别: list传入[]时, 查询为空的情况下, Find返回的是[], 而Scan返回的是nil.
|
||||||
return err
|
// ref: What is the difference between Find and Scan: https://github.com/go-gorm/gorm/issues/4218
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 是否需要查询记录条数
|
// 是否需要查询记录条数
|
||||||
@ -521,87 +491,6 @@ func Transaction(txItem, txMain *gorm.DB, fun func(txItem, txMain *gorm.DB) (err
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// isSafeSQL 检查SQL语句是否安全,防止SQL注入
|
|
||||||
func isSafeSQL(sql string) bool {
|
|
||||||
// 转换为大写进行关键字检查
|
|
||||||
upperSQL := strings.ToUpper(sql)
|
|
||||||
|
|
||||||
// 危险关键字列表
|
|
||||||
dangerousKeywords := []string{
|
|
||||||
"DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "CREATE", "TRUNCATE",
|
|
||||||
"EXEC", "EXECUTE", "XP_", "SP_", "UNION", "JOIN", "HAVING", "GROUP BY",
|
|
||||||
"ORDER BY", "LIMIT", "OFFSET", "--", "/*", "*/", ";", "@@", "@",
|
|
||||||
"0X", "CHAR(", "ASCII(", "SUBSTRING(", "MID(", "LENGTH(", "LEN(",
|
|
||||||
"CONCAT(", "LOAD_FILE(", "BENCHMARK(", "SLEEP(", "WAITFOR",
|
|
||||||
"CAST(", "CONVERT(", "IF(", "CASE", "WHEN", "THEN", "END",
|
|
||||||
}
|
|
||||||
|
|
||||||
upperSQL = strings.ReplaceAll(upperSQL, "CREATED_AT", "")
|
|
||||||
upperSQL = strings.ReplaceAll(upperSQL, "UPDATED_AT", "")
|
|
||||||
upperSQL = strings.ReplaceAll(upperSQL, "CREATED_TIME", "")
|
|
||||||
upperSQL = strings.ReplaceAll(upperSQL, "UPDATED_TIME", "")
|
|
||||||
|
|
||||||
// 检查危险关键字
|
|
||||||
for _, keyword := range dangerousKeywords {
|
|
||||||
if strings.Contains(upperSQL, keyword) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查SQL结构是否为简单的SELECT语句
|
|
||||||
if strings.HasPrefix(upperSQL, "SELECT") {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查是否包含注释
|
|
||||||
if strings.Contains(upperSQL, "--") || strings.Contains(upperSQL, "/*") {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查是否包含分号(多条语句)
|
|
||||||
if strings.Contains(upperSQL, ";") {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// sanitizeOrderBy 安全处理 ORDER BY 字段
|
|
||||||
func sanitizeOrderBy(orderBy string) string {
|
|
||||||
// 移除可能危险的字符
|
|
||||||
orderBy = strings.TrimSpace(orderBy)
|
|
||||||
|
|
||||||
// 分割字段和排序方向
|
|
||||||
parts := strings.Fields(orderBy)
|
|
||||||
if len(parts) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证字段名
|
|
||||||
field := parts[0]
|
|
||||||
if !isValidFieldName(field) {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果有排序方向,验证它
|
|
||||||
if len(parts) > 1 {
|
|
||||||
direction := strings.ToUpper(parts[1])
|
|
||||||
if direction != "ASC" && direction != "DESC" {
|
|
||||||
return field // 默认不添加方向
|
|
||||||
}
|
|
||||||
return field + " " + direction
|
|
||||||
}
|
|
||||||
|
|
||||||
return field
|
|
||||||
}
|
|
||||||
|
|
||||||
// isValidFieldName 验证字段名是否安全
|
|
||||||
func isValidFieldName(field string) bool {
|
|
||||||
// 允许字母、数字、下划线和点(用于表名.字段名)
|
|
||||||
matched, _ := regexp.MatchString(`^[a-zA-Z0-9_.]+$`, field)
|
|
||||||
return matched
|
|
||||||
}
|
|
||||||
|
|
||||||
// replaceSelectAndRemoveGroupBy 替换 SELECT 和 FROM 之间的部分,并去掉 GROUP BY
|
// replaceSelectAndRemoveGroupBy 替换 SELECT 和 FROM 之间的部分,并去掉 GROUP BY
|
||||||
func replaceSelectAndRemoveGroupBy(sql string) string {
|
func replaceSelectAndRemoveGroupBy(sql string) string {
|
||||||
// 找到 SELECT 和 FROM 的位置
|
// 找到 SELECT 和 FROM 的位置
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user