126 lines
3.0 KiB
Go
126 lines
3.0 KiB
Go
package sqlite
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
|
||
"git.kingecg.top/kingecg/gomog/internal/database"
|
||
"git.kingecg.top/kingecg/gomog/pkg/types"
|
||
_ "github.com/mattn/go-sqlite3"
|
||
)
|
||
|
||
// SQLiteAdapter SQLite 数据库适配器
|
||
type SQLiteAdapter struct {
|
||
*database.BaseAdapter
|
||
}
|
||
|
||
// NewSQLiteAdapter 创建 SQLite 适配器
|
||
func NewSQLiteAdapter() *SQLiteAdapter {
|
||
return &SQLiteAdapter{
|
||
BaseAdapter: database.NewBaseAdapter("sqlite3"),
|
||
}
|
||
}
|
||
|
||
// Connect 连接 SQLite 数据库
|
||
func (a *SQLiteAdapter) Connect(ctx context.Context, dsn string) error {
|
||
// SQLite 需要启用 JSON1 扩展(大多数构建已默认包含)
|
||
if err := a.BaseAdapter.Connect(ctx, dsn); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 设置 SQLite 特定的 PRAGMA
|
||
_, err := a.GetDB().Exec("PRAGMA journal_mode = WAL")
|
||
return err
|
||
}
|
||
|
||
// CreateCollection 创建集合(SQLite 表)
|
||
func (a *SQLiteAdapter) CreateCollection(ctx context.Context, name string) error {
|
||
// SQLite 使用 CHECK 约束验证 JSON
|
||
query := fmt.Sprintf(`
|
||
CREATE TABLE IF NOT EXISTS %s (
|
||
id TEXT PRIMARY KEY,
|
||
data TEXT NOT NULL CHECK(json_valid(data)),
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
)`, name)
|
||
|
||
_, err := a.GetDB().ExecContext(ctx, query)
|
||
return err
|
||
}
|
||
|
||
// CollectionExists 检查集合是否存在
|
||
func (a *SQLiteAdapter) CollectionExists(ctx context.Context, name string) (bool, error) {
|
||
query := `SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?`
|
||
var count int
|
||
err := a.GetDB().QueryRowContext(ctx, query, name).Scan(&count)
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return count > 0, nil
|
||
}
|
||
|
||
// FindAll 查询所有文档(使用 SQLite JSON 函数)
|
||
func (a *SQLiteAdapter) FindAll(ctx context.Context, collection string) ([]types.Document, error) {
|
||
query := fmt.Sprintf("SELECT id, data, created_at, updated_at FROM %s", collection)
|
||
rows, err := a.GetDB().QueryContext(ctx, query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var docs []types.Document
|
||
for rows.Next() {
|
||
var doc types.Document
|
||
var jsonData []byte
|
||
err := rows.Scan(&doc.ID, &jsonData, &doc.CreatedAt, &doc.UpdatedAt)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if err := json.Unmarshal(jsonData, &doc.Data); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
docs = append(docs, doc)
|
||
}
|
||
|
||
return docs, rows.Err()
|
||
}
|
||
|
||
// InsertMany 批量插入(SQLite 优化版本)
|
||
func (a *SQLiteAdapter) InsertMany(ctx context.Context, collection string, docs []types.Document) error {
|
||
tx, err := a.GetDB().BeginTx(ctx, nil)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer tx.Rollback()
|
||
|
||
for _, doc := range docs {
|
||
jsonData, err := json.Marshal(doc.Data)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
query := fmt.Sprintf(
|
||
"INSERT INTO %s (id, data, created_at, updated_at) VALUES (?, json(?), ?, ?)",
|
||
collection,
|
||
)
|
||
|
||
now := doc.CreatedAt
|
||
if now.IsZero() {
|
||
now = doc.UpdatedAt
|
||
}
|
||
if now.IsZero() {
|
||
now = doc.UpdatedAt
|
||
}
|
||
|
||
_, err = tx.ExecContext(ctx, query, doc.ID, string(jsonData), now, now)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return tx.Commit()
|
||
}
|