71 lines
1.6 KiB
Go
71 lines
1.6 KiB
Go
package mongo
|
||
|
||
import (
|
||
"common/config"
|
||
"common/log"
|
||
"context"
|
||
"go.mongodb.org/mongo-driver/mongo"
|
||
"go.mongodb.org/mongo-driver/mongo/options"
|
||
)
|
||
|
||
type DBName string
|
||
|
||
var clients = make(map[DBName]*mongo.Client)
|
||
|
||
func Init(cfg map[string]*config.MongoConfig) error {
|
||
for name, oneConfig := range cfg {
|
||
if client, err := initOneClient(oneConfig); err != nil {
|
||
return err
|
||
} else {
|
||
clients[DBName(name)] = client
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func initOneClient(cfg *config.MongoConfig) (*mongo.Client, error) {
|
||
opts := options.Client().
|
||
ApplyURI(cfg.URI)
|
||
|
||
client, err := mongo.Connect(context.Background(), opts)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return client, nil
|
||
}
|
||
|
||
// GetClient 返回 mongo.Client,你可以通过 .Database("xxx") 获取具体数据库
|
||
func GetClient(dbName DBName) *mongo.Client {
|
||
if c, ok := clients[dbName]; ok {
|
||
return c
|
||
}
|
||
log.Errorf("mongo client %s not found", dbName)
|
||
return nil
|
||
}
|
||
|
||
// GetDB 是便捷方法,直接返回 *mongo.Database(假设 dbName 对应数据库名)
|
||
// 如果你的配置中 dbName 和实际数据库名一致,可这样用
|
||
func GetDB(dbName DBName) *mongo.Database {
|
||
client := GetClient(dbName)
|
||
if client == nil {
|
||
return nil
|
||
}
|
||
// 假设配置中的 key 就是数据库名;若需分离,可在 config 中加字段
|
||
return client.Database(string(dbName))
|
||
}
|
||
|
||
func Close() error {
|
||
if clients == nil {
|
||
return nil
|
||
}
|
||
for name, client := range clients {
|
||
if client != nil {
|
||
if err := client.Disconnect(context.Background()); err != nil {
|
||
log.Errorf("close mongo client %s error: %v", name, err)
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|