243 lines
5.5 KiB
Go
243 lines
5.5 KiB
Go
package database
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"encoding/json"
|
||
"fmt"
|
||
"time"
|
||
|
||
"git.kingecg.top/kingecg/gomog/pkg/types"
|
||
)
|
||
|
||
// BaseAdapter 基础适配器实现
|
||
type BaseAdapter struct {
|
||
db *sql.DB
|
||
driverName string
|
||
}
|
||
|
||
// NewBaseAdapter 创建基础适配器
|
||
func NewBaseAdapter(driverName string) *BaseAdapter {
|
||
return &BaseAdapter{
|
||
driverName: driverName,
|
||
}
|
||
}
|
||
|
||
// getDB 获取数据库连接(供子类使用)
|
||
func (a *BaseAdapter) GetDB() *sql.DB {
|
||
return a.db
|
||
}
|
||
|
||
// Connect 连接数据库
|
||
func (a *BaseAdapter) Connect(ctx context.Context, dsn string) error {
|
||
db, err := sql.Open(a.driverName, dsn)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
a.db = db
|
||
return db.PingContext(ctx)
|
||
}
|
||
|
||
// Close 关闭连接
|
||
func (a *BaseAdapter) Close() error {
|
||
if a.db != nil {
|
||
return a.db.Close()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// Ping 检查连接
|
||
func (a *BaseAdapter) Ping(ctx context.Context) error {
|
||
return a.db.PingContext(ctx)
|
||
}
|
||
|
||
// CreateCollection 创建集合(表)
|
||
func (a *BaseAdapter) CreateCollection(ctx context.Context, name string) error {
|
||
// 使用统一的表结构:id, data(JSON), created_at, updated_at
|
||
query := fmt.Sprintf(`
|
||
CREATE TABLE IF NOT EXISTS %s (
|
||
id TEXT PRIMARY KEY,
|
||
data JSON NOT NULL,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
)`, name)
|
||
|
||
_, err := a.db.ExecContext(ctx, query)
|
||
return err
|
||
}
|
||
|
||
// DropCollection 删除集合(表)
|
||
func (a *BaseAdapter) DropCollection(ctx context.Context, name string) error {
|
||
query := fmt.Sprintf("DROP TABLE IF EXISTS %s", name)
|
||
_, err := a.db.ExecContext(ctx, query)
|
||
return err
|
||
}
|
||
|
||
// CollectionExists 检查集合是否存在
|
||
func (a *BaseAdapter) CollectionExists(ctx context.Context, name string) (bool, error) {
|
||
// 这个方法需要在具体适配器中实现,因为不同数据库的系统表不同
|
||
return false, ErrNotImplemented
|
||
}
|
||
|
||
// InsertMany 批量插入文档
|
||
func (a *BaseAdapter) InsertMany(ctx context.Context, collection string, docs []types.Document) error {
|
||
tx, err := a.db.BeginTx(ctx, nil)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer tx.Rollback()
|
||
|
||
stmt, err := tx.PrepareContext(ctx,
|
||
fmt.Sprintf("INSERT INTO %s (id, data, created_at, updated_at) VALUES (?, ?, ?, ?)", collection))
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer stmt.Close()
|
||
|
||
for _, doc := range docs {
|
||
jsonData, err := json.Marshal(doc.Data)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
now := time.Now()
|
||
_, err = stmt.ExecContext(ctx, doc.ID, jsonData, now, now)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return tx.Commit()
|
||
}
|
||
|
||
// UpdateMany 批量更新文档
|
||
func (a *BaseAdapter) UpdateMany(ctx context.Context, collection string, ids []string, update types.Update) error {
|
||
tx, err := a.db.BeginTx(ctx, nil)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer tx.Rollback()
|
||
|
||
// 构建更新 SQL
|
||
setClauses := make([]string, 0)
|
||
args := make([]interface{}, 0)
|
||
|
||
// 处理 $set
|
||
for field, value := range update.Set {
|
||
setClauses = append(setClauses, fmt.Sprintf("json_set(data, '$.%s', ?)", field))
|
||
args = append(args, toJSONString(value))
|
||
}
|
||
|
||
// 处理 $unset
|
||
for field := range update.Unset {
|
||
// SQLite/PostgreSQL 移除 JSON 字段的方式不同,这里简化处理
|
||
// 实际实现中需要根据具体数据库调整
|
||
setClauses = append(setClauses, fmt.Sprintf("json_remove(data, '$.%s')", field))
|
||
}
|
||
|
||
if len(setClauses) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 为每个 ID 执行更新
|
||
for _, id := range ids {
|
||
updateArgs := append([]interface{}{time.Now()}, args...)
|
||
updateArgs = append(updateArgs, id)
|
||
|
||
query := fmt.Sprintf(
|
||
"UPDATE %s SET data = %s, updated_at = ? WHERE id = ?",
|
||
collection,
|
||
setClauses[0], // 简化:只处理第一个 set 子句
|
||
)
|
||
|
||
_, err = tx.ExecContext(ctx, query, updateArgs...)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return tx.Commit()
|
||
}
|
||
|
||
// DeleteMany 批量删除文档
|
||
func (a *BaseAdapter) DeleteMany(ctx context.Context, collection string, ids []string) error {
|
||
if len(ids) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 构建 IN 子句
|
||
placeholders := make([]string, len(ids))
|
||
args := make([]interface{}, len(ids))
|
||
for i, id := range ids {
|
||
placeholders[i] = "?"
|
||
args[i] = id
|
||
}
|
||
|
||
query := fmt.Sprintf(
|
||
"DELETE FROM %s WHERE id IN (%s)",
|
||
collection,
|
||
fmt.Sprintf("%s", placeholders),
|
||
)
|
||
|
||
_, err := a.db.ExecContext(ctx, query, args...)
|
||
return err
|
||
}
|
||
|
||
// FindAll 查询所有文档
|
||
func (a *BaseAdapter) FindAll(ctx context.Context, collection string) ([]types.Document, error) {
|
||
query := fmt.Sprintf("SELECT id, data, created_at, updated_at FROM %s", collection)
|
||
rows, err := a.db.QueryContext(ctx, query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var docs []types.Document
|
||
for rows.Next() {
|
||
var doc types.Document
|
||
var jsonData []byte
|
||
err := rows.Scan(&doc.ID, &jsonData, &doc.CreatedAt, &doc.UpdatedAt)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if err := json.Unmarshal(jsonData, &doc.Data); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
docs = append(docs, doc)
|
||
}
|
||
|
||
return docs, rows.Err()
|
||
}
|
||
|
||
// BeginTx 开始事务
|
||
func (a *BaseAdapter) BeginTx(ctx context.Context) (Transaction, error) {
|
||
tx, err := a.db.BeginTx(ctx, nil)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &baseTransaction{tx: tx}, nil
|
||
}
|
||
|
||
// baseTransaction 基础事务实现
|
||
type baseTransaction struct {
|
||
tx *sql.Tx
|
||
}
|
||
|
||
func (t *baseTransaction) Commit() error {
|
||
return t.tx.Commit()
|
||
}
|
||
|
||
func (t *baseTransaction) Rollback() error {
|
||
return t.tx.Rollback()
|
||
}
|
||
|
||
// toJSONString 将值转换为 JSON 字符串
|
||
func toJSONString(v interface{}) string {
|
||
if v == nil {
|
||
return "null"
|
||
}
|
||
data, _ := json.Marshal(v)
|
||
return string(data)
|
||
}
|