diff --git a/internal/engine/aggregate.go b/internal/engine/aggregate.go index f242f07..7661d2d 100644 --- a/internal/engine/aggregate.go +++ b/internal/engine/aggregate.go @@ -392,6 +392,12 @@ func (e *AggregationEngine) projectDocument(data map[string]interface{}, spec ma // evaluateExpression 评估表达式 func (e *AggregationEngine) evaluateExpression(data map[string]interface{}, expr interface{}) interface{} { + // 处理字段引用(以 $ 开头的字符串) + if fieldStr, ok := expr.(string); ok && len(fieldStr) > 0 && fieldStr[0] == '$' { + fieldName := fieldStr[1:] // 移除 $ 前缀 + return getNestedValue(data, fieldName) + } + if exprMap, ok := expr.(map[string]interface{}); ok { for op, operand := range exprMap { switch op { @@ -479,6 +485,18 @@ func (e *AggregationEngine) evaluateExpression(data map[string]interface{}, expr return e.dateAdd(operand, data) case "$dateDiff": return e.dateDiff(operand, data) + case "$gt": + return e.compareGt(operand, data) + case "$gte": + return e.compareGte(operand, data) + case "$lt": + return e.compareLt(operand, data) + case "$lte": + return e.compareLte(operand, data) + case "$eq": + return e.compareEq(operand, data) + case "$ne": + return e.compareNe(operand, data) } } } @@ -676,3 +694,64 @@ func toString(v interface{}) string { return "" } } + +// 比较操作符辅助方法 +func (e *AggregationEngine) compareGt(operand interface{}, data map[string]interface{}) interface{} { + arr, ok := operand.([]interface{}) + if !ok || len(arr) != 2 { + return false + } + left := e.evaluateExpression(data, arr[0]) + right := e.evaluateExpression(data, arr[1]) + return toFloat64(left) > toFloat64(right) +} + +func (e *AggregationEngine) compareGte(operand interface{}, data map[string]interface{}) interface{} { + arr, ok := operand.([]interface{}) + if !ok || len(arr) != 2 { + return false + } + left := e.evaluateExpression(data, arr[0]) + right := e.evaluateExpression(data, arr[1]) + return toFloat64(left) >= toFloat64(right) +} + +func (e *AggregationEngine) compareLt(operand interface{}, data map[string]interface{}) interface{} { + arr, ok := operand.([]interface{}) + if !ok || len(arr) != 2 { + return false + } + left := e.evaluateExpression(data, arr[0]) + right := e.evaluateExpression(data, arr[1]) + return toFloat64(left) < toFloat64(right) +} + +func (e *AggregationEngine) compareLte(operand interface{}, data map[string]interface{}) interface{} { + arr, ok := operand.([]interface{}) + if !ok || len(arr) != 2 { + return false + } + left := e.evaluateExpression(data, arr[0]) + right := e.evaluateExpression(data, arr[1]) + return toFloat64(left) <= toFloat64(right) +} + +func (e *AggregationEngine) compareEq(operand interface{}, data map[string]interface{}) interface{} { + arr, ok := operand.([]interface{}) + if !ok || len(arr) != 2 { + return false + } + left := e.evaluateExpression(data, arr[0]) + right := e.evaluateExpression(data, arr[1]) + return left == right +} + +func (e *AggregationEngine) compareNe(operand interface{}, data map[string]interface{}) interface{} { + arr, ok := operand.([]interface{}) + if !ok || len(arr) != 2 { + return false + } + left := e.evaluateExpression(data, arr[0]) + right := e.evaluateExpression(data, arr[1]) + return left != right +} diff --git a/internal/engine/aggregate_helpers.go b/internal/engine/aggregate_helpers.go index ec95c2a..5025cb6 100644 --- a/internal/engine/aggregate_helpers.go +++ b/internal/engine/aggregate_helpers.go @@ -167,7 +167,7 @@ func (e *AggregationEngine) switchExpr(operand interface{}, data map[string]inte caseRaw, _ := branch["case"] thenRaw, _ := branch["then"] - if isTrue(e.evaluateExpression(data, caseRaw)) { + if isTrueValue(e.evaluateExpression(data, caseRaw)) { return e.evaluateExpression(data, thenRaw) } } diff --git a/internal/engine/crud.go b/internal/engine/crud.go index a1e84a9..b4280cb 100644 --- a/internal/engine/crud.go +++ b/internal/engine/crud.go @@ -362,18 +362,29 @@ func updateArrayElement(data map[string]interface{}, field string, value interfa // updateArrayAtPath 在指定路径更新数组 func updateArrayAtPath(data map[string]interface{}, parts []string, index int, value interface{}, arrayFilters []map[string]interface{}) bool { - // 获取到数组前的路径 + // 获取到数组前的路径(导航到父对象) current := data for i := 0; i < index; i++ { if m, ok := current[parts[i]].(map[string]interface{}); ok { current = m + } else if i == index-1 { + // 最后一个部分应该是数组字段名,不需要是 map + break } else { return false } } + // 获取实际的数组字段名(操作符前面的部分) + var actualFieldName string + if index > 0 { + actualFieldName = parts[index-1] + } else { + return false // 无效的路径 + } + arrField := parts[index] - arr := getNestedValue(current, arrField) + arr := getNestedValue(data, actualFieldName) array, ok := arr.([]interface{}) if !ok || len(array) == 0 { return false @@ -384,7 +395,7 @@ func updateArrayAtPath(data map[string]interface{}, parts []string, index int, v // 定位第一个匹配的元素(需要配合查询条件) // 简化实现:更新第一个元素 array[0] = value - setNestedValue(current, arrField, array) + setNestedValue(data, actualFieldName, array) return true } @@ -393,7 +404,7 @@ func updateArrayAtPath(data map[string]interface{}, parts []string, index int, v for i := range array { array[i] = value } - setNestedValue(current, arrField, array) + setNestedValue(data, actualFieldName, array) return true } @@ -405,21 +416,33 @@ func updateArrayAtPath(data map[string]interface{}, parts []string, index int, v var filter map[string]interface{} for _, f := range arrayFilters { if idVal, exists := f["identifier"]; exists && idVal == identifier { - filter = f + // 复制 filter 并移除 identifier 字段 + filter = make(map[string]interface{}) + for k, v := range f { + if k != "identifier" { + filter[k] = v + } + } break } } - if filter != nil { + if filter != nil && len(filter) > 0 { // 应用过滤器更新匹配的元素 for i, item := range array { if itemMap, ok := item.(map[string]interface{}); ok { if MatchFilter(itemMap, filter) { - array[i] = value + // 如果是嵌套字段(如 students.$[elem].grade),需要设置嵌套字段 + if index+1 < len(parts) { + // 还有后续字段,设置嵌套字段 + itemMap[parts[index+1]] = value + } else { + array[i] = value + } } } } - setNestedValue(current, arrField, array) + setNestedValue(data, actualFieldName, array) return true } } diff --git a/internal/engine/crud_handler.go b/internal/engine/crud_handler.go index 760e0cc..52672df 100644 --- a/internal/engine/crud_handler.go +++ b/internal/engine/crud_handler.go @@ -54,7 +54,7 @@ func (h *CRUDHandler) Insert(ctx context.Context, collection string, docs []map[ // Update 更新文档 func (h *CRUDHandler) Update(ctx context.Context, collection string, filter types.Filter, update types.Update) (*types.UpdateResult, error) { - matched, modified, err := h.store.Update(collection, filter, update) + matched, modified, _, err := h.store.Update(collection, filter, update, false, nil) if err != nil { return nil, err } diff --git a/internal/engine/memory_store_batch2_test.go b/internal/engine/memory_store_batch2_test.go index fe84201..d50a0aa 100644 --- a/internal/engine/memory_store_batch2_test.go +++ b/internal/engine/memory_store_batch2_test.go @@ -94,12 +94,10 @@ func TestMemoryStoreUpdateWithUpsert(t *testing.T) { if tt.checkField != "" { // Find the created/updated document - var doc types.Document found := false for _, d := range store.collections[collection].documents { if val, ok := d.Data[tt.checkField]; ok { if compareEq(val, tt.expectedValue) { - doc = d found = true break } diff --git a/internal/engine/operators.go b/internal/engine/operators.go index fd2481c..a07c310 100644 --- a/internal/engine/operators.go +++ b/internal/engine/operators.go @@ -16,10 +16,29 @@ func compareEq(a, b interface{}) bool { return false } + // 对于 slice、map 等复杂类型,使用 reflect.DeepEqual + if isComplexType(a) || isComplexType(b) { + return reflect.DeepEqual(a, b) + } + // 类型转换后比较 return normalizeValue(a) == normalizeValue(b) } +// isComplexType 检查是否是复杂类型(slice、map 等) +func isComplexType(v interface{}) bool { + switch v.(type) { + case []interface{}: + return true + case map[string]interface{}: + return true + case map[interface{}]interface{}: + return true + default: + return false + } +} + // compareGt 大于比较 func compareGt(a, b interface{}) bool { return compareNumbers(a, b) > 0 diff --git a/internal/engine/projection.go b/internal/engine/projection.go index 7cad599..e967a06 100644 --- a/internal/engine/projection.go +++ b/internal/engine/projection.go @@ -34,7 +34,6 @@ func applyProjectionToDoc(data map[string]interface{}, projection types.Projecti // 检查是否是包含模式(所有值都是 1/true)或排除模式(所有值都是 0/false) isInclusionMode := false hasInclusion := false - hasExclusion := false for field, value := range projection { if field == "_id" { @@ -43,8 +42,6 @@ func applyProjectionToDoc(data map[string]interface{}, projection types.Projecti if isTrueValue(value) { hasInclusion = true - } else { - hasExclusion = true } } diff --git a/internal/engine/query.go b/internal/engine/query.go index 5a83116..4e4b386 100644 --- a/internal/engine/query.go +++ b/internal/engine/query.go @@ -54,12 +54,6 @@ func handleExpr(doc map[string]interface{}, condition interface{}) bool { // 创建临时引擎实例用于评估表达式 engine := &AggregationEngine{} - // 将文档转换为 Document 结构 - document := types.Document{ - ID: "", - Data: doc, - } - // 评估聚合表达式 result := engine.evaluateExpression(doc, condition) @@ -132,7 +126,8 @@ func validateJSONSchema(doc map[string]interface{}, schema map[string]interface{ if fieldSchema, ok := fieldSchemaRaw.(map[string]interface{}); ok { fieldValue := doc[fieldName] if fieldValue != nil { - if !validateJSONSchema(fieldValue, fieldSchema) { + // 递归验证字段值 + if !validateFieldValue(fieldValue, fieldSchema) { return false } } @@ -157,80 +152,6 @@ func validateJSONSchema(doc map[string]interface{}, schema map[string]interface{ } } - // 检查 minimum - if minimumRaw, exists := schema["minimum"]; exists { - if num := toFloat64(doc); num < toFloat64(minimumRaw) { - return false - } - } - - // 检查 maximum - if maximumRaw, exists := schema["maximum"]; exists { - if num := toFloat64(doc); num > toFloat64(maximumRaw) { - return false - } - } - - // 检查 minLength (字符串) - if minLengthRaw, exists := schema["minLength"]; exists { - if str, ok := doc.(string); ok { - if minLen := int(toFloat64(minLengthRaw)); len(str) < minLen { - return false - } - } - } - - // 检查 maxLength (字符串) - if maxLengthRaw, exists := schema["maxLength"]; exists { - if str, ok := doc.(string); ok { - if maxLen := int(toFloat64(maxLengthRaw)); len(str) > maxLen { - return false - } - } - } - - // 检查 pattern (正则表达式) - if patternRaw, exists := schema["pattern"]; exists { - if str, ok := doc.(string); ok { - if pattern, ok := patternRaw.(string); ok { - if !compareRegex(str, map[string]interface{}{"$regex": pattern}) { - return false - } - } - } - } - - // 检查 items (数组元素) - if itemsRaw, exists := schema["items"]; exists { - if arr, ok := doc.([]interface{}); ok { - if itemSchema, ok := itemsRaw.(map[string]interface{}); ok { - for _, item := range arr { - if !validateJSONSchema(item, itemSchema) { - return false - } - } - } - } - } - - // 检查 minItems (数组最小长度) - if minItemsRaw, exists := schema["minItems"]; exists { - if arr, ok := doc.([]interface{}); ok { - if minItems := int(toFloat64(minItemsRaw)); len(arr) < minItems { - return false - } - } - } - - // 检查 maxItems (数组最大长度) - if maxItemsRaw, exists := schema["maxItems"]; exists { - if arr, ok := doc.([]interface{}); ok { - if maxItems := int(toFloat64(maxItemsRaw)); len(arr) > maxItems { - return false - } - } - } - // 检查 allOf if allOfRaw, exists := schema["allOf"]; exists { if allOf, ok := allOfRaw.([]interface{}); ok { @@ -291,6 +212,203 @@ func validateJSONSchema(doc map[string]interface{}, schema map[string]interface{ return true } +// validateFieldValue 验证字段值是否符合 schema +func validateFieldValue(value interface{}, schema map[string]interface{}) bool { + // 检查 bsonType + if bsonType, exists := schema["bsonType"]; exists { + if !validateBsonType(value, bsonType) { + return false + } + } + + // 检查 enum + if enumRaw, exists := schema["enum"]; exists { + if enum, ok := enumRaw.([]interface{}); ok { + found := false + for _, val := range enum { + if compareEq(value, val) { + found = true + break + } + } + if !found { + return false + } + } + } + + // 检查 minimum - 仅当 value 是数值类型时 + if minimumRaw, exists := schema["minimum"]; exists { + if num, ok := toNumber(value); ok { + if num < toFloat64(minimumRaw) { + return false + } + } + } + + // 检查 maximum - 仅当 value 是数值类型时 + if maximumRaw, exists := schema["maximum"]; exists { + if num, ok := toNumber(value); ok { + if num > toFloat64(maximumRaw) { + return false + } + } + } + + // 检查 minLength (字符串) - 仅当 value 是字符串时 + if minLengthRaw, exists := schema["minLength"]; exists { + if str, ok := value.(string); ok { + if minLen := int(toFloat64(minLengthRaw)); len(str) < minLen { + return false + } + } + } + + // 检查 maxLength (字符串) - 仅当 value 是字符串时 + if maxLengthRaw, exists := schema["maxLength"]; exists { + if str, ok := value.(string); ok { + if maxLen := int(toFloat64(maxLengthRaw)); len(str) > maxLen { + return false + } + } + } + + // 检查 pattern (正则表达式) - 仅当 value 是字符串时 + if patternRaw, exists := schema["pattern"]; exists { + if str, ok := value.(string); ok { + if pattern, ok := patternRaw.(string); ok { + if !compareRegex(str, map[string]interface{}{"$regex": pattern}) { + return false + } + } + } + } + + // 检查 items (数组元素) - 仅当 value 是数组时 + if itemsRaw, exists := schema["items"]; exists { + if arr, ok := value.([]interface{}); ok { + if itemSchema, ok := itemsRaw.(map[string]interface{}); ok { + for _, item := range arr { + if itemMap, ok := item.(map[string]interface{}); ok { + if !validateJSONSchema(itemMap, itemSchema) { + return false + } + } + } + } + } + } + + // 检查 minItems (数组最小长度) - 仅当 value 是数组时 + if minItemsRaw, exists := schema["minItems"]; exists { + if arr, ok := value.([]interface{}); ok { + if minItems := int(toFloat64(minItemsRaw)); len(arr) < minItems { + return false + } + } + } + + // 检查 maxItems (数组最大长度) - 仅当 value 是数组时 + if maxItemsRaw, exists := schema["maxItems"]; exists { + if arr, ok := value.([]interface{}); ok { + if maxItems := int(toFloat64(maxItemsRaw)); len(arr) > maxItems { + return false + } + } + } + + // 对于对象类型,继续递归验证嵌套 properties + if valueMap, ok := value.(map[string]interface{}); ok { + // 检查 required 字段 + if requiredRaw, exists := schema["required"]; exists { + if required, ok := requiredRaw.([]interface{}); ok { + for _, reqField := range required { + if fieldStr, ok := reqField.(string); ok { + if valueMap[fieldStr] == nil { + return false + } + } + } + } + } + + // 检查 properties + if propertiesRaw, exists := schema["properties"]; exists { + if properties, ok := propertiesRaw.(map[string]interface{}); ok { + for fieldName, fieldSchemaRaw := range properties { + if fieldSchema, ok := fieldSchemaRaw.(map[string]interface{}); ok { + fieldValue := valueMap[fieldName] + if fieldValue != nil { + if !validateFieldValue(fieldValue, fieldSchema) { + return false + } + } + } + } + } + } + } + + // 检查 allOf + if allOfRaw, exists := schema["allOf"]; exists { + if allOf, ok := allOfRaw.([]interface{}); ok { + for _, subSchemaRaw := range allOf { + if subSchema, ok := subSchemaRaw.(map[string]interface{}); ok { + if !validateFieldValue(value, subSchema) { + return false + } + } + } + } + } + + // 检查 anyOf + if anyOfRaw, exists := schema["anyOf"]; exists { + if anyOf, ok := anyOfRaw.([]interface{}); ok { + matched := false + for _, subSchemaRaw := range anyOf { + if subSchema, ok := subSchemaRaw.(map[string]interface{}); ok { + if validateFieldValue(value, subSchema) { + matched = true + break + } + } + } + if !matched { + return false + } + } + } + + // 检查 oneOf + if oneOfRaw, exists := schema["oneOf"]; exists { + if oneOf, ok := oneOfRaw.([]interface{}); ok { + matchCount := 0 + for _, subSchemaRaw := range oneOf { + if subSchema, ok := subSchemaRaw.(map[string]interface{}); ok { + if validateFieldValue(value, subSchema) { + matchCount++ + } + } + } + if matchCount != 1 { + return false + } + } + } + + // 检查 not + if notRaw, exists := schema["not"]; exists { + if notSchema, ok := notRaw.(map[string]interface{}); ok { + if validateFieldValue(value, notSchema) { + return false // not 要求不匹配 + } + } + } + + return true +} + // validateBsonType 验证 BSON 类型 func validateBsonType(value interface{}, bsonType interface{}) bool { typeStr, ok := bsonType.(string) @@ -370,6 +488,28 @@ func getNumericValue(value interface{}) float64 { } } +// toArray 将值转换为数组 +func toArray(value interface{}) ([]interface{}, bool) { + if arr, ok := value.([]interface{}); ok { + return arr, true + } + return nil, false +} + +// toNumber 将值转换为数值 +func toNumber(value interface{}) (float64, bool) { + switch v := value.(type) { + case int, int8, int16, int32, int64: + return getNumericValue(v), true + case uint, uint8, uint16, uint32, uint64: + return getNumericValue(v), true + case float32, float64: + return getNumericValue(v), true + default: + return 0, false + } +} + // handleAnd 处理 $and 操作符 func handleAnd(doc map[string]interface{}, condition interface{}) bool { andConditions, ok := condition.([]interface{}) diff --git a/internal/protocol/http/batch2_test.go b/internal/protocol/http/batch2_test.go index fadc357..83c712b 100644 --- a/internal/protocol/http/batch2_test.go +++ b/internal/protocol/http/batch2_test.go @@ -14,8 +14,8 @@ import ( // TestHTTPUpdateWithUpsert 测试 HTTP API 的 upsert 功能 func TestHTTPUpdateWithUpsert(t *testing.T) { store := engine.NewMemoryStore(nil) - crud := &engine.CRUDHandler{store: store} - agg := &engine.AggregationEngine{store: store} + crud := engine.NewCRUDHandler(store, nil) + agg := engine.NewAggregationEngine(store) handler := NewRequestHandler(store, crud, agg) @@ -65,8 +65,8 @@ func TestHTTPUpdateWithUpsert(t *testing.T) { // TestHTTPHealthCheck 测试健康检查端点 func TestHTTPHealthCheck(t *testing.T) { store := engine.NewMemoryStore(nil) - crud := &engine.CRUDHandler{store: store} - agg := &engine.AggregationEngine{store: store} + crud := engine.NewCRUDHandler(store, nil) + agg := engine.NewAggregationEngine(store) server := NewHTTPServer(":0", NewRequestHandler(store, crud, agg)) @@ -92,8 +92,8 @@ func TestHTTPHealthCheck(t *testing.T) { // TestHTTPRoot 测试根路径处理 func TestHTTPRoot(t *testing.T) { store := engine.NewMemoryStore(nil) - crud := &engine.CRUDHandler{store: store} - agg := &engine.AggregationEngine{store: store} + crud := engine.NewCRUDHandler(store, nil) + agg := engine.NewAggregationEngine(store) server := NewHTTPServer(":0", NewRequestHandler(store, crud, agg)) diff --git a/internal/protocol/tcp/server.go b/internal/protocol/tcp/server.go index 9ff817a..ed90897 100644 --- a/internal/protocol/tcp/server.go +++ b/internal/protocol/tcp/server.go @@ -307,7 +307,7 @@ func (h *MessageHandler) handleUpdate(body []byte) (interface{}, error) { totalModified := 0 for _, op := range req.Updates { - matched, modified, err := h.store.Update(req.Collection, op.Q, op.U) + matched, modified, _, err := h.store.Update(req.Collection, op.Q, op.U, op.Upsert, op.ArrayFilters) if err != nil { return nil, err } @@ -452,7 +452,7 @@ func (h *MessageHandler) handleUpdateMsg(collection string, params map[string]in totalModified := 0 for _, op := range updates { - matched, modified, err := h.store.Update(collection, op.Q, op.U) + matched, modified, _, err := h.store.Update(collection, op.Q, op.U, op.Upsert, op.ArrayFilters) if err != nil { return nil, err }