lxutils/lxDb/sql.go
2025-08-22 19:08:43 +08:00

646 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package lxDb
import (
"errors"
"git.listensoft.net/tool/lxutils/lxUtil"
"gorm.io/gorm"
"regexp"
"strings"
)
// 带事务的, 和不带事务的 说明:
// 如果需要支持事务请调aaaTx方法, 并传开启事务的DB
// OneByIdTx 通过ID查询一条
func OneById(tx *gorm.DB, m interface{}, id uint) (err error) {
if id == 0 {
return errors.New("主键ID不能为空")
}
if tx.Take(m, id).Error != nil { // or First ?
return errors.New("未找到记录")
}
return nil
}
// OneTx 查询一条 使用m的有值属性作为条件, m必须是Model结构体. 条件为空时报错.
func One(tx *gorm.DB, m interface{}) error {
return oneTx(tx, m, "")
}
func oneTx(tx *gorm.DB, m interface{}, Type string) (err error) {
//reflectVal := reflect.ValueOf(m)
//t := reflect.Indirect(reflectVal).Type()
//newObj := reflect.New(t)
if lxUtil.IsZeroOfUnderlyingType(m) {
return errors.New("条件不能为空")
}
// 这里有一个特别的情况, 如果m.id有值, 生成sql的where里id条件出现两次, 但是不影响效果
if Type == "first" { // 第一个
err = tx.Where(m).First(m).Error
} else if Type == "last" { // 最后一个
err = tx.Where(m).Last(m).Error
} else { // 就是一个
err = tx.Where(m).Take(m).Error
}
if err != nil {
return err
}
return nil
}
// FirstTx 查询第一条 使用m的有值属性作为条件, m必须是Model结构体. 条件为空时报错.
func First(tx *gorm.DB, m interface{}) error {
return oneTx(tx, m, "first")
}
// LastTx 查询最后一条 使用m的有值属性作为条件, m必须是Model结构体. 条件为空时报错.
func Last(tx *gorm.DB, m interface{}) error {
return oneTx(tx, m, "last")
}
// One 查询一条. 这种方式时不行的, 实际上对应的表是 base_model
//func (m *BaseModel) One() (err error) {
// //if DB.Where(m).First(one).RecordNotFound() {
// if DB.Where(m).First(m).Error != nil {
// return errors.New("resource is not found")
// }
// return nil
//}
// CreateTx 带事务的, 新增
func Create(tx *gorm.DB, m interface{}) error {
return tx.Create(m).Error
}
// UpdateTx 带事务的, 更新一条数据的单个字段, m.ID必须有值
func Update(tx *gorm.DB, m interface{}, field string, value interface{}) error {
db := tx.Model(m).Update(field, value)
if err := db.Error; err != nil {
return err
}
if db.RowsAffected != 1 {
return errors.New("id is invalid and resource is not found")
}
return nil
}
// UpdatesTx 带事务的, 更新一条数据的多个字段, m.ID必须有值
func Updates(tx *gorm.DB, m interface{}, fields interface{}) error {
db := tx.Model(m).Updates(fields)
if err := db.Error; err != nil {
return err
}
if db.RowsAffected != 1 {
return errors.New("id is invalid and resource is not found")
}
return nil
}
// DeleteTx 删除, m.ID必须有值. tx不为空就是带事务的
func Delete(tx *gorm.DB, m interface{}) error {
//func DeleteTx(tx *gorm.DB, m interface{}, conds ...interface{}) error {
db := tx.Delete(m)
//db := tx.Delete(m, conds...)
if err := db.Error; err != nil {
return err
}
if db.RowsAffected != 1 {
return errors.New("未找到要删除的数据")
}
return nil
}
// ListTx 查询数据列表, m是库表结构体, m的有值属性会作为查询条件, 且必须有条件, list里个体的类型可以与m的类型不同
func List(tx *gorm.DB, m interface{}, list interface{}) (err error) {
if lxUtil.IsZeroOfUnderlyingType(m) {
return errors.New("条件不能为空")
}
if tx.Model(m).Where(m).Find(list).Error != nil {
// if tx.Where(m).Find(list).Error != nil {
return errors.New("查询出现错误")
}
return nil
}
// ListAllTx 查询所有数据, m是库表结构体, m的有值属性不会作为查询条件, list里个体的类型可以与m的类型不同
func ListAll(tx *gorm.DB, m interface{}, list interface{}) (err error) {
if tx.Model(m).Find(list).Error != nil {
// if tx.Where(m).Find(list).Error != nil {
return errors.New("查询出现错误")
}
return nil
}
// QueryTx 带事务的查询数据列表. tx为空就是不带事务, 否则认为是开启了事务. 你应该避免使用次方法, 而是使用Query
func Query(tx *gorm.DB, m interface{}, list interface{}, q *PaginationQuery) (err error) {
// !! 关于会话 新的Statement实例 及复用 ref: https://gorm.io/zh_CN/docs/method_chaining.html
if q == nil {
err = errors.New("paginationQuery不可为nil")
return
}
tx = tx.Model(m)
// 注意: count, 查询list, summary的顺序不能变.
tx = q.Build(tx) // 构造查询条件
// 记录条数
if needDoCount(q) {
var total int64
tx = tx.Count(&total)
q.Total = int(total)
if total == 0 { // 如果查了记录条数并且是0, 就不需要查记录和汇总了
return
}
}
if q.OrderBy != "" {
tx = tx.Order(lxUtil.FieldToColumn(q.OrderBy)) // TODO: q.OrderBy是字符串,可能多个字段 会有问题吗
//tx = tx.Order(q.OrderBy)
}
if q.Offset > 0 {
tx = tx.Offset((q.Offset - 1) * q.Limit)
}
if q.Limit > 0 {
tx = tx.Limit(q.Limit)
}
// 获取查询值
err = tx.Find(list).Error
// 获取汇总信息, 如果不需要查记录条数就不再查汇总
if needDoCount(q) {
if q.Summary != "" && len(q.SummarySql) == 0 {
q.SummarySql = fieldsToSumSql(q.Summary)
}
if len(q.Summary) != 0 {
tx = tx.Offset(-1) // 需要去除offset, 否则结果可能为空, 注意: 设置0不起作用.
var summary = make(map[string]interface{})
//tx.Order("") // FIXME: 怎么去掉order by, sum是不需要order by的, 影响性能.
tx.Select(q.SummarySql).Take(&summary)
// []byte 转 string. 不太合理, 应该返回int或float
for k, v := range summary {
if bs, ok := v.([]byte); ok {
summary[k] = string(bs)
}
}
q.SummaryResult = summary
}
}
return
}
//// SqlOne 原生SQL查询一个
//func SqlOne(sql string, m interface{}) {
// DB.Raw(sql, m)
//}
//
//// SqlList 原生SQL查询列表
//func SqlList() {
//
//}
func SqlQuery(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, params ...interface{}) (err error) {
var builder strings.Builder
builder.WriteString(sql)
if params == nil {
params = make([]interface{}, 0)
}
// 条件字段
if q != nil {
where, args := q.BuildRawWhere()
// 安全地添加 WHERE 子句
if hasWhere(sql) { // 原SQL已有WHERE子句
// 确保 where 子句以 AND 或 OR 开头,然后安全添加
if strings.HasPrefix(where, " AND ") || strings.HasPrefix(where, " OR ") {
builder.WriteString(where)
} else {
// 如果 where 不是以 AND/OR 开头,添加 AND 前缀
builder.WriteString(" AND ")
builder.WriteString(strings.TrimSpace(where))
}
} else { // 原SQL没有WHERE子句
if strings.HasPrefix(where, " AND ") {
where = strings.Replace(where, " AND ", " WHERE ", 1)
builder.WriteString(where)
} else if strings.HasPrefix(where, " OR ") {
where = strings.Replace(where, " OR ", " WHERE ", 1)
builder.WriteString(where)
} else if where != "" {
builder.WriteString(" WHERE ")
builder.WriteString(where)
}
}
if len(args) > 0 {
params = append(params, args...)
}
// 半成品 sql 用于查询其他信息
var sql2 = builder.String()
// 记录条数
if needDoCount(q) {
var total int64
// 使用安全的 COUNT 查询
countSQL := replaceSelectAndRemoveGroupBy(sql2)
if err := tx.Raw(countSQL, params...).Take(&total).Error; err != nil {
return err
}
q.Total = int(total)
if total == 0 { // 如果查了记录条数并且是0, 就不需要查记录和汇总了
return
}
// 获取汇总信息
if q.Summary != "" && len(q.SummarySql) == 0 {
// 安全地构建汇总字段
q.SummarySql = fieldsToSumSql(q.Summary)
}
if len(q.Summary) != 0 {
tx = tx.Offset(-1) // 需要去除offset, 否则结果可能为空, 注意: 设置0不起作用.
var summary = make(map[string]interface{})
// 安全构建汇总查询 - 使用参数化查询
summarySQL := "SELECT " + strings.Join(q.SummarySql, ", ") + " FROM (" + sql2 + ") ssss"
if err := tx.Raw(summarySQL, params...).Take(&summary).Error; err != nil {
return err
}
// []byte 转 string. 不太合理, 应该返回int或float
for k, v := range summary {
if bs, ok := v.([]byte); ok {
summary[k] = string(bs)
}
}
q.SummaryResult = summary
}
}
// 安全地处理排序 - 使用白名单验证字段名
if q.OrderBy != "" {
// 验证输入参数的安全性
if !isSafeSQL(q.OrderBy) {
return errors.New("环境异常")
}
safeOrderBy := sanitizeOrderBy(q.OrderBy)
if safeOrderBy != "" {
builder.WriteString(" ORDER BY ")
builder.WriteString(safeOrderBy)
}
}
// 安全地处理分页 - 使用参数化查询
if q.Limit > 0 {
if q.Offset > 0 {
offset := (q.Offset - 1) * q.Limit
builder.WriteString(" LIMIT ?, ?")
params = append(params, offset, q.Limit)
} else {
builder.WriteString(" LIMIT ?")
params = append(params, q.Limit)
}
}
}
// 执行最终查询 - 使用参数化查询
if err := tx.Raw(builder.String(), params...).Find(list).Error; err != nil {
return err
}
return
}
func SqlQueryNew(tx *gorm.DB, sql string, list interface{}, q *PaginationQuery, params ...interface{}) (err error) {
var builder strings.Builder
builder.WriteString(sql)
if params == nil {
params = make([]interface{}, 0)
}
// 条件字段
if q != nil {
where, args := q.BuildRawWhere()
if hasWhere(sql) { // 原SQL已有WHERE子句
// 确保 where 子句以 AND 或 OR 开头,然后安全添加
if strings.HasPrefix(where, " AND ") || strings.HasPrefix(where, " OR ") {
builder.WriteString(where)
} else {
// 如果 where 不是以 AND/OR 开头,添加 AND 前缀
builder.WriteString(" AND ")
builder.WriteString(strings.TrimSpace(where))
}
} else { // 原SQL没有WHERE子句
if strings.HasPrefix(where, " AND ") {
where = strings.Replace(where, " AND ", " WHERE ", 1)
builder.WriteString(where)
} else if strings.HasPrefix(where, " OR ") {
where = strings.Replace(where, " OR ", " WHERE ", 1)
builder.WriteString(where)
} else if where != "" {
builder.WriteString(" WHERE ")
builder.WriteString(where)
}
}
if len(args) > 0 {
params = append(params, args...)
}
// 半成品 sql 用于查询其他信息
var sql2 = builder.String()
// 记录条数
if needDoCount(q) {
var total int64
// 使用安全的 COUNT 查询
replacedSQL := replaceSelectAndRemoveGroupBy(sql2)
if err := tx.Raw(replacedSQL, params...).Take(&total).Error; err != nil {
return err
}
q.Total = int(total)
if total == 0 { // 如果查了记录条数并且是0, 就不需要查记录和汇总了
return
}
// 获取汇总信息
if q.Summary != "" && len(q.SummarySql) == 0 {
// 安全地构建汇总字段
q.SummarySql = fieldsToSumSql(q.Summary)
}
if len(q.Summary) != 0 {
tx = tx.Offset(-1) // 需要去除offset, 否则结果可能为空, 注意: 设置0不起作用.
var summary = make(map[string]interface{})
// 安全构建汇总查询 - 使用参数化查询
summarySQL := "SELECT " + strings.Join(q.SummarySql, ", ") + " FROM (" + sql2 + ") ssss"
if err := tx.Raw(summarySQL, params...).Take(&summary).Error; err != nil {
return err
}
// []byte 转 string. 不太合理, 应该返回int或float
for k, v := range summary {
if bs, ok := v.([]byte); ok {
summary[k] = string(bs)
}
}
q.SummaryResult = summary
}
}
// 安全地处理排序 - 使用白名单验证字段名
if q.OrderBy != "" {
// 验证输入参数的安全性
if !isSafeSQL(q.OrderBy) {
return errors.New("环境异常")
}
safeOrderBy := sanitizeOrderBy(q.OrderBy)
if safeOrderBy != "" {
builder.WriteString(" ORDER BY ")
builder.WriteString(safeOrderBy)
}
}
// 安全地处理分页 - 使用参数化查询
if q.Limit > 0 {
if q.Offset > 0 {
offset := (q.Offset - 1) * q.Limit
builder.WriteString(" LIMIT ?, ?")
params = append(params, offset, q.Limit)
} else {
builder.WriteString(" LIMIT ?")
params = append(params, q.Limit)
}
}
}
// 执行最终查询 - 使用参数化查询
if err := tx.Raw(builder.String(), params...).Find(list).Error; err != nil {
return err
}
return nil
}
// 是否需要查询记录条数
func needDoCount(q *PaginationQuery) bool {
if q.NoTotal {
return false
}
if q.Limit == 0 { // 不限制条数, 等同于查所有记录, 这时候就不需要查记录条数
return false
}
//return q.Offset <= 1 // todo lcs 为什么要这样写 第二页都没了
return true
}
// utils -----------------
// "inAmt,inCount" -> ["SUM(int_amt) AS inAmt", "SUM(int_count) AS inCount"]
func fieldsToSumSql(fields string) (sumSqls []string) {
strs := strings.Split(strings.TrimSpace(fields), ",")
for _, str := range strs {
field := strings.TrimSpace(str)
if field != "" {
sumSqls = append(sumSqls, "SUM("+lxUtil.FieldToColumn(field)+") AS "+field+"")
}
}
return
}
// SELECT...FROM...[WHERE] 句式的 SQL 中是否存在 WHERE 子句
func hasWhere(sql string) bool {
deep := 0 // 括号嵌套层数
step := 0 // "where" 匹配进度
// 遍历 sql 忽略 ' ( ` 判断是否存在 where
for i := 0; i < len(sql); i++ {
switch sql[i] {
case '(':
deep++
case ')':
deep--
case 96: // "`"
// 下一个 "`" 的下标
// 忽略其他字符
for i = i + 1; i < len(sql); i++ {
if sql[i] == 96 {
break
}
}
case 39: // "'"
// 下一个 "'" 的下标
// 忽略其他字符
for i = i + 1; i < len(sql); i++ {
if sql[i] == 39 {
break
}
}
default:
if deep != 0 {
continue
}
if step == 5 {
return true
}
if sql[i] == where[step][0] || sql[i] == where[step][1] {
step++
} else {
step = 0
}
}
}
return false
}
var where = []string{
"Ww",
"Hh",
"Ee",
"Rr",
"Ee",
}
// Transaction 同时开两个事务的方法
func Transaction(txItem, txMain *gorm.DB, fun func(txItem, txMain *gorm.DB) (err error)) (err error) {
tx1 := txItem.Begin()
tx2 := txMain.Begin()
err = fun(tx1, tx2)
if err != nil {
tx1.Rollback()
tx2.Rollback()
return err
}
err = tx1.Commit().Error
if err != nil {
return
}
err = tx2.Commit().Error
if err != nil {
return
}
return
}
// isSafeSQL 检查SQL语句是否安全防止SQL注入
func isSafeSQL(sql string) bool {
// 转换为大写进行关键字检查
upperSQL := strings.ToUpper(sql)
// 危险关键字列表
dangerousKeywords := []string{
"DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "CREATE", "TRUNCATE",
"EXEC", "EXECUTE", "XP_", "SP_", "UNION", "JOIN", "HAVING", "GROUP BY",
"ORDER BY", "LIMIT", "OFFSET", "--", "/*", "*/", ";", "@@", "@",
"0X", "CHAR(", "ASCII(", "SUBSTRING(", "MID(", "LENGTH(", "LEN(",
"CONCAT(", "LOAD_FILE(", "BENCHMARK(", "SLEEP(", "WAITFOR",
"CAST(", "CONVERT(", "IF(", "CASE", "WHEN", "THEN", "END",
}
// 检查危险关键字
for _, keyword := range dangerousKeywords {
if strings.Contains(upperSQL, keyword) {
return false
}
}
// 检查SQL结构是否为简单的SELECT语句
if strings.HasPrefix(upperSQL, "SELECT") {
return false
}
// 检查是否包含注释
if strings.Contains(upperSQL, "--") || strings.Contains(upperSQL, "/*") {
return false
}
// 检查是否包含分号(多条语句)
if strings.Contains(upperSQL, ";") {
return false
}
return true
}
// sanitizeOrderBy 安全处理 ORDER BY 字段
func sanitizeOrderBy(orderBy string) string {
// 移除可能危险的字符
orderBy = strings.TrimSpace(orderBy)
// 分割字段和排序方向
parts := strings.Fields(orderBy)
if len(parts) == 0 {
return ""
}
// 验证字段名
field := parts[0]
if !isValidFieldName(field) {
return ""
}
// 如果有排序方向,验证它
if len(parts) > 1 {
direction := strings.ToUpper(parts[1])
if direction != "ASC" && direction != "DESC" {
return field // 默认不添加方向
}
return field + " " + direction
}
return field
}
// isValidFieldName 验证字段名是否安全
func isValidFieldName(field string) bool {
// 允许字母、数字、下划线和点(用于表名.字段名)
matched, _ := regexp.MatchString(`^[a-zA-Z0-9_.]+$`, field)
return matched
}
// replaceSelectAndRemoveGroupBy 替换 SELECT 和 FROM 之间的部分,并去掉 GROUP BY
func replaceSelectAndRemoveGroupBy(sql string) string {
// 找到 SELECT 和 FROM 的位置
selectIndex := strings.Index(strings.ToUpper(sql), "SELECT")
fromIndex := strings.Index(strings.ToUpper(sql), "FROM")
if selectIndex == -1 || fromIndex == -1 {
return sql // 如果没有找到 SELECT 或 FROM返回原始 SQL
}
// 替换 SELECT 和 FROM 之间的部分
newSelectClause := "COUNT(DISTINCT ***) AS total"
replacedSQL := sql[:selectIndex+6] + " " + newSelectClause + " " + sql[fromIndex:]
// 去掉 GROUP BY 子句
groupByIndex := strings.LastIndex(strings.ToUpper(replacedSQL), "GROUP BY")
if groupByIndex != -1 {
c := replacedSQL[groupByIndex+8:]
c = strings.TrimSpace(c)
c = strings.ReplaceAll(c, ";", "")
c = strings.Split(c, ",")[0]
replacedSQL = replacedSQL[:groupByIndex]
replacedSQL = strings.ReplaceAll(replacedSQL, "DISTINCT ***", "DISTINCT "+c)
} else {
//没有group by 的
replacedSQL = strings.ReplaceAll(replacedSQL, "DISTINCT ***", "*")
}
return replacedSQL
}