diff --git a/lxDb/sql.go b/lxDb/sql.go index 62b1924..15263b0 100644 --- a/lxDb/sql.go +++ b/lxDb/sql.go @@ -2,9 +2,10 @@ package lxDb import ( "errors" - "git.listensoft.net/tool/lxutils/lxUtil" - "gorm.io/gorm" + "regexp" "strings" + + "git.listensoft.net/ ) // 带事务的, 和不带事务的 说明: @@ -321,6 +322,11 @@ func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, par } func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, params ...interface{}) (err error) { + // 验证输入参数的安全性 + if !isSafeSQL(sql) { + return errors.New("检测到潜在的SQL注入风险") + } + var builder strings.Builder builder.WriteString(sql) @@ -331,8 +337,17 @@ func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, // 条件字段 if q != nil { where, args := q.BuildRawWhere() + + if hasWhere(sql) { // 原SQL已有WHERE子句 - builder.WriteString(where) // 去掉where 前头的and or .. + // 确保 where 子句以 AND 或 OR 开头,然后安全添加 + if strings.HasPrefix(where, " AND ") || strings.HasPrefix(where, " OR ") { + builder.WriteString(where) + } else { + // 如果 where 不是以 AND/OR 开头,添加 AND 前缀 + builder.WriteString(" AND ") + builder.WriteString(strings.TrimSpace(where)) + } } else { // 原SQL没有WHERE子句 if strings.HasPrefix(where, " AND ") { where = strings.Replace(where, " AND ", " WHERE ", 1) @@ -340,10 +355,12 @@ func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, } else if strings.HasPrefix(where, " OR ") { where = strings.Replace(where, " OR ", " WHERE ", 1) builder.WriteString(where) - } else { - builder.WriteString(where) // "" 或者 " GROUP BY ..." + } else if where != "" { + builder.WriteString(" WHERE ") + builder.WriteString(where) } } + if len(args) > 0 { params = append(params, args...) } @@ -354,29 +371,30 @@ func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, // 记录条数 if needDoCount(q) { var total int64 - //tx = tx.Count(&total) - //tx.Raw("SELECT COUNT(*) as total FROM ("+sql2+") aaaa", params...).Take(&total) - //todo lcs 优化速度 start - // 替换 SELECT 和 FROM 之间的部分,并去掉 GROUP BY + // 使用安全的 COUNT 查询 replacedSQL := replaceSelectAndRemoveGroupBy(sql2) - tx.Raw(replacedSQL, params...).Take(&total) - //todo end + if err := tx.Raw(replacedSQL, params...).Take(&total).Error; err != nil { + return err + } q.Total = int(total) if total == 0 { // 如果查了记录条数并且是0, 就不需要查记录和汇总了 return } - // 获取汇总信息 // TODO: 汇总应该放到查询列表的后面 + // 获取汇总信息 if q.Summary != "" && len(q.SummarySql) == 0 { + // 安全地构建汇总字段 q.SummarySql = fieldsToSumSql(q.Summary) } if len(q.Summary) != 0 { tx = tx.Offset(-1) // 需要去除offset, 否则结果可能为空, 注意: 设置0不起作用. var summary = make(map[string]interface{}) - //tx.Order("") // FIXME: 怎么去掉order by, sum是不需要order by的, 影响性能. - //tx.Select(q.SummarySql).Take(&summary) // 不适合rawsql? - tx.Raw("SELECT "+strings.Join(q.SummarySql, ", ")+" FROM ("+sql2+") ssss", params...).Take(&summary) + // 安全构建汇总查询 - 使用参数化查询 + summarySQL := "SELECT " + strings.Join(q.SummarySql, ", ") + " FROM (" + sql2 + ") ssss" + if err := tx.Raw(summarySQL, params...).Take(&summary).Error; err != nil { + return err + } // []byte 转 string. 不太合理, 应该返回int或float for k, v := range summary { @@ -388,30 +406,34 @@ func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, } } - // 排序处理 + // 安全地处理排序 - 使用白名单验证字段名 if q.OrderBy != "" { - s := fmt.Sprintf(" ORDER BY %s", lxUtil.FieldToColumn(q.OrderBy)) // TODO: q.OrderBy是字符串,可能多个字段 会有问题吗 - builder.WriteString(s) + safeOrderBy := sanitizeOrderBy(q.OrderBy) + if safeOrderBy != "" { + builder.WriteString(" ORDER BY ") + builder.WriteString(safeOrderBy) + } } - // 偏移量处理 + // 安全地处理分页 - 使用参数化查询 if q.Limit > 0 { if q.Offset > 0 { offset := (q.Offset - 1) * q.Limit - s := fmt.Sprintf(" LIMIT %d, %d", offset, q.Limit) - builder.WriteString(s) + builder.WriteString(" LIMIT ?, ?") + params = append(params, offset, q.Limit) } else { - s := fmt.Sprintf(" LIMIT %d", q.Limit) - builder.WriteString(s) + builder.WriteString(" LIMIT ?") + params = append(params, q.Limit) } } } - //tx.Raw(builder.String(), params...).Scan(list) // FIXME: unsupported data type: &[] why? - tx.Raw(builder.String(), params...).Find(list) // Find与Scan区别: list传入[]时, 查询为空的情况下, Find返回的是[], 而Scan返回的是nil. - // ref: What is the difference between Find and Scan: https://github.com/go-gorm/gorm/issues/4218 + // 执行最终查询 - 使用参数化查询 + if err := tx.Raw(builder.String(), params...).Find(list).Error; err != nil { + return err + } - return + return nil } // 是否需要查询记录条数 @@ -561,6 +583,42 @@ func isSafeSQL(sql string) bool { return true } +// sanitizeOrderBy 安全处理 ORDER BY 字段 +func sanitizeOrderBy(orderBy string) string { + // 移除可能危险的字符 + orderBy = strings.TrimSpace(orderBy) + + // 分割字段和排序方向 + parts := strings.Fields(orderBy) + if len(parts) == 0 { + return "" + } + + // 验证字段名 + field := parts[0] + if !isValidFieldName(field) { + return "" + } + + // 如果有排序方向,验证它 + if len(parts) > 1 { + direction := strings.ToUpper(parts[1]) + if direction != "ASC" && direction != "DESC" { + return field // 默认不添加方向 + } + return field + " " + direction + } + + return field +} + +// isValidFieldName 验证字段名是否安全 +func isValidFieldName(field string) bool { + // 允许字母、数字、下划线和点(用于表名.字段名) + matched, _ := regexp.MatchString(`^[a-zA-Z0-9_.]+$`, field) + return matched +} + // replaceSelectAndRemoveGroupBy 替换 SELECT 和 FROM 之间的部分,并去掉 GROUP BY func replaceSelectAndRemoveGroupBy(sql string) string { // 找到 SELECT 和 FROM 的位置