package engine import ( "fmt" "sort" "strings" "git.kingecg.top/kingecg/gomog/pkg/errors" "git.kingecg.top/kingecg/gomog/pkg/types" ) // AggregationEngine 聚合引擎 type AggregationEngine struct { store *MemoryStore } // NewAggregationEngine 创建聚合引擎 func NewAggregationEngine(store *MemoryStore) *AggregationEngine { return &AggregationEngine{store: store} } // Execute 执行聚合管道 func (e *AggregationEngine) Execute(collection string, pipeline []types.AggregateStage) ([]types.Document, error) { // 获取集合所有文档 docs, err := e.store.GetAllDocuments(collection) if err != nil { return nil, err } // 依次执行每个阶段 result := docs for _, stage := range pipeline { result, err = e.executeStage(stage, result) if err != nil { return nil, errors.Wrap(err, errors.ErrAggregationError, "aggregation failed") } } return result, nil } // executeStage 执行单个阶段 func (e *AggregationEngine) executeStage(stage types.AggregateStage, docs []types.Document) ([]types.Document, error) { switch stage.Stage { case "$match": return e.executeMatch(stage.Spec, docs) case "$group": return e.executeGroup(stage.Spec, docs) case "$sort": return e.executeSort(stage.Spec, docs) case "$project": return e.executeProject(stage.Spec, docs) case "$limit": return e.executeLimit(stage.Spec, docs) case "$skip": return e.executeSkip(stage.Spec, docs) case "$unwind": return e.executeUnwind(stage.Spec, docs) case "$lookup": return e.executeLookup(stage.Spec, docs) case "$count": return e.executeCount(stage.Spec, docs) case "$addFields", "$set": return e.executeAddFields(stage.Spec, docs) case "$unset": return e.executeUnset(stage.Spec, docs) case "$facet": return e.executeFacet(stage.Spec, docs) case "$sample": return e.executeSample(stage.Spec, docs) case "$bucket": return e.executeBucket(stage.Spec, docs) default: return docs, nil // 未知阶段,跳过 } } // executeMatch 执行 $match 阶段 func (e *AggregationEngine) executeMatch(spec interface{}, docs []types.Document) ([]types.Document, error) { filter, ok := spec.(map[string]interface{}) if !ok { return docs, nil } var results []types.Document for _, doc := range docs { if MatchFilter(doc.Data, filter) { results = append(results, doc) } } return results, nil } // executeGroup 执行 $group 阶段 func (e *AggregationEngine) executeGroup(spec interface{}, docs []types.Document) ([]types.Document, error) { groupSpec, ok := spec.(map[string]interface{}) if !ok { return docs, nil } // 获取分组字段 idField, _ := groupSpec["_id"].(string) // 分组 groups := make(map[string][]types.Document) for _, doc := range docs { key := e.getGroupKey(doc, idField) groups[key] = append(groups[key], doc) } // 聚合每个组 var results []types.Document for key, groupDocs := range groups { aggregated := e.aggregateGroup(groupSpec, groupDocs) // 设置 _id if key != "" { aggregated["_id"] = key } results = append(results, types.Document{ ID: key, Data: aggregated, }) } return results, nil } // getGroupKey 获取分组键 func (e *AggregationEngine) getGroupKey(doc types.Document, field string) string { if field == "" || field[0] != '$' { return "" } fieldName := field[1:] // 去掉 $ 前缀 value := getNestedValue(doc.Data, fieldName) if value == nil { return "" } // 转换为字符串作为键 switch v := value.(type) { case string: return v case int, int64, float64: return toString(v) default: return toString(value) } } // aggregateGroup 聚合一组文档 func (e *AggregationEngine) aggregateGroup(groupSpec map[string]interface{}, docs []types.Document) map[string]interface{} { result := make(map[string]interface{}) for field, expr := range groupSpec { if field == "_id" { continue } // 处理聚合操作符 if exprMap, ok := expr.(map[string]interface{}); ok { for op, operand := range exprMap { switch op { case "$sum": result[field] = e.sum(docs, operand) case "$avg": result[field] = e.avg(docs, operand) case "$min": result[field] = e.min(docs, operand) case "$max": result[field] = e.max(docs, operand) case "$count": result[field] = len(docs) case "$first": if len(docs) > 0 { result[field] = e.getFieldValue(docs[0], operand) } case "$last": if len(docs) > 0 { result[field] = e.getFieldValue(docs[len(docs)-1], operand) } case "$push": values := make([]interface{}, 0, len(docs)) for _, doc := range docs { values = append(values, e.getFieldValue(doc, operand)) } result[field] = values case "$addToSet": set := make(map[interface{}]bool) for _, doc := range docs { v := e.getFieldValue(doc, operand) set[v] = true } values := make([]interface{}, 0, len(set)) for v := range set { values = append(values, v) } result[field] = values } } } } return result } // sum 计算总和 func (e *AggregationEngine) sum(docs []types.Document, field interface{}) float64 { total := 0.0 for _, doc := range docs { total += toFloat64(e.getFieldValue(doc, field)) } return total } // avg 计算平均值 func (e *AggregationEngine) avg(docs []types.Document, field interface{}) float64 { if len(docs) == 0 { return 0 } return e.sum(docs, field) / float64(len(docs)) } // min 计算最小值 func (e *AggregationEngine) min(docs []types.Document, field interface{}) float64 { if len(docs) == 0 { return 0 } min := toFloat64(e.getFieldValue(docs[0], field)) for _, doc := range docs[1:] { val := toFloat64(e.getFieldValue(doc, field)) if val < min { min = val } } return min } // max 计算最大值 func (e *AggregationEngine) max(docs []types.Document, field interface{}) float64 { if len(docs) == 0 { return 0 } max := toFloat64(e.getFieldValue(docs[0], field)) for _, doc := range docs[1:] { val := toFloat64(e.getFieldValue(doc, field)) if val > max { max = val } } return max } // getFieldValue 获取字段值 func (e *AggregationEngine) getFieldValue(doc types.Document, field interface{}) interface{} { switch f := field.(type) { case string: if len(f) > 0 && f[0] == '$' { return getNestedValue(doc.Data, f[1:]) } return f default: return field } } // executeSort 执行 $sort 阶段 func (e *AggregationEngine) executeSort(spec interface{}, docs []types.Document) ([]types.Document, error) { sortSpec, ok := spec.(map[string]interface{}) if !ok { return docs, nil } // 转换为排序字段映射 sortFields := make(map[string]int) for field, direction := range sortSpec { dir := 1 switch d := direction.(type) { case int: dir = d case int64: dir = int(d) case float64: dir = int(d) } sortFields[field] = dir } // 创建可排序的副本 sorted := make([]types.Document, len(docs)) copy(sorted, docs) sort.Slice(sorted, func(i, j int) bool { return e.compareDocs(sorted[i], sorted[j], sortFields) }) return sorted, nil } // compareDocs 比较两个文档 func (e *AggregationEngine) compareDocs(a, b types.Document, sortFields map[string]int) bool { for field, dir := range sortFields { valA := getNestedValue(a.Data, field) valB := getNestedValue(b.Data, field) cmp := compareValues(valA, valB) if cmp != 0 { if dir < 0 { return cmp > 0 } return cmp < 0 } } return false } // compareValues 比较两个值 func compareValues(a, b interface{}) int { if a == nil && b == nil { return 0 } if a == nil { return -1 } if b == nil { return 1 } // 数值比较 numA := toFloat64(a) numB := toFloat64(b) if numA < numB { return -1 } else if numA > numB { return 1 } return 0 } // executeProject 执行 $project 阶段 func (e *AggregationEngine) executeProject(spec interface{}, docs []types.Document) ([]types.Document, error) { projectSpec, ok := spec.(map[string]interface{}) if !ok { return docs, nil } var results []types.Document for _, doc := range docs { projected := e.projectDocument(doc.Data, projectSpec) results = append(results, types.Document{ ID: doc.ID, Data: projected, }) } return results, nil } // projectDocument 投影文档 func (e *AggregationEngine) projectDocument(data map[string]interface{}, spec map[string]interface{}) map[string]interface{} { result := make(map[string]interface{}) for field, include := range spec { if field == "_id" { // 特殊处理 _id if isFalse(include) { // 排除 _id } else { result["_id"] = data["_id"] } continue } if isTrue(include) { // 包含字段 result[field] = getNestedValue(data, field) } else if isFalse(include) { // 排除字段(在包含模式下不处理) continue } else { // 表达式 result[field] = e.evaluateExpression(data, include) } } return result } // evaluateExpression 评估表达式 func (e *AggregationEngine) evaluateExpression(data map[string]interface{}, expr interface{}) interface{} { // 处理字段引用(以 $ 开头的字符串) if fieldStr, ok := expr.(string); ok && len(fieldStr) > 0 && fieldStr[0] == '$' { fieldName := fieldStr[1:] // 移除 $ 前缀 return getNestedValue(data, fieldName) } if exprMap, ok := expr.(map[string]interface{}); ok { for op, operand := range exprMap { switch op { case "$concat": return e.concat(operand, data) case "$substr", "$substring": return e.substr(operand, data) case "$toUpper": str := e.getFieldValueStr(types.Document{Data: data}, operand) return strings.ToUpper(str) case "$toLower": str := e.getFieldValueStr(types.Document{Data: data}, operand) return strings.ToLower(str) case "$add": return e.add(operand, data) case "$multiply": return e.multiply(operand, data) case "$divide": return e.divide(operand, data) case "$subtract": return e.subtract(operand, data) case "$abs": return e.abs(operand, data) case "$ceil": return e.ceil(operand, data) case "$floor": return e.floor(operand, data) case "$round": return e.round(operand, data) case "$sqrt": return e.sqrt(operand, data) case "$pow": return e.pow(operand, data) case "$size": arr := getNestedValue(data, operand.(string)) if a, ok := arr.([]interface{}); ok { return len(a) } return 0 case "$ifNull": return e.ifNull(operand, data) case "$cond": return e.cond(operand, data) case "$switch": return e.switchExpr(operand, data) case "$trim": return e.trim(operand, data) case "$ltrim": return e.ltrim(operand, data) case "$rtrim": return e.rtrim(operand, data) case "$split": return e.split(operand, data) case "$replaceAll": return e.replaceAll(operand, data) case "$strcasecmp": return e.strcasecmp(operand, data) case "$filter": return e.filter(operand, data) case "$map": return e.mapArr(operand, data) case "$concatArrays": return e.concatArrays(operand, data) case "$slice": return e.slice(operand, data) case "$mergeObjects": return e.mergeObjects(operand, data) case "$objectToArray": return e.objectToArray(operand, data) case "$year": return e.year(operand, data) case "$month": return e.month(operand, data) case "$dayOfMonth": return e.dayOfMonth(operand, data) case "$hour": return e.hour(operand, data) case "$minute": return e.minute(operand, data) case "$second": return e.second(operand, data) case "$dateToString": return e.dateToString(operand, data) case "$dateAdd": return e.dateAdd(operand, data) case "$dateDiff": return e.dateDiff(operand, data) case "$gt": return e.compareGt(operand, data) case "$gte": return e.compareGte(operand, data) case "$lt": return e.compareLt(operand, data) case "$lte": return e.compareLte(operand, data) case "$eq": return e.compareEq(operand, data) case "$ne": return e.compareNe(operand, data) } } } return expr } // executeLimit 执行 $limit 阶段 func (e *AggregationEngine) executeLimit(spec interface{}, docs []types.Document) ([]types.Document, error) { limit := 0 switch l := spec.(type) { case int: limit = l case int64: limit = int(l) case float64: limit = int(l) } if limit <= 0 || limit >= len(docs) { return docs, nil } return docs[:limit], nil } // executeSkip 执行 $skip 阶段 func (e *AggregationEngine) executeSkip(spec interface{}, docs []types.Document) ([]types.Document, error) { skip := 0 switch s := spec.(type) { case int: skip = s case int64: skip = int(s) case float64: skip = int(s) } if skip <= 0 { return docs, nil } if skip >= len(docs) { return []types.Document{}, nil } return docs[skip:], nil } // executeUnwind 执行 $unwind 阶段 func (e *AggregationEngine) executeUnwind(spec interface{}, docs []types.Document) ([]types.Document, error) { var path string var preserveNull bool switch s := spec.(type) { case string: path = s case map[string]interface{}: if p, ok := s["path"].(string); ok { path = p } if pn, ok := s["preserveNullAndEmptyArrays"].(bool); ok { preserveNull = pn } } if path == "" || path[0] != '$' { return docs, nil } fieldPath := path[1:] var results []types.Document for _, doc := range docs { arr := getNestedValue(doc.Data, fieldPath) if arr == nil { if preserveNull { results = append(results, doc) } continue } array, ok := arr.([]interface{}) if !ok || len(array) == 0 { if preserveNull { results = append(results, doc) } continue } for _, item := range array { newData := deepCopyMap(doc.Data) setNestedValue(newData, fieldPath, item) results = append(results, types.Document{ ID: doc.ID, Data: newData, }) } } return results, nil } // executeLookup 执行 $lookup 阶段 func (e *AggregationEngine) executeLookup(spec interface{}, docs []types.Document) ([]types.Document, error) { lookupSpec, ok := spec.(map[string]interface{}) if !ok { return docs, nil } from, _ := lookupSpec["from"].(string) localField, _ := lookupSpec["localField"].(string) foreignField, _ := lookupSpec["foreignField"].(string) as, _ := lookupSpec["as"].(string) if from == "" || as == "" { return docs, nil } // 获取关联集合的数据 foreignDocs, err := e.store.GetAllDocuments(from) if err != nil { return docs, nil // 忽略错误,继续处理 } var results []types.Document for _, doc := range docs { localValue := getNestedValue(doc.Data, localField) var matches []map[string]interface{} for _, foreignDoc := range foreignDocs { foreignValue := getNestedValue(foreignDoc.Data, foreignField) if compareEq(localValue, foreignValue) { matches = append(matches, foreignDoc.Data) } } newData := deepCopyMap(doc.Data) newData[as] = matches results = append(results, types.Document{ ID: doc.ID, Data: newData, }) } return results, nil } // executeCount 执行 $count 阶段 func (e *AggregationEngine) executeCount(spec interface{}, docs []types.Document) ([]types.Document, error) { fieldName, ok := spec.(string) if !ok { fieldName = "count" } return []types.Document{ { ID: "count", Data: map[string]interface{}{ fieldName: len(docs), }, }, }, nil } // 辅助函数 func isTrue(v interface{}) bool { switch val := v.(type) { case bool: return val case int: return val != 0 case float64: return val != 0 } return true } func isFalse(v interface{}) bool { return !isTrue(v) } func toString(v interface{}) string { switch val := v.(type) { case string: return val case int: return string(rune(val)) case int64: return string(rune(val)) case float64: return fmt.Sprintf("%v", val) default: return "" } } // 比较操作符辅助方法 func (e *AggregationEngine) compareGt(operand interface{}, data map[string]interface{}) interface{} { arr, ok := operand.([]interface{}) if !ok || len(arr) != 2 { return false } left := e.evaluateExpression(data, arr[0]) right := e.evaluateExpression(data, arr[1]) return toFloat64(left) > toFloat64(right) } func (e *AggregationEngine) compareGte(operand interface{}, data map[string]interface{}) interface{} { arr, ok := operand.([]interface{}) if !ok || len(arr) != 2 { return false } left := e.evaluateExpression(data, arr[0]) right := e.evaluateExpression(data, arr[1]) return toFloat64(left) >= toFloat64(right) } func (e *AggregationEngine) compareLt(operand interface{}, data map[string]interface{}) interface{} { arr, ok := operand.([]interface{}) if !ok || len(arr) != 2 { return false } left := e.evaluateExpression(data, arr[0]) right := e.evaluateExpression(data, arr[1]) return toFloat64(left) < toFloat64(right) } func (e *AggregationEngine) compareLte(operand interface{}, data map[string]interface{}) interface{} { arr, ok := operand.([]interface{}) if !ok || len(arr) != 2 { return false } left := e.evaluateExpression(data, arr[0]) right := e.evaluateExpression(data, arr[1]) return toFloat64(left) <= toFloat64(right) } func (e *AggregationEngine) compareEq(operand interface{}, data map[string]interface{}) interface{} { arr, ok := operand.([]interface{}) if !ok || len(arr) != 2 { return false } left := e.evaluateExpression(data, arr[0]) right := e.evaluateExpression(data, arr[1]) return left == right } func (e *AggregationEngine) compareNe(operand interface{}, data map[string]interface{}) interface{} { arr, ok := operand.([]interface{}) if !ok || len(arr) != 2 { return false } left := e.evaluateExpression(data, arr[0]) right := e.evaluateExpression(data, arr[1]) return left != right }