From 1f5f7a3dc6e13eeac0ab654b0d64db3418dfa4bc Mon Sep 17 00:00:00 2001 From: wangning Date: Mon, 8 Sep 2025 09:50:00 +0800 Subject: [PATCH] =?UTF-8?q?fix:=E5=9B=9E=E9=80=80=E5=85=AC=E5=85=B1query?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lxDb/sql.go | 175 ++++++++++------------------------------------------ 1 file changed, 32 insertions(+), 143 deletions(-) diff --git a/lxDb/sql.go b/lxDb/sql.go index ae6d6ef..2945aa6 100644 --- a/lxDb/sql.go +++ b/lxDb/sql.go @@ -5,7 +5,6 @@ import ( "fmt" "git.listensoft.net/tool/lxutils/lxUtil" "gorm.io/gorm" - "regexp" "strings" ) @@ -208,7 +207,6 @@ func Query(tx *gorm.DB, m interface{}, list interface{}, q *PaginationQuery) (er //} func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, params ...interface{}) (err error) { - var builder strings.Builder builder.WriteString(sql) @@ -219,8 +217,6 @@ func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, par // 条件字段 if q != nil { where, args := q.BuildRawWhere() - - // 安全地添加 WHERE 子句 if hasWhere(sql) { // 原SQL已有WHERE子句 builder.WriteString(where) // 去掉where 前头的and or .. } else { // 原SQL没有WHERE子句 @@ -231,10 +227,9 @@ func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, par where = strings.Replace(where, " OR ", " WHERE ", 1) builder.WriteString(where) } else { - builder.WriteString(where) + builder.WriteString(where) // "" 或者 " GROUP BY ..." } } - if len(args) > 0 { params = append(params, args...) } @@ -252,9 +247,8 @@ func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, par return } - // 获取汇总信息 + // 获取汇总信息 // TODO: 汇总应该放到查询列表的后面 if q.Summary != "" && len(q.SummarySql) == 0 { - // 安全地构建汇总字段 q.SummarySql = fieldsToSumSql(q.Summary) } if len(q.Summary) != 0 { @@ -266,7 +260,6 @@ func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, par tx.Raw("SELECT "+strings.Join(q.SummarySql, ", ")+" FROM ("+sql2+") ssss", params...).Take(&summary) // []byte 转 string. 不太合理, 应该返回int或float - for k, v := range summary { if bs, ok := v.([]byte); ok { summary[k] = string(bs) @@ -276,17 +269,13 @@ func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, par } } - // 安全地处理排序 - 使用白名单验证字段名 + // 排序处理 if q.OrderBy != "" { - // 验证输入参数的安全性 - if !isSafeSQL(q.OrderBy) { - return errors.New("环境异常") - } s := fmt.Sprintf(" ORDER BY %s", lxUtil.FieldToColumn(q.OrderBy)) // TODO: q.OrderBy是字符串,可能多个字段 会有问题吗 builder.WriteString(s) } - // 安全地处理分页 - 使用参数化查询 + // 偏移量处理 if q.Limit > 0 { if q.Offset > 0 { offset := (q.Offset - 1) * q.Limit @@ -299,14 +288,14 @@ func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, par } } - // 执行最终查询 - 使用参数化查询 - tx.Raw(builder.String(), params...).Find(list) + //tx.Raw(builder.String(), params...).Scan(list) // FIXME: unsupported data type: &[] why? + tx.Raw(builder.String(), params...).Find(list) // Find与Scan区别: list传入[]时, 查询为空的情况下, Find返回的是[], 而Scan返回的是nil. + // ref: What is the difference between Find and Scan: https://github.com/go-gorm/gorm/issues/4218 return } func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, params ...interface{}) (err error) { - var builder strings.Builder builder.WriteString(sql) @@ -317,16 +306,8 @@ func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, // 条件字段 if q != nil { where, args := q.BuildRawWhere() - if hasWhere(sql) { // 原SQL已有WHERE子句 - // 确保 where 子句以 AND 或 OR 开头,然后安全添加 - if strings.HasPrefix(where, " AND ") || strings.HasPrefix(where, " OR ") { - builder.WriteString(where) - } else { - // 如果 where 不是以 AND/OR 开头,添加 AND 前缀 - builder.WriteString(" AND ") - builder.WriteString(strings.TrimSpace(where)) - } + builder.WriteString(where) // 去掉where 前头的and or .. } else { // 原SQL没有WHERE子句 if strings.HasPrefix(where, " AND ") { where = strings.Replace(where, " AND ", " WHERE ", 1) @@ -334,12 +315,10 @@ func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, } else if strings.HasPrefix(where, " OR ") { where = strings.Replace(where, " OR ", " WHERE ", 1) builder.WriteString(where) - } else if where != "" { - builder.WriteString(" WHERE ") - builder.WriteString(where) + } else { + builder.WriteString(where) // "" 或者 " GROUP BY ..." } } - if len(args) > 0 { params = append(params, args...) } @@ -350,30 +329,29 @@ func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, // 记录条数 if needDoCount(q) { var total int64 - // 使用安全的 COUNT 查询 + //tx = tx.Count(&total) + //tx.Raw("SELECT COUNT(*) as total FROM ("+sql2+") aaaa", params...).Take(&total) + //todo lcs 优化速度 start + // 替换 SELECT 和 FROM 之间的部分,并去掉 GROUP BY replacedSQL := replaceSelectAndRemoveGroupBy(sql2) - if err := tx.Raw(replacedSQL, params...).Take(&total).Error; err != nil { - return err - } + tx.Raw(replacedSQL, params...).Take(&total) + //todo end q.Total = int(total) if total == 0 { // 如果查了记录条数并且是0, 就不需要查记录和汇总了 return } - // 获取汇总信息 + // 获取汇总信息 // TODO: 汇总应该放到查询列表的后面 if q.Summary != "" && len(q.SummarySql) == 0 { - // 安全地构建汇总字段 q.SummarySql = fieldsToSumSql(q.Summary) } if len(q.Summary) != 0 { tx = tx.Offset(-1) // 需要去除offset, 否则结果可能为空, 注意: 设置0不起作用. var summary = make(map[string]interface{}) + //tx.Order("") // FIXME: 怎么去掉order by, sum是不需要order by的, 影响性能. + //tx.Select(q.SummarySql).Take(&summary) // 不适合rawsql? - // 安全构建汇总查询 - 使用参数化查询 - summarySQL := "SELECT " + strings.Join(q.SummarySql, ", ") + " FROM (" + sql2 + ") ssss" - if err := tx.Raw(summarySQL, params...).Take(&summary).Error; err != nil { - return err - } + tx.Raw("SELECT "+strings.Join(q.SummarySql, ", ")+" FROM ("+sql2+") ssss", params...).Take(&summary) // []byte 转 string. 不太合理, 应该返回int或float for k, v := range summary { @@ -385,38 +363,30 @@ func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, } } - // 安全地处理排序 - 使用白名单验证字段名 + // 排序处理 if q.OrderBy != "" { - // 验证输入参数的安全性 - if !isSafeSQL(q.OrderBy) { - return errors.New("环境异常") - } - safeOrderBy := sanitizeOrderBy(q.OrderBy) - if safeOrderBy != "" { - builder.WriteString(" ORDER BY ") - builder.WriteString(safeOrderBy) - } + s := fmt.Sprintf(" ORDER BY %s", lxUtil.FieldToColumn(q.OrderBy)) // TODO: q.OrderBy是字符串,可能多个字段 会有问题吗 + builder.WriteString(s) } - // 安全地处理分页 - 使用参数化查询 + // 偏移量处理 if q.Limit > 0 { if q.Offset > 0 { offset := (q.Offset - 1) * q.Limit - builder.WriteString(" LIMIT ?, ?") - params = append(params, offset, q.Limit) + s := fmt.Sprintf(" LIMIT %d, %d", offset, q.Limit) + builder.WriteString(s) } else { - builder.WriteString(" LIMIT ?") - params = append(params, q.Limit) + s := fmt.Sprintf(" LIMIT %d", q.Limit) + builder.WriteString(s) } } } - // 执行最终查询 - 使用参数化查询 - if err := tx.Raw(builder.String(), params...).Find(list).Error; err != nil { - return err - } + //tx.Raw(builder.String(), params...).Scan(list) // FIXME: unsupported data type: &[] why? + tx.Raw(builder.String(), params...).Find(list) // Find与Scan区别: list传入[]时, 查询为空的情况下, Find返回的是[], 而Scan返回的是nil. + // ref: What is the difference between Find and Scan: https://github.com/go-gorm/gorm/issues/4218 - return nil + return } // 是否需要查询记录条数 @@ -521,87 +491,6 @@ func Transaction(txItem, txMain *gorm.DB, fun func(txItem, txMain *gorm.DB) (err return } -// isSafeSQL 检查SQL语句是否安全,防止SQL注入 -func isSafeSQL(sql string) bool { - // 转换为大写进行关键字检查 - upperSQL := strings.ToUpper(sql) - - // 危险关键字列表 - dangerousKeywords := []string{ - "DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "CREATE", "TRUNCATE", - "EXEC", "EXECUTE", "XP_", "SP_", "UNION", "JOIN", "HAVING", "GROUP BY", - "ORDER BY", "LIMIT", "OFFSET", "--", "/*", "*/", ";", "@@", "@", - "0X", "CHAR(", "ASCII(", "SUBSTRING(", "MID(", "LENGTH(", "LEN(", - "CONCAT(", "LOAD_FILE(", "BENCHMARK(", "SLEEP(", "WAITFOR", - "CAST(", "CONVERT(", "IF(", "CASE", "WHEN", "THEN", "END", - } - - upperSQL = strings.ReplaceAll(upperSQL, "CREATED_AT", "") - upperSQL = strings.ReplaceAll(upperSQL, "UPDATED_AT", "") - upperSQL = strings.ReplaceAll(upperSQL, "CREATED_TIME", "") - upperSQL = strings.ReplaceAll(upperSQL, "UPDATED_TIME", "") - - // 检查危险关键字 - for _, keyword := range dangerousKeywords { - if strings.Contains(upperSQL, keyword) { - return false - } - } - - // 检查SQL结构是否为简单的SELECT语句 - if strings.HasPrefix(upperSQL, "SELECT") { - return false - } - - // 检查是否包含注释 - if strings.Contains(upperSQL, "--") || strings.Contains(upperSQL, "/*") { - return false - } - - // 检查是否包含分号(多条语句) - if strings.Contains(upperSQL, ";") { - return false - } - - return true -} - -// sanitizeOrderBy 安全处理 ORDER BY 字段 -func sanitizeOrderBy(orderBy string) string { - // 移除可能危险的字符 - orderBy = strings.TrimSpace(orderBy) - - // 分割字段和排序方向 - parts := strings.Fields(orderBy) - if len(parts) == 0 { - return "" - } - - // 验证字段名 - field := parts[0] - if !isValidFieldName(field) { - return "" - } - - // 如果有排序方向,验证它 - if len(parts) > 1 { - direction := strings.ToUpper(parts[1]) - if direction != "ASC" && direction != "DESC" { - return field // 默认不添加方向 - } - return field + " " + direction - } - - return field -} - -// isValidFieldName 验证字段名是否安全 -func isValidFieldName(field string) bool { - // 允许字母、数字、下划线和点(用于表名.字段名) - matched, _ := regexp.MatchString(`^[a-zA-Z0-9_.]+$`, field) - return matched -} - // replaceSelectAndRemoveGroupBy 替换 SELECT 和 FROM 之间的部分,并去掉 GROUP BY func replaceSelectAndRemoveGroupBy(sql string) string { // 找到 SELECT 和 FROM 的位置