package postgres import ( "context" "encoding/json" "fmt" "time" "git.kingecg.top/kingecg/gomog/internal/database" "git.kingecg.top/kingecg/gomog/pkg/types" _ "github.com/lib/pq" ) // PostgresAdapter PostgreSQL 数据库适配器 type PostgresAdapter struct { *database.BaseAdapter } // NewPostgresAdapter 创建 PostgreSQL 适配器 func NewPostgresAdapter() *PostgresAdapter { return &PostgresAdapter{ BaseAdapter: database.NewBaseAdapter("postgres"), } } // Connect 连接 PostgreSQL 数据库 func (a *PostgresAdapter) Connect(ctx context.Context, dsn string) error { if err := a.BaseAdapter.Connect(ctx, dsn); err != nil { return err } // 设置 PostgreSQL 会话参数 _, err := a.GetDB().Exec("SET timezone = 'UTC'") return err } // CreateCollection 创建集合(PostgreSQL 表) func (a *PostgresAdapter) CreateCollection(ctx context.Context, name string) error { // PostgreSQL 使用 JSONB 类型(二进制 JSON,更高效) query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id TEXT PRIMARY KEY, data JSONB NOT NULL, created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP )`, name) _, err := a.GetDB().ExecContext(ctx, query) return err } // CollectionExists 检查集合是否存在 func (a *PostgresAdapter) CollectionExists(ctx context.Context, name string) (bool, error) { query := `SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1` var count int err := a.GetDB().QueryRowContext(ctx, query, name).Scan(&count) if err != nil { return false, err } return count > 0, nil } // FindAll 查询所有文档(使用 PostgreSQL JSONB) func (a *PostgresAdapter) FindAll(ctx context.Context, collection string) ([]types.Document, error) { query := fmt.Sprintf("SELECT id, data::text, created_at, updated_at FROM %s", collection) rows, err := a.GetDB().QueryContext(ctx, query) if err != nil { return nil, err } defer rows.Close() var docs []types.Document for rows.Next() { var doc types.Document var jsonData string err := rows.Scan(&doc.ID, &jsonData, &doc.CreatedAt, &doc.UpdatedAt) if err != nil { return nil, err } if err := json.Unmarshal([]byte(jsonData), &doc.Data); err != nil { return nil, err } docs = append(docs, doc) } return docs, rows.Err() } // InsertMany 批量插入(PostgreSQL 优化版本) func (a *PostgresAdapter) InsertMany(ctx context.Context, collection string, docs []types.Document) error { tx, err := a.GetDB().BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() for _, doc := range docs { jsonData, err := json.Marshal(doc.Data) if err != nil { return err } query := fmt.Sprintf( "INSERT INTO %s (id, data, created_at, updated_at) VALUES ($1, $2::jsonb, $3, $4)", collection, ) now := doc.CreatedAt if now.IsZero() { now = doc.UpdatedAt } if now.IsZero() { now = doc.UpdatedAt } _, err = tx.ExecContext(ctx, query, doc.ID, string(jsonData), now, now) if err != nil { return err } } return tx.Commit() } // UpdateMany 批量更新(使用 PostgreSQL JSONB 操作符) func (a *PostgresAdapter) UpdateMany(ctx context.Context, collection string, ids []string, update types.Update) error { if len(ids) == 0 { return nil } tx, err := a.GetDB().BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() // 构建更新表达式 updateExpr := "data" args := make([]interface{}, 0) argIndex := 1 // 处理 $set - 使用 JSONB 合并 if len(update.Set) > 0 { setJSON, _ := json.Marshal(update.Set) updateExpr = fmt.Sprintf("%s || $%d::jsonb", updateExpr, argIndex) args = append(args, string(setJSON)) argIndex++ } // 处理 $unset - 使用 JSONB 减号操作符 for field := range update.Unset { updateExpr = fmt.Sprintf("%s - $%d", updateExpr, argIndex) args = append(args, field) argIndex++ } // 为每个 ID 执行更新 for _, id := range ids { query := fmt.Sprintf( "UPDATE %s SET data = %s, updated_at = $%d WHERE id = $%d", collection, updateExpr, argIndex, argIndex+1, ) finalArgs := append(args, time.Now(), id) _, err = tx.ExecContext(ctx, query, finalArgs...) if err != nil { return err } } return tx.Commit() }