149 lines
3.6 KiB
Go
149 lines
3.6 KiB
Go
package lxDb
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"crypto/x509"
|
||
"fmt"
|
||
"io/ioutil"
|
||
"os"
|
||
|
||
"git.listensoft.net/tool/lxutils/lxlog"
|
||
"git.listensoft.net/tool/lxutils/lxzap"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/go-sql-driver/mysql"
|
||
"go.uber.org/zap"
|
||
gormMysql "gorm.io/driver/mysql"
|
||
"gorm.io/gorm"
|
||
"gorm.io/gorm/logger"
|
||
"gorm.io/gorm/schema"
|
||
)
|
||
|
||
// DB2.0 数据库连接 gorm2.0
|
||
var DBS map[string]*gorm.DB //gs 数据库
|
||
type DbConfig struct {
|
||
Name string
|
||
Host string
|
||
Port string
|
||
User string
|
||
Password string
|
||
Database string
|
||
Charset string
|
||
TLS string
|
||
}
|
||
|
||
func GetDB(c *gin.Context, dbName ...string) *gorm.DB {
|
||
v, _ := c.Get("X-Span-ID")
|
||
spanId := fmt.Sprintf("%v", v)
|
||
ctx := context.WithValue(context.Background(), "X-Span-ID", spanId)
|
||
d := "one"
|
||
if len(dbName) != 0 {
|
||
d = dbName[0]
|
||
}
|
||
return DBS[d].WithContext(ctx)
|
||
}
|
||
|
||
func GetDB2(dbName ...string) *gorm.DB {
|
||
d := "one"
|
||
if len(dbName) != 0 {
|
||
d = dbName[0]
|
||
}
|
||
return DBS[d]
|
||
}
|
||
|
||
func InitDBS(env string, conf []DbConfig) {
|
||
DBS = map[string]*gorm.DB{}
|
||
if len(conf) == 1 {
|
||
conf[0].Name = "one"
|
||
}
|
||
for _, config := range conf {
|
||
InitDB(env, config)
|
||
}
|
||
}
|
||
|
||
func InitDB(env string, conf DbConfig) {
|
||
if conf.Name == "" {
|
||
DBS = nil
|
||
fmt.Println("db name 错误")
|
||
return
|
||
}
|
||
if conf.Host == "" {
|
||
DBS = nil
|
||
fmt.Println("未配置Db连接")
|
||
return
|
||
}
|
||
|
||
if conf.Charset == "" {
|
||
conf.Charset = "utf8mb4"
|
||
}
|
||
|
||
// 如果配置了TLS,则设置TLS连接
|
||
if conf.TLS != "" {
|
||
// 1. 读取CA证书
|
||
caCert, err := ioutil.ReadFile(conf.TLS)
|
||
if err != nil {
|
||
panic("读取CA证书失败: " + err.Error())
|
||
}
|
||
|
||
// 2. 创建证书池并添加CA证书
|
||
caCertPool := x509.NewCertPool()
|
||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||
panic("添加CA证书到池失败")
|
||
}
|
||
|
||
/* // 3. 加载客户端证书和私钥
|
||
cert, err := tls.LoadX509KeyPair("path/to/client-cert.pem", "path/to/client-key.pem")
|
||
if err != nil {
|
||
panic("加载客户端证书失败: " + err.Error())
|
||
}*/
|
||
|
||
// 4. 创建TLS配置
|
||
tlsConfig := &tls.Config{
|
||
RootCAs: caCertPool, // 信任的CA
|
||
//Certificates: []tls.Certificate{cert}, // 客户端证书
|
||
//ServerName: conf.Host,
|
||
//MinVersion: tls.VersionTLS12, // 最小TLS版本
|
||
}
|
||
|
||
// 5. 注册自定义TLS配置到MySQL驱动
|
||
mysql.RegisterTLSConfig("custom-tls", tlsConfig)
|
||
}
|
||
var dsn string
|
||
if conf.TLS != "" {
|
||
dsn = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s&parseTime=True&loc=Local&tls=custom-tls",
|
||
conf.User, conf.Password, conf.Host, conf.Port, conf.Database, conf.Charset)
|
||
} else {
|
||
dsn = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s&parseTime=True&loc=Local",
|
||
conf.User, conf.Password, conf.Host, conf.Port, conf.Database, conf.Charset)
|
||
}
|
||
if env == "dev" {
|
||
db, err := gorm.Open(gormMysql.Open(dsn), &gorm.Config{
|
||
Logger: lxlog.Default.LogMode(logger.Info),
|
||
NamingStrategy: schema.NamingStrategy{
|
||
SingularTable: true, // 使用单数表名,启用该选项,此时,`User` 的表名应该是 `user`
|
||
},
|
||
})
|
||
if err != nil {
|
||
fmt.Println(err.Error())
|
||
os.Exit(-1)
|
||
}
|
||
DBS[conf.Name] = db
|
||
} else {
|
||
logger2 := lxzap.NewGormZap(zap.L())
|
||
db, err := gorm.Open(gormMysql.Open(dsn), &gorm.Config{
|
||
//Logger: logger.Default.LogMode(logger.Info),
|
||
Logger: logger2,
|
||
NamingStrategy: schema.NamingStrategy{
|
||
SingularTable: true, // 使用单数表名,启用该选项,此时,`User` 的表名应该是 `user`
|
||
},
|
||
})
|
||
if err != nil {
|
||
fmt.Println(err.Error())
|
||
os.Exit(-1)
|
||
}
|
||
DBS[conf.Name] = db
|
||
}
|
||
DBS[conf.Name] = DBS[conf.Name].Session(&gorm.Session{SkipDefaultTransaction: true})
|
||
//db.AutoMigrate(TaskData{}, Task{}, Version{})
|
||
}
|