feat(engine): 添加聚合引擎比较操作符支持并优化查询功能

- 实现 $gt、$gte、$lt、$lte、$eq、$ne 比较操作符
- 添加字段引用处理(以 $ 开头的字符串)
- 优化 updateArrayAtPath 方法中的路径解析逻辑
- 改进 CRUDHandler 中的更新操作参数传递
- 增强 JSON Schema 验证功能,支持 allOf、anyOf、oneOf、not 等操作符
- 优化 HTTP 和 TCP 协议层的处理器初始化方式
- 修复内存存储测试中的文档查找逻辑
- 改进 isTrueValue 在 switch 表达式中的使用
This commit is contained in:
kingecg 2026-03-13 21:48:44 +08:00
parent 9847384f9b
commit 76b86b4b43
10 changed files with 360 additions and 104 deletions

View File

@ -392,6 +392,12 @@ func (e *AggregationEngine) projectDocument(data map[string]interface{}, spec ma
// evaluateExpression 评估表达式
func (e *AggregationEngine) evaluateExpression(data map[string]interface{}, expr interface{}) interface{} {
// 处理字段引用(以 $ 开头的字符串)
if fieldStr, ok := expr.(string); ok && len(fieldStr) > 0 && fieldStr[0] == '$' {
fieldName := fieldStr[1:] // 移除 $ 前缀
return getNestedValue(data, fieldName)
}
if exprMap, ok := expr.(map[string]interface{}); ok {
for op, operand := range exprMap {
switch op {
@ -479,6 +485,18 @@ func (e *AggregationEngine) evaluateExpression(data map[string]interface{}, expr
return e.dateAdd(operand, data)
case "$dateDiff":
return e.dateDiff(operand, data)
case "$gt":
return e.compareGt(operand, data)
case "$gte":
return e.compareGte(operand, data)
case "$lt":
return e.compareLt(operand, data)
case "$lte":
return e.compareLte(operand, data)
case "$eq":
return e.compareEq(operand, data)
case "$ne":
return e.compareNe(operand, data)
}
}
}
@ -676,3 +694,64 @@ func toString(v interface{}) string {
return ""
}
}
// 比较操作符辅助方法
func (e *AggregationEngine) compareGt(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return toFloat64(left) > toFloat64(right)
}
func (e *AggregationEngine) compareGte(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return toFloat64(left) >= toFloat64(right)
}
func (e *AggregationEngine) compareLt(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return toFloat64(left) < toFloat64(right)
}
func (e *AggregationEngine) compareLte(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return toFloat64(left) <= toFloat64(right)
}
func (e *AggregationEngine) compareEq(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return left == right
}
func (e *AggregationEngine) compareNe(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return left != right
}

View File

@ -167,7 +167,7 @@ func (e *AggregationEngine) switchExpr(operand interface{}, data map[string]inte
caseRaw, _ := branch["case"]
thenRaw, _ := branch["then"]
if isTrue(e.evaluateExpression(data, caseRaw)) {
if isTrueValue(e.evaluateExpression(data, caseRaw)) {
return e.evaluateExpression(data, thenRaw)
}
}

View File

@ -362,18 +362,29 @@ func updateArrayElement(data map[string]interface{}, field string, value interfa
// updateArrayAtPath 在指定路径更新数组
func updateArrayAtPath(data map[string]interface{}, parts []string, index int, value interface{}, arrayFilters []map[string]interface{}) bool {
// 获取到数组前的路径
// 获取到数组前的路径(导航到父对象)
current := data
for i := 0; i < index; i++ {
if m, ok := current[parts[i]].(map[string]interface{}); ok {
current = m
} else if i == index-1 {
// 最后一个部分应该是数组字段名,不需要是 map
break
} else {
return false
}
}
// 获取实际的数组字段名(操作符前面的部分)
var actualFieldName string
if index > 0 {
actualFieldName = parts[index-1]
} else {
return false // 无效的路径
}
arrField := parts[index]
arr := getNestedValue(current, arrField)
arr := getNestedValue(data, actualFieldName)
array, ok := arr.([]interface{})
if !ok || len(array) == 0 {
return false
@ -384,7 +395,7 @@ func updateArrayAtPath(data map[string]interface{}, parts []string, index int, v
// 定位第一个匹配的元素(需要配合查询条件)
// 简化实现:更新第一个元素
array[0] = value
setNestedValue(current, arrField, array)
setNestedValue(data, actualFieldName, array)
return true
}
@ -393,7 +404,7 @@ func updateArrayAtPath(data map[string]interface{}, parts []string, index int, v
for i := range array {
array[i] = value
}
setNestedValue(current, arrField, array)
setNestedValue(data, actualFieldName, array)
return true
}
@ -405,21 +416,33 @@ func updateArrayAtPath(data map[string]interface{}, parts []string, index int, v
var filter map[string]interface{}
for _, f := range arrayFilters {
if idVal, exists := f["identifier"]; exists && idVal == identifier {
filter = f
// 复制 filter 并移除 identifier 字段
filter = make(map[string]interface{})
for k, v := range f {
if k != "identifier" {
filter[k] = v
}
}
break
}
}
if filter != nil {
if filter != nil && len(filter) > 0 {
// 应用过滤器更新匹配的元素
for i, item := range array {
if itemMap, ok := item.(map[string]interface{}); ok {
if MatchFilter(itemMap, filter) {
array[i] = value
// 如果是嵌套字段(如 students.$[elem].grade需要设置嵌套字段
if index+1 < len(parts) {
// 还有后续字段,设置嵌套字段
itemMap[parts[index+1]] = value
} else {
array[i] = value
}
}
}
}
setNestedValue(current, arrField, array)
setNestedValue(data, actualFieldName, array)
return true
}
}

View File

@ -54,7 +54,7 @@ func (h *CRUDHandler) Insert(ctx context.Context, collection string, docs []map[
// Update 更新文档
func (h *CRUDHandler) Update(ctx context.Context, collection string, filter types.Filter, update types.Update) (*types.UpdateResult, error) {
matched, modified, err := h.store.Update(collection, filter, update)
matched, modified, _, err := h.store.Update(collection, filter, update, false, nil)
if err != nil {
return nil, err
}

View File

@ -94,12 +94,10 @@ func TestMemoryStoreUpdateWithUpsert(t *testing.T) {
if tt.checkField != "" {
// Find the created/updated document
var doc types.Document
found := false
for _, d := range store.collections[collection].documents {
if val, ok := d.Data[tt.checkField]; ok {
if compareEq(val, tt.expectedValue) {
doc = d
found = true
break
}

View File

@ -16,10 +16,29 @@ func compareEq(a, b interface{}) bool {
return false
}
// 对于 slice、map 等复杂类型,使用 reflect.DeepEqual
if isComplexType(a) || isComplexType(b) {
return reflect.DeepEqual(a, b)
}
// 类型转换后比较
return normalizeValue(a) == normalizeValue(b)
}
// isComplexType 检查是否是复杂类型slice、map 等)
func isComplexType(v interface{}) bool {
switch v.(type) {
case []interface{}:
return true
case map[string]interface{}:
return true
case map[interface{}]interface{}:
return true
default:
return false
}
}
// compareGt 大于比较
func compareGt(a, b interface{}) bool {
return compareNumbers(a, b) > 0

View File

@ -34,7 +34,6 @@ func applyProjectionToDoc(data map[string]interface{}, projection types.Projecti
// 检查是否是包含模式(所有值都是 1/true或排除模式所有值都是 0/false
isInclusionMode := false
hasInclusion := false
hasExclusion := false
for field, value := range projection {
if field == "_id" {
@ -43,8 +42,6 @@ func applyProjectionToDoc(data map[string]interface{}, projection types.Projecti
if isTrueValue(value) {
hasInclusion = true
} else {
hasExclusion = true
}
}

View File

@ -54,12 +54,6 @@ func handleExpr(doc map[string]interface{}, condition interface{}) bool {
// 创建临时引擎实例用于评估表达式
engine := &AggregationEngine{}
// 将文档转换为 Document 结构
document := types.Document{
ID: "",
Data: doc,
}
// 评估聚合表达式
result := engine.evaluateExpression(doc, condition)
@ -132,7 +126,8 @@ func validateJSONSchema(doc map[string]interface{}, schema map[string]interface{
if fieldSchema, ok := fieldSchemaRaw.(map[string]interface{}); ok {
fieldValue := doc[fieldName]
if fieldValue != nil {
if !validateJSONSchema(fieldValue, fieldSchema) {
// 递归验证字段值
if !validateFieldValue(fieldValue, fieldSchema) {
return false
}
}
@ -157,80 +152,6 @@ func validateJSONSchema(doc map[string]interface{}, schema map[string]interface{
}
}
// 检查 minimum
if minimumRaw, exists := schema["minimum"]; exists {
if num := toFloat64(doc); num < toFloat64(minimumRaw) {
return false
}
}
// 检查 maximum
if maximumRaw, exists := schema["maximum"]; exists {
if num := toFloat64(doc); num > toFloat64(maximumRaw) {
return false
}
}
// 检查 minLength (字符串)
if minLengthRaw, exists := schema["minLength"]; exists {
if str, ok := doc.(string); ok {
if minLen := int(toFloat64(minLengthRaw)); len(str) < minLen {
return false
}
}
}
// 检查 maxLength (字符串)
if maxLengthRaw, exists := schema["maxLength"]; exists {
if str, ok := doc.(string); ok {
if maxLen := int(toFloat64(maxLengthRaw)); len(str) > maxLen {
return false
}
}
}
// 检查 pattern (正则表达式)
if patternRaw, exists := schema["pattern"]; exists {
if str, ok := doc.(string); ok {
if pattern, ok := patternRaw.(string); ok {
if !compareRegex(str, map[string]interface{}{"$regex": pattern}) {
return false
}
}
}
}
// 检查 items (数组元素)
if itemsRaw, exists := schema["items"]; exists {
if arr, ok := doc.([]interface{}); ok {
if itemSchema, ok := itemsRaw.(map[string]interface{}); ok {
for _, item := range arr {
if !validateJSONSchema(item, itemSchema) {
return false
}
}
}
}
}
// 检查 minItems (数组最小长度)
if minItemsRaw, exists := schema["minItems"]; exists {
if arr, ok := doc.([]interface{}); ok {
if minItems := int(toFloat64(minItemsRaw)); len(arr) < minItems {
return false
}
}
}
// 检查 maxItems (数组最大长度)
if maxItemsRaw, exists := schema["maxItems"]; exists {
if arr, ok := doc.([]interface{}); ok {
if maxItems := int(toFloat64(maxItemsRaw)); len(arr) > maxItems {
return false
}
}
}
// 检查 allOf
if allOfRaw, exists := schema["allOf"]; exists {
if allOf, ok := allOfRaw.([]interface{}); ok {
@ -291,6 +212,203 @@ func validateJSONSchema(doc map[string]interface{}, schema map[string]interface{
return true
}
// validateFieldValue 验证字段值是否符合 schema
func validateFieldValue(value interface{}, schema map[string]interface{}) bool {
// 检查 bsonType
if bsonType, exists := schema["bsonType"]; exists {
if !validateBsonType(value, bsonType) {
return false
}
}
// 检查 enum
if enumRaw, exists := schema["enum"]; exists {
if enum, ok := enumRaw.([]interface{}); ok {
found := false
for _, val := range enum {
if compareEq(value, val) {
found = true
break
}
}
if !found {
return false
}
}
}
// 检查 minimum - 仅当 value 是数值类型时
if minimumRaw, exists := schema["minimum"]; exists {
if num, ok := toNumber(value); ok {
if num < toFloat64(minimumRaw) {
return false
}
}
}
// 检查 maximum - 仅当 value 是数值类型时
if maximumRaw, exists := schema["maximum"]; exists {
if num, ok := toNumber(value); ok {
if num > toFloat64(maximumRaw) {
return false
}
}
}
// 检查 minLength (字符串) - 仅当 value 是字符串时
if minLengthRaw, exists := schema["minLength"]; exists {
if str, ok := value.(string); ok {
if minLen := int(toFloat64(minLengthRaw)); len(str) < minLen {
return false
}
}
}
// 检查 maxLength (字符串) - 仅当 value 是字符串时
if maxLengthRaw, exists := schema["maxLength"]; exists {
if str, ok := value.(string); ok {
if maxLen := int(toFloat64(maxLengthRaw)); len(str) > maxLen {
return false
}
}
}
// 检查 pattern (正则表达式) - 仅当 value 是字符串时
if patternRaw, exists := schema["pattern"]; exists {
if str, ok := value.(string); ok {
if pattern, ok := patternRaw.(string); ok {
if !compareRegex(str, map[string]interface{}{"$regex": pattern}) {
return false
}
}
}
}
// 检查 items (数组元素) - 仅当 value 是数组时
if itemsRaw, exists := schema["items"]; exists {
if arr, ok := value.([]interface{}); ok {
if itemSchema, ok := itemsRaw.(map[string]interface{}); ok {
for _, item := range arr {
if itemMap, ok := item.(map[string]interface{}); ok {
if !validateJSONSchema(itemMap, itemSchema) {
return false
}
}
}
}
}
}
// 检查 minItems (数组最小长度) - 仅当 value 是数组时
if minItemsRaw, exists := schema["minItems"]; exists {
if arr, ok := value.([]interface{}); ok {
if minItems := int(toFloat64(minItemsRaw)); len(arr) < minItems {
return false
}
}
}
// 检查 maxItems (数组最大长度) - 仅当 value 是数组时
if maxItemsRaw, exists := schema["maxItems"]; exists {
if arr, ok := value.([]interface{}); ok {
if maxItems := int(toFloat64(maxItemsRaw)); len(arr) > maxItems {
return false
}
}
}
// 对于对象类型,继续递归验证嵌套 properties
if valueMap, ok := value.(map[string]interface{}); ok {
// 检查 required 字段
if requiredRaw, exists := schema["required"]; exists {
if required, ok := requiredRaw.([]interface{}); ok {
for _, reqField := range required {
if fieldStr, ok := reqField.(string); ok {
if valueMap[fieldStr] == nil {
return false
}
}
}
}
}
// 检查 properties
if propertiesRaw, exists := schema["properties"]; exists {
if properties, ok := propertiesRaw.(map[string]interface{}); ok {
for fieldName, fieldSchemaRaw := range properties {
if fieldSchema, ok := fieldSchemaRaw.(map[string]interface{}); ok {
fieldValue := valueMap[fieldName]
if fieldValue != nil {
if !validateFieldValue(fieldValue, fieldSchema) {
return false
}
}
}
}
}
}
}
// 检查 allOf
if allOfRaw, exists := schema["allOf"]; exists {
if allOf, ok := allOfRaw.([]interface{}); ok {
for _, subSchemaRaw := range allOf {
if subSchema, ok := subSchemaRaw.(map[string]interface{}); ok {
if !validateFieldValue(value, subSchema) {
return false
}
}
}
}
}
// 检查 anyOf
if anyOfRaw, exists := schema["anyOf"]; exists {
if anyOf, ok := anyOfRaw.([]interface{}); ok {
matched := false
for _, subSchemaRaw := range anyOf {
if subSchema, ok := subSchemaRaw.(map[string]interface{}); ok {
if validateFieldValue(value, subSchema) {
matched = true
break
}
}
}
if !matched {
return false
}
}
}
// 检查 oneOf
if oneOfRaw, exists := schema["oneOf"]; exists {
if oneOf, ok := oneOfRaw.([]interface{}); ok {
matchCount := 0
for _, subSchemaRaw := range oneOf {
if subSchema, ok := subSchemaRaw.(map[string]interface{}); ok {
if validateFieldValue(value, subSchema) {
matchCount++
}
}
}
if matchCount != 1 {
return false
}
}
}
// 检查 not
if notRaw, exists := schema["not"]; exists {
if notSchema, ok := notRaw.(map[string]interface{}); ok {
if validateFieldValue(value, notSchema) {
return false // not 要求不匹配
}
}
}
return true
}
// validateBsonType 验证 BSON 类型
func validateBsonType(value interface{}, bsonType interface{}) bool {
typeStr, ok := bsonType.(string)
@ -370,6 +488,28 @@ func getNumericValue(value interface{}) float64 {
}
}
// toArray 将值转换为数组
func toArray(value interface{}) ([]interface{}, bool) {
if arr, ok := value.([]interface{}); ok {
return arr, true
}
return nil, false
}
// toNumber 将值转换为数值
func toNumber(value interface{}) (float64, bool) {
switch v := value.(type) {
case int, int8, int16, int32, int64:
return getNumericValue(v), true
case uint, uint8, uint16, uint32, uint64:
return getNumericValue(v), true
case float32, float64:
return getNumericValue(v), true
default:
return 0, false
}
}
// handleAnd 处理 $and 操作符
func handleAnd(doc map[string]interface{}, condition interface{}) bool {
andConditions, ok := condition.([]interface{})

View File

@ -14,8 +14,8 @@ import (
// TestHTTPUpdateWithUpsert 测试 HTTP API 的 upsert 功能
func TestHTTPUpdateWithUpsert(t *testing.T) {
store := engine.NewMemoryStore(nil)
crud := &engine.CRUDHandler{store: store}
agg := &engine.AggregationEngine{store: store}
crud := engine.NewCRUDHandler(store, nil)
agg := engine.NewAggregationEngine(store)
handler := NewRequestHandler(store, crud, agg)
@ -65,8 +65,8 @@ func TestHTTPUpdateWithUpsert(t *testing.T) {
// TestHTTPHealthCheck 测试健康检查端点
func TestHTTPHealthCheck(t *testing.T) {
store := engine.NewMemoryStore(nil)
crud := &engine.CRUDHandler{store: store}
agg := &engine.AggregationEngine{store: store}
crud := engine.NewCRUDHandler(store, nil)
agg := engine.NewAggregationEngine(store)
server := NewHTTPServer(":0", NewRequestHandler(store, crud, agg))
@ -92,8 +92,8 @@ func TestHTTPHealthCheck(t *testing.T) {
// TestHTTPRoot 测试根路径处理
func TestHTTPRoot(t *testing.T) {
store := engine.NewMemoryStore(nil)
crud := &engine.CRUDHandler{store: store}
agg := &engine.AggregationEngine{store: store}
crud := engine.NewCRUDHandler(store, nil)
agg := engine.NewAggregationEngine(store)
server := NewHTTPServer(":0", NewRequestHandler(store, crud, agg))

View File

@ -307,7 +307,7 @@ func (h *MessageHandler) handleUpdate(body []byte) (interface{}, error) {
totalModified := 0
for _, op := range req.Updates {
matched, modified, err := h.store.Update(req.Collection, op.Q, op.U)
matched, modified, _, err := h.store.Update(req.Collection, op.Q, op.U, op.Upsert, op.ArrayFilters)
if err != nil {
return nil, err
}
@ -452,7 +452,7 @@ func (h *MessageHandler) handleUpdateMsg(collection string, params map[string]in
totalModified := 0
for _, op := range updates {
matched, modified, err := h.store.Update(collection, op.Q, op.U)
matched, modified, _, err := h.store.Update(collection, op.Q, op.U, op.Upsert, op.ArrayFilters)
if err != nil {
return nil, err
}