Revert "fix: 修复order by和limit、page的sql注入风险"

This reverts commit ead7a214a0.
This commit is contained in:
wangjie 2025-08-22 17:50:13 +08:00
parent fe6c9c6009
commit 0d113a5473

View File

@ -4,7 +4,6 @@ import (
"errors" "errors"
"git.listensoft.net/tool/lxutils/lxUtil" "git.listensoft.net/tool/lxutils/lxUtil"
"gorm.io/gorm" "gorm.io/gorm"
"regexp"
"strings" "strings"
) )
@ -332,17 +331,8 @@ func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery,
// 条件字段 // 条件字段
if q != nil { if q != nil {
where, args := q.BuildRawWhere() where, args := q.BuildRawWhere()
// 安全地添加 WHERE 子句
if hasWhere(sql) { // 原SQL已有WHERE子句 if hasWhere(sql) { // 原SQL已有WHERE子句
// 确保 where 子句以 AND 或 OR 开头,然后安全添加 builder.WriteString(where) // 去掉where 前头的and or ..
if strings.HasPrefix(where, " AND ") || strings.HasPrefix(where, " OR ") {
builder.WriteString(where)
} else {
// 如果 where 不是以 AND/OR 开头,添加 AND 前缀
builder.WriteString(" AND ")
builder.WriteString(strings.TrimSpace(where))
}
} else { // 原SQL没有WHERE子句 } else { // 原SQL没有WHERE子句
if strings.HasPrefix(where, " AND ") { if strings.HasPrefix(where, " AND ") {
where = strings.Replace(where, " AND ", " WHERE ", 1) where = strings.Replace(where, " AND ", " WHERE ", 1)
@ -350,12 +340,10 @@ func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery,
} else if strings.HasPrefix(where, " OR ") { } else if strings.HasPrefix(where, " OR ") {
where = strings.Replace(where, " OR ", " WHERE ", 1) where = strings.Replace(where, " OR ", " WHERE ", 1)
builder.WriteString(where) builder.WriteString(where)
} else if where != "" { } else {
builder.WriteString(" WHERE ") builder.WriteString(where) // "" 或者 " GROUP BY ..."
builder.WriteString(where)
} }
} }
if len(args) > 0 { if len(args) > 0 {
params = append(params, args...) params = append(params, args...)
} }
@ -366,31 +354,31 @@ func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery,
// 记录条数 // 记录条数
if needDoCount(q) { if needDoCount(q) {
var total int64 var total int64
// 使用参数化查询 //tx = tx.Count(&total)
//tx.Raw("SELECT COUNT(*) as total FROM ("+sql2+") aaaa", params...).Take(&total)
//todo lcs 优化速度 start
// 替换 SELECT 和 FROM 之间的部分,并去掉 GROUP BY
replacedSQL := replaceSelectAndRemoveGroupBy(sql2) replacedSQL := replaceSelectAndRemoveGroupBy(sql2)
if err := tx.Raw(replacedSQL, params...).Take(&total).Error; err != nil { tx.Raw(replacedSQL, params...).Take(&total)
return err //todo end
}
q.Total = int(total) q.Total = int(total)
if total == 0 { if total == 0 { // 如果查了记录条数并且是0, 就不需要查记录和汇总了
return return
} }
// 获取汇总信息 // 获取汇总信息 // TODO: 汇总应该放到查询列表的后面
if q.Summary != "" && len(q.SummarySql) == 0 { if q.Summary != "" && len(q.SummarySql) == 0 {
// 安全地构建汇总字段
q.SummarySql = fieldsToSumSql(q.Summary) q.SummarySql = fieldsToSumSql(q.Summary)
} }
if len(q.Summary) != 0 { if len(q.Summary) != 0 {
tx = tx.Offset(-1) tx = tx.Offset(-1) // 需要去除offset, 否则结果可能为空, 注意: 设置0不起作用.
var summary = make(map[string]interface{}) var summary = make(map[string]interface{})
//tx.Order("") // FIXME: 怎么去掉order by, sum是不需要order by的, 影响性能.
//tx.Select(q.SummarySql).Take(&summary) // 不适合rawsql?
// 安全构建汇总查询 - 使用参数化查询 tx.Raw("SELECT "+strings.Join(q.SummarySql, ", ")+" FROM ("+sql2+") ssss", params...).Take(&summary)
summarySQL := "SELECT " + strings.Join(q.SummarySql, ", ") + " FROM (" + sql2 + ") ssss"
if err := tx.Raw(summarySQL, params...).Take(&summary).Error; err != nil {
return err
}
// []byte 转 string. 不太合理, 应该返回int或float
for k, v := range summary { for k, v := range summary {
if bs, ok := v.([]byte); ok { if bs, ok := v.([]byte); ok {
summary[k] = string(bs) summary[k] = string(bs)
@ -400,70 +388,30 @@ func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery,
} }
} }
// 安全地处理排序 - 使用白名单验证字段名 // 排序处理
if q.OrderBy != "" { if q.OrderBy != "" {
safeOrderBy := sanitizeOrderBy(q.OrderBy) s := fmt.Sprintf(" ORDER BY %s", lxUtil.FieldToColumn(q.OrderBy)) // TODO: q.OrderBy是字符串,可能多个字段 会有问题吗
if safeOrderBy != "" { builder.WriteString(s)
builder.WriteString(" ORDER BY ")
builder.WriteString(safeOrderBy)
}
} }
// 安全地处理分页 - 使用参数化查询 // 偏移量处理
if q.Limit > 0 { if q.Limit > 0 {
if q.Offset > 0 { if q.Offset > 0 {
offset := (q.Offset - 1) * q.Limit offset := (q.Offset - 1) * q.Limit
builder.WriteString(" LIMIT ?, ?") s := fmt.Sprintf(" LIMIT %d, %d", offset, q.Limit)
params = append(params, offset, q.Limit) builder.WriteString(s)
} else { } else {
builder.WriteString(" LIMIT ?") s := fmt.Sprintf(" LIMIT %d", q.Limit)
params = append(params, q.Limit) builder.WriteString(s)
} }
} }
} }
// 执行最终查询 //tx.Raw(builder.String(), params...).Scan(list) // FIXME: unsupported data type: &[] why?
if err := tx.Raw(builder.String(), params...).Find(list).Error; err != nil { tx.Raw(builder.String(), params...).Find(list) // Find与Scan区别: list传入[]时, 查询为空的情况下, Find返回的是[], 而Scan返回的是nil.
return err // ref: What is the difference between Find and Scan: https://github.com/go-gorm/gorm/issues/4218
}
return nil return
}
// 辅助函数:安全处理 ORDER BY 字段
func sanitizeOrderBy(orderBy string) string {
// 移除可能危险的字符
orderBy = strings.TrimSpace(orderBy)
// 分割字段和排序方向
parts := strings.Fields(orderBy)
if len(parts) == 0 {
return ""
}
// 验证字段名(简单示例,应根据实际需求扩展)
field := parts[0]
if !isValidFieldName(field) {
return ""
}
// 如果有排序方向,验证它
if len(parts) > 1 {
direction := strings.ToUpper(parts[1])
if direction != "ASC" && direction != "DESC" {
return field // 默认不添加方向
}
return field + " " + direction
}
return field
}
// 辅助函数:验证字段名是否安全
func isValidFieldName(field string) bool {
// 允许字母、数字、下划线和点(用于表名.字段名)
matched, _ := regexp.MatchString(`^[a-zA-Z0-9_.]+$`, field)
return matched
} }
// 是否需要查询记录条数 // 是否需要查询记录条数