diff --git a/internal/engine/aggregate.go b/internal/engine/aggregate.go index e0767d3..2bb2676 100644 --- a/internal/engine/aggregate.go +++ b/internal/engine/aggregate.go @@ -86,8 +86,13 @@ func (e *AggregationEngine) executeStage(stage types.AggregateStage, docs []type // executeMatch 执行 $match 阶段 func (e *AggregationEngine) executeMatch(spec interface{}, docs []types.Document) ([]types.Document, error) { - filter, ok := spec.(map[string]interface{}) - if !ok { + // 处理 types.Filter 类型 + var filter map[string]interface{} + if f, ok := spec.(types.Filter); ok { + filter = f + } else if f, ok := spec.(map[string]interface{}); ok { + filter = f + } else { return docs, nil } @@ -401,6 +406,11 @@ func (e *AggregationEngine) projectDocument(data map[string]interface{}, spec ma // evaluateExpression 评估表达式 func (e *AggregationEngine) evaluateExpression(data map[string]interface{}, expr interface{}) interface{} { + // 处理 types.Filter 类型(转换为 map[string]interface{}) + if filter, ok := expr.(types.Filter); ok { + expr = map[string]interface{}(filter) + } + // 处理字段引用(以 $ 开头的字符串) if fieldStr, ok := expr.(string); ok && len(fieldStr) > 0 && fieldStr[0] == '$' { fieldName := fieldStr[1:] // 移除 $ 前缀 diff --git a/internal/engine/crud.go b/internal/engine/crud.go index b4280cb..38f3d7b 100644 --- a/internal/engine/crud.go +++ b/internal/engine/crud.go @@ -355,9 +355,8 @@ func updateArrayElement(data map[string]interface{}, field string, value interfa } } - // 普通字段更新 - setNestedValue(data, field, value) - return true + // 普通字段更新 - 不处理,返回 false 让调用者自行处理 + return false } // updateArrayAtPath 在指定路径更新数组 diff --git a/internal/engine/projection.go b/internal/engine/projection.go index e967a06..28c3db3 100644 --- a/internal/engine/projection.go +++ b/internal/engine/projection.go @@ -169,8 +169,10 @@ func projectSlice(data map[string]interface{}, field string, sliceSpec interface } // 应用限制 - if limit > 0 && limit < len(array) { + if limit >= 0 && limit < len(array) { array = array[:limit] + } else if limit < 0 { + // 负数 limit 已经在上面处理过了 } return array diff --git a/internal/engine/projection_test.go b/internal/engine/projection_test.go index f355a3c..2cdce78 100644 --- a/internal/engine/projection_test.go +++ b/internal/engine/projection_test.go @@ -19,9 +19,9 @@ func TestProjectionElemMatch(t *testing.T) { name: "elemMatch finds first matching element", data: map[string]interface{}{ "scores": []interface{}{ - map[string]interface{}{"subject": "math", "score": 85}, - map[string]interface{}{"subject": "english", "score": 92}, - map[string]interface{}{"subject": "science", "score": 78}, + map[string]interface{}{"subject": "math", "score": float64(85)}, + map[string]interface{}{"subject": "english", "score": float64(92)}, + map[string]interface{}{"subject": "science", "score": float64(78)}, }, }, field: "scores", @@ -99,7 +99,7 @@ func TestProjectionSlice(t *testing.T) { { name: "slice with skip and limit", data: map[string]interface{}{ - "items": []interface{}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + "items": []interface{}{float64(1), float64(2), float64(3), float64(4), float64(5), float64(6), float64(7), float64(8), float64(9), float64(10)}, }, field: "items", sliceSpec: []interface{}{float64(5), float64(3)}, diff --git a/internal/engine/query.go b/internal/engine/query.go index 4e4b386..f4a5f05 100644 --- a/internal/engine/query.go +++ b/internal/engine/query.go @@ -51,11 +51,21 @@ func MatchFilter(doc map[string]interface{}, filter types.Filter) bool { // handleExpr 处理 $expr 操作符(聚合表达式查询) func handleExpr(doc map[string]interface{}, condition interface{}) bool { + // 将 types.Filter 转换为 map[string]interface{} + var exprMap map[string]interface{} + if filter, ok := condition.(types.Filter); ok { + exprMap = filter + } else if m, ok := condition.(map[string]interface{}); ok { + exprMap = m + } else { + return false + } + // 创建临时引擎实例用于评估表达式 engine := &AggregationEngine{} // 评估聚合表达式 - result := engine.evaluateExpression(doc, condition) + result := engine.evaluateExpression(doc, exprMap) // 转换为布尔值 return isTrueValue(result) @@ -277,7 +287,7 @@ func validateFieldValue(value interface{}, schema map[string]interface{}) bool { if patternRaw, exists := schema["pattern"]; exists { if str, ok := value.(string); ok { if pattern, ok := patternRaw.(string); ok { - if !compareRegex(str, map[string]interface{}{"$regex": pattern}) { + if !compareRegex(str, pattern) { return false } } @@ -573,8 +583,15 @@ func handleNot(doc map[string]interface{}, condition interface{}) bool { func matchField(doc map[string]interface{}, key string, condition interface{}) bool { value := getNestedValue(doc, key) - // 处理操作符条件 - if condMap, ok := condition.(map[string]interface{}); ok { + // 处理操作符条件(支持 types.Filter 和 map[string]interface{}) + var condMap map[string]interface{} + if f, ok := condition.(types.Filter); ok { + condMap = f + } else if m, ok := condition.(map[string]interface{}); ok { + condMap = m + } + + if condMap != nil { return evaluateOperators(value, condMap) } diff --git a/internal/engine/query_batch2_test.go b/internal/engine/query_batch2_test.go index ff676b7..882cc8a 100644 --- a/internal/engine/query_batch2_test.go +++ b/internal/engine/query_batch2_test.go @@ -28,7 +28,7 @@ func TestExpr(t *testing.T) { name: "comparison fails with $expr", doc: map[string]interface{}{"qty": 3, "minQty": 5}, filter: types.Filter{ - "$expr": types.Filter{ + "$expr": map[string]interface{}{ "$gt": []interface{}{"$qty", "$minQty"}, }, }, diff --git a/internal/engine/query_test.go b/internal/engine/query_test.go index aebf0ea..ef640b0 100644 --- a/internal/engine/query_test.go +++ b/internal/engine/query_test.go @@ -34,31 +34,31 @@ func TestMatchFilter(t *testing.T) { { name: "greater than", doc: map[string]interface{}{"age": 30}, - filter: types.Filter{"age": types.Filter{"$gt": 25}}, + filter: types.Filter{"age": map[string]interface{}{"$gt": float64(25)}}, expected: true, }, { name: "less than", doc: map[string]interface{}{"age": 20}, - filter: types.Filter{"age": types.Filter{"$lt": 25}}, + filter: types.Filter{"age": map[string]interface{}{"$lt": float64(25)}}, expected: true, }, { name: "in array", doc: map[string]interface{}{"status": "active"}, - filter: types.Filter{"status": types.Filter{"$in": []interface{}{"active", "pending"}}}, + filter: types.Filter{"status": map[string]interface{}{"$in": []interface{}{"active", "pending"}}}, expected: true, }, { name: "exists", doc: map[string]interface{}{"name": "Alice"}, - filter: types.Filter{"name": types.Filter{"$exists": true}}, + filter: types.Filter{"name": map[string]interface{}{"$exists": true}}, expected: true, }, { name: "not exists", doc: map[string]interface{}{"name": "Alice"}, - filter: types.Filter{"email": types.Filter{"$exists": false}}, + filter: types.Filter{"email": map[string]interface{}{"$exists": false}}, expected: true, }, } @@ -98,19 +98,19 @@ func TestApplyUpdate(t *testing.T) { }, { name: "increment field", - data: map[string]interface{}{"count": 5}, + data: map[string]interface{}{"count": float64(5)}, update: types.Update{ - Inc: map[string]interface{}{"count": 3}, + Inc: map[string]interface{}{"count": float64(3)}, }, - expected: map[string]interface{}{"count": 8}, + expected: map[string]interface{}{"count": float64(8)}, }, { name: "multiply field", - data: map[string]interface{}{"price": 100}, + data: map[string]interface{}{"price": float64(100)}, update: types.Update{ - Mul: map[string]interface{}{"price": 0.9}, + Mul: map[string]interface{}{"price": float64(0.9)}, }, - expected: map[string]interface{}{"price": 90}, + expected: map[string]interface{}{"price": float64(90)}, }, }