gomog/internal/engine/aggregate.go

758 lines
18 KiB
Go

package engine
import (
"fmt"
"sort"
"strings"
"git.kingecg.top/kingecg/gomog/pkg/errors"
"git.kingecg.top/kingecg/gomog/pkg/types"
)
// AggregationEngine 聚合引擎
type AggregationEngine struct {
store *MemoryStore
}
// NewAggregationEngine 创建聚合引擎
func NewAggregationEngine(store *MemoryStore) *AggregationEngine {
return &AggregationEngine{store: store}
}
// Execute 执行聚合管道
func (e *AggregationEngine) Execute(collection string, pipeline []types.AggregateStage) ([]types.Document, error) {
// 获取集合所有文档
docs, err := e.store.GetAllDocuments(collection)
if err != nil {
return nil, err
}
// 依次执行每个阶段
result := docs
for _, stage := range pipeline {
result, err = e.executeStage(stage, result)
if err != nil {
return nil, errors.Wrap(err, errors.ErrAggregationError, "aggregation failed")
}
}
return result, nil
}
// executeStage 执行单个阶段
func (e *AggregationEngine) executeStage(stage types.AggregateStage, docs []types.Document) ([]types.Document, error) {
switch stage.Stage {
case "$match":
return e.executeMatch(stage.Spec, docs)
case "$group":
return e.executeGroup(stage.Spec, docs)
case "$sort":
return e.executeSort(stage.Spec, docs)
case "$project":
return e.executeProject(stage.Spec, docs)
case "$limit":
return e.executeLimit(stage.Spec, docs)
case "$skip":
return e.executeSkip(stage.Spec, docs)
case "$unwind":
return e.executeUnwind(stage.Spec, docs)
case "$lookup":
return e.executeLookup(stage.Spec, docs)
case "$count":
return e.executeCount(stage.Spec, docs)
case "$addFields", "$set":
return e.executeAddFields(stage.Spec, docs)
case "$unset":
return e.executeUnset(stage.Spec, docs)
case "$facet":
return e.executeFacet(stage.Spec, docs)
case "$sample":
return e.executeSample(stage.Spec, docs)
case "$bucket":
return e.executeBucket(stage.Spec, docs)
default:
return docs, nil // 未知阶段,跳过
}
}
// executeMatch 执行 $match 阶段
func (e *AggregationEngine) executeMatch(spec interface{}, docs []types.Document) ([]types.Document, error) {
filter, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
var results []types.Document
for _, doc := range docs {
if MatchFilter(doc.Data, filter) {
results = append(results, doc)
}
}
return results, nil
}
// executeGroup 执行 $group 阶段
func (e *AggregationEngine) executeGroup(spec interface{}, docs []types.Document) ([]types.Document, error) {
groupSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
// 获取分组字段
idField, _ := groupSpec["_id"].(string)
// 分组
groups := make(map[string][]types.Document)
for _, doc := range docs {
key := e.getGroupKey(doc, idField)
groups[key] = append(groups[key], doc)
}
// 聚合每个组
var results []types.Document
for key, groupDocs := range groups {
aggregated := e.aggregateGroup(groupSpec, groupDocs)
// 设置 _id
if key != "" {
aggregated["_id"] = key
}
results = append(results, types.Document{
ID: key,
Data: aggregated,
})
}
return results, nil
}
// getGroupKey 获取分组键
func (e *AggregationEngine) getGroupKey(doc types.Document, field string) string {
if field == "" || field[0] != '$' {
return ""
}
fieldName := field[1:] // 去掉 $ 前缀
value := getNestedValue(doc.Data, fieldName)
if value == nil {
return ""
}
// 转换为字符串作为键
switch v := value.(type) {
case string:
return v
case int, int64, float64:
return toString(v)
default:
return toString(value)
}
}
// aggregateGroup 聚合一组文档
func (e *AggregationEngine) aggregateGroup(groupSpec map[string]interface{}, docs []types.Document) map[string]interface{} {
result := make(map[string]interface{})
for field, expr := range groupSpec {
if field == "_id" {
continue
}
// 处理聚合操作符
if exprMap, ok := expr.(map[string]interface{}); ok {
for op, operand := range exprMap {
switch op {
case "$sum":
result[field] = e.sum(docs, operand)
case "$avg":
result[field] = e.avg(docs, operand)
case "$min":
result[field] = e.min(docs, operand)
case "$max":
result[field] = e.max(docs, operand)
case "$count":
result[field] = len(docs)
case "$first":
if len(docs) > 0 {
result[field] = e.getFieldValue(docs[0], operand)
}
case "$last":
if len(docs) > 0 {
result[field] = e.getFieldValue(docs[len(docs)-1], operand)
}
case "$push":
values := make([]interface{}, 0, len(docs))
for _, doc := range docs {
values = append(values, e.getFieldValue(doc, operand))
}
result[field] = values
case "$addToSet":
set := make(map[interface{}]bool)
for _, doc := range docs {
v := e.getFieldValue(doc, operand)
set[v] = true
}
values := make([]interface{}, 0, len(set))
for v := range set {
values = append(values, v)
}
result[field] = values
}
}
}
}
return result
}
// sum 计算总和
func (e *AggregationEngine) sum(docs []types.Document, field interface{}) float64 {
total := 0.0
for _, doc := range docs {
total += toFloat64(e.getFieldValue(doc, field))
}
return total
}
// avg 计算平均值
func (e *AggregationEngine) avg(docs []types.Document, field interface{}) float64 {
if len(docs) == 0 {
return 0
}
return e.sum(docs, field) / float64(len(docs))
}
// min 计算最小值
func (e *AggregationEngine) min(docs []types.Document, field interface{}) float64 {
if len(docs) == 0 {
return 0
}
min := toFloat64(e.getFieldValue(docs[0], field))
for _, doc := range docs[1:] {
val := toFloat64(e.getFieldValue(doc, field))
if val < min {
min = val
}
}
return min
}
// max 计算最大值
func (e *AggregationEngine) max(docs []types.Document, field interface{}) float64 {
if len(docs) == 0 {
return 0
}
max := toFloat64(e.getFieldValue(docs[0], field))
for _, doc := range docs[1:] {
val := toFloat64(e.getFieldValue(doc, field))
if val > max {
max = val
}
}
return max
}
// getFieldValue 获取字段值
func (e *AggregationEngine) getFieldValue(doc types.Document, field interface{}) interface{} {
switch f := field.(type) {
case string:
if len(f) > 0 && f[0] == '$' {
return getNestedValue(doc.Data, f[1:])
}
return f
default:
return field
}
}
// executeSort 执行 $sort 阶段
func (e *AggregationEngine) executeSort(spec interface{}, docs []types.Document) ([]types.Document, error) {
sortSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
// 转换为排序字段映射
sortFields := make(map[string]int)
for field, direction := range sortSpec {
dir := 1
switch d := direction.(type) {
case int:
dir = d
case int64:
dir = int(d)
case float64:
dir = int(d)
}
sortFields[field] = dir
}
// 创建可排序的副本
sorted := make([]types.Document, len(docs))
copy(sorted, docs)
sort.Slice(sorted, func(i, j int) bool {
return e.compareDocs(sorted[i], sorted[j], sortFields)
})
return sorted, nil
}
// compareDocs 比较两个文档
func (e *AggregationEngine) compareDocs(a, b types.Document, sortFields map[string]int) bool {
for field, dir := range sortFields {
valA := getNestedValue(a.Data, field)
valB := getNestedValue(b.Data, field)
cmp := compareValues(valA, valB)
if cmp != 0 {
if dir < 0 {
return cmp > 0
}
return cmp < 0
}
}
return false
}
// compareValues 比较两个值
func compareValues(a, b interface{}) int {
if a == nil && b == nil {
return 0
}
if a == nil {
return -1
}
if b == nil {
return 1
}
// 数值比较
numA := toFloat64(a)
numB := toFloat64(b)
if numA < numB {
return -1
} else if numA > numB {
return 1
}
return 0
}
// executeProject 执行 $project 阶段
func (e *AggregationEngine) executeProject(spec interface{}, docs []types.Document) ([]types.Document, error) {
projectSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
var results []types.Document
for _, doc := range docs {
projected := e.projectDocument(doc.Data, projectSpec)
results = append(results, types.Document{
ID: doc.ID,
Data: projected,
})
}
return results, nil
}
// projectDocument 投影文档
func (e *AggregationEngine) projectDocument(data map[string]interface{}, spec map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for field, include := range spec {
if field == "_id" {
// 特殊处理 _id
if isFalse(include) {
// 排除 _id
} else {
result["_id"] = data["_id"]
}
continue
}
if isTrue(include) {
// 包含字段
result[field] = getNestedValue(data, field)
} else if isFalse(include) {
// 排除字段(在包含模式下不处理)
continue
} else {
// 表达式
result[field] = e.evaluateExpression(data, include)
}
}
return result
}
// evaluateExpression 评估表达式
func (e *AggregationEngine) evaluateExpression(data map[string]interface{}, expr interface{}) interface{} {
// 处理字段引用(以 $ 开头的字符串)
if fieldStr, ok := expr.(string); ok && len(fieldStr) > 0 && fieldStr[0] == '$' {
fieldName := fieldStr[1:] // 移除 $ 前缀
return getNestedValue(data, fieldName)
}
if exprMap, ok := expr.(map[string]interface{}); ok {
for op, operand := range exprMap {
switch op {
case "$concat":
return e.concat(operand, data)
case "$substr", "$substring":
return e.substr(operand, data)
case "$toUpper":
str := e.getFieldValueStr(types.Document{Data: data}, operand)
return strings.ToUpper(str)
case "$toLower":
str := e.getFieldValueStr(types.Document{Data: data}, operand)
return strings.ToLower(str)
case "$add":
return e.add(operand, data)
case "$multiply":
return e.multiply(operand, data)
case "$divide":
return e.divide(operand, data)
case "$subtract":
return e.subtract(operand, data)
case "$abs":
return e.abs(operand, data)
case "$ceil":
return e.ceil(operand, data)
case "$floor":
return e.floor(operand, data)
case "$round":
return e.round(operand, data)
case "$sqrt":
return e.sqrt(operand, data)
case "$pow":
return e.pow(operand, data)
case "$size":
arr := getNestedValue(data, operand.(string))
if a, ok := arr.([]interface{}); ok {
return len(a)
}
return 0
case "$ifNull":
return e.ifNull(operand, data)
case "$cond":
return e.cond(operand, data)
case "$switch":
return e.switchExpr(operand, data)
case "$trim":
return e.trim(operand, data)
case "$ltrim":
return e.ltrim(operand, data)
case "$rtrim":
return e.rtrim(operand, data)
case "$split":
return e.split(operand, data)
case "$replaceAll":
return e.replaceAll(operand, data)
case "$strcasecmp":
return e.strcasecmp(operand, data)
case "$filter":
return e.filter(operand, data)
case "$map":
return e.mapArr(operand, data)
case "$concatArrays":
return e.concatArrays(operand, data)
case "$slice":
return e.slice(operand, data)
case "$mergeObjects":
return e.mergeObjects(operand, data)
case "$objectToArray":
return e.objectToArray(operand, data)
case "$year":
return e.year(operand, data)
case "$month":
return e.month(operand, data)
case "$dayOfMonth":
return e.dayOfMonth(operand, data)
case "$hour":
return e.hour(operand, data)
case "$minute":
return e.minute(operand, data)
case "$second":
return e.second(operand, data)
case "$dateToString":
return e.dateToString(operand, data)
case "$dateAdd":
return e.dateAdd(operand, data)
case "$dateDiff":
return e.dateDiff(operand, data)
case "$gt":
return e.compareGt(operand, data)
case "$gte":
return e.compareGte(operand, data)
case "$lt":
return e.compareLt(operand, data)
case "$lte":
return e.compareLte(operand, data)
case "$eq":
return e.compareEq(operand, data)
case "$ne":
return e.compareNe(operand, data)
}
}
}
return expr
}
// executeLimit 执行 $limit 阶段
func (e *AggregationEngine) executeLimit(spec interface{}, docs []types.Document) ([]types.Document, error) {
limit := 0
switch l := spec.(type) {
case int:
limit = l
case int64:
limit = int(l)
case float64:
limit = int(l)
}
if limit <= 0 || limit >= len(docs) {
return docs, nil
}
return docs[:limit], nil
}
// executeSkip 执行 $skip 阶段
func (e *AggregationEngine) executeSkip(spec interface{}, docs []types.Document) ([]types.Document, error) {
skip := 0
switch s := spec.(type) {
case int:
skip = s
case int64:
skip = int(s)
case float64:
skip = int(s)
}
if skip <= 0 {
return docs, nil
}
if skip >= len(docs) {
return []types.Document{}, nil
}
return docs[skip:], nil
}
// executeUnwind 执行 $unwind 阶段
func (e *AggregationEngine) executeUnwind(spec interface{}, docs []types.Document) ([]types.Document, error) {
var path string
var preserveNull bool
switch s := spec.(type) {
case string:
path = s
case map[string]interface{}:
if p, ok := s["path"].(string); ok {
path = p
}
if pn, ok := s["preserveNullAndEmptyArrays"].(bool); ok {
preserveNull = pn
}
}
if path == "" || path[0] != '$' {
return docs, nil
}
fieldPath := path[1:]
var results []types.Document
for _, doc := range docs {
arr := getNestedValue(doc.Data, fieldPath)
if arr == nil {
if preserveNull {
results = append(results, doc)
}
continue
}
array, ok := arr.([]interface{})
if !ok || len(array) == 0 {
if preserveNull {
results = append(results, doc)
}
continue
}
for _, item := range array {
newData := deepCopyMap(doc.Data)
setNestedValue(newData, fieldPath, item)
results = append(results, types.Document{
ID: doc.ID,
Data: newData,
})
}
}
return results, nil
}
// executeLookup 执行 $lookup 阶段
func (e *AggregationEngine) executeLookup(spec interface{}, docs []types.Document) ([]types.Document, error) {
lookupSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
from, _ := lookupSpec["from"].(string)
localField, _ := lookupSpec["localField"].(string)
foreignField, _ := lookupSpec["foreignField"].(string)
as, _ := lookupSpec["as"].(string)
if from == "" || as == "" {
return docs, nil
}
// 获取关联集合的数据
foreignDocs, err := e.store.GetAllDocuments(from)
if err != nil {
return docs, nil // 忽略错误,继续处理
}
var results []types.Document
for _, doc := range docs {
localValue := getNestedValue(doc.Data, localField)
var matches []map[string]interface{}
for _, foreignDoc := range foreignDocs {
foreignValue := getNestedValue(foreignDoc.Data, foreignField)
if compareEq(localValue, foreignValue) {
matches = append(matches, foreignDoc.Data)
}
}
newData := deepCopyMap(doc.Data)
newData[as] = matches
results = append(results, types.Document{
ID: doc.ID,
Data: newData,
})
}
return results, nil
}
// executeCount 执行 $count 阶段
func (e *AggregationEngine) executeCount(spec interface{}, docs []types.Document) ([]types.Document, error) {
fieldName, ok := spec.(string)
if !ok {
fieldName = "count"
}
return []types.Document{
{
ID: "count",
Data: map[string]interface{}{
fieldName: len(docs),
},
},
}, nil
}
// 辅助函数
func isTrue(v interface{}) bool {
switch val := v.(type) {
case bool:
return val
case int:
return val != 0
case float64:
return val != 0
}
return true
}
func isFalse(v interface{}) bool {
return !isTrue(v)
}
func toString(v interface{}) string {
switch val := v.(type) {
case string:
return val
case int:
return string(rune(val))
case int64:
return string(rune(val))
case float64:
return fmt.Sprintf("%v", val)
default:
return ""
}
}
// 比较操作符辅助方法
func (e *AggregationEngine) compareGt(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return toFloat64(left) > toFloat64(right)
}
func (e *AggregationEngine) compareGte(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return toFloat64(left) >= toFloat64(right)
}
func (e *AggregationEngine) compareLt(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return toFloat64(left) < toFloat64(right)
}
func (e *AggregationEngine) compareLte(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return toFloat64(left) <= toFloat64(right)
}
func (e *AggregationEngine) compareEq(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return left == right
}
func (e *AggregationEngine) compareNe(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return left != right
}