diff --git a/pkg/compiler/internal/loop_collect_agg.go b/pkg/compiler/internal/loop_collect_agg.go index f1e90235..4496f109 100644 --- a/pkg/compiler/internal/loop_collect_agg.go +++ b/pkg/compiler/internal/loop_collect_agg.go @@ -119,6 +119,18 @@ func (c *LoopCollectCompiler) compileAggregationFuncArgs(selectors []fql.ICollec } func (c *LoopCollectCompiler) compileAggregationFuncCall(selectors []fql.ICollectAggregateSelectorContext, accumulator vm.Operand, argsPkg []int) { + // Gets the number of records in the accumulator + cond := c.ctx.Registers.Allocate(core.Temp) + c.ctx.Emitter.EmitAB(vm.OpLength, cond, accumulator) + zero := loadConstant(c.ctx, runtime.ZeroInt) + // Check if the number equals to zero + c.ctx.Emitter.EmitEq(cond, cond, zero) + c.ctx.Registers.Free(zero) + // We skip the key retrieval and function call of there are no records in the accumulator + ifJump := c.ctx.Emitter.EmitJumpIfTrue(cond, core.JumpPlaceholder) + + selectorVarRegs := make([]vm.Operand, len(selectors)) + for i, selector := range selectors { argsNum := argsPkg[i] @@ -149,9 +161,20 @@ func (c *LoopCollectCompiler) compileAggregationFuncCall(selectors []fql.ICollec // Since this temporary scope is only for aggregators and will be closed after the aggregation selectorVarName := selector.Identifier().GetText() varReg := c.ctx.Symbols.DeclareLocal(selectorVarName) + selectorVarRegs[i] = varReg c.ctx.Emitter.EmitAB(vm.OpMove, varReg, result) c.ctx.Registers.Free(result) } + + elseJump := c.ctx.Emitter.EmitJump(core.JumpPlaceholder) + c.ctx.Emitter.PatchJumpNext(ifJump) + + for _, varReg := range selectorVarRegs { + c.ctx.Emitter.EmitA(vm.OpLoadNone, varReg) + } + + c.ctx.Emitter.PatchJumpNext(elseJump) + c.ctx.Registers.Free(cond) } func (c *LoopCollectCompiler) loadAggregationArgKey(selector int, arg int) vm.Operand { diff --git a/pkg/vm/internal/collector_counter.go b/pkg/vm/internal/collector_counter.go index a07d5a5f..20df2bbe 100644 --- a/pkg/vm/internal/collector_counter.go +++ b/pkg/vm/internal/collector_counter.go @@ -33,6 +33,10 @@ func (c *CounterCollector) Get(_ context.Context, _ runtime.Value) (runtime.Valu return c.Value, nil } +func (c *CounterCollector) Length(_ context.Context) (runtime.Int, error) { + return 1, nil +} + func (c *CounterCollector) Close() error { return nil } diff --git a/pkg/vm/internal/collector_key.go b/pkg/vm/internal/collector_key.go index 713564a0..58809ff7 100644 --- a/pkg/vm/internal/collector_key.go +++ b/pkg/vm/internal/collector_key.go @@ -64,12 +64,16 @@ func (c *KeyCollector) Get(ctx context.Context, key runtime.Value) (runtime.Valu v, ok := c.grouping[k] if !ok { - return runtime.None, runtime.ErrNotFound + return runtime.None, runtime.Errorf(runtime.ErrNotFound, "collector key: %s", k) } return v, nil } +func (c *KeyCollector) Length(ctx context.Context) (runtime.Int, error) { + return c.Value.Length(ctx) +} + func (c *KeyCollector) Close() error { val := c.Value c.Value = nil diff --git a/pkg/vm/internal/collector_key_counter.go b/pkg/vm/internal/collector_key_counter.go index c75a94ec..246dfc6f 100644 --- a/pkg/vm/internal/collector_key_counter.go +++ b/pkg/vm/internal/collector_key_counter.go @@ -113,12 +113,16 @@ func (c *KeyCounterCollector) Get(ctx context.Context, key runtime.Value) (runti v, ok := c.grouping[k] if !ok { - return runtime.None, runtime.ErrNotFound + return runtime.None, runtime.Errorf(runtime.ErrNotFound, "collector key: %s", k) } return v, nil } +func (c *KeyCounterCollector) Length(ctx context.Context) (runtime.Int, error) { + return c.Value.Length(ctx) +} + func (c *KeyCounterCollector) Close() error { val := c.Value c.Value = nil diff --git a/pkg/vm/internal/collector_key_group.go b/pkg/vm/internal/collector_key_group.go index e8e566ff..73f93062 100644 --- a/pkg/vm/internal/collector_key_group.go +++ b/pkg/vm/internal/collector_key_group.go @@ -97,6 +97,10 @@ func (c *KeyGroupCollector) Get(ctx context.Context, key runtime.Value) (runtime return v, nil } +func (c *KeyGroupCollector) Length(ctx context.Context) (runtime.Int, error) { + return c.Value.Length(ctx) +} + func (c *KeyGroupCollector) Close() error { val := c.Value c.Value = nil diff --git a/pkg/vm/internal/sorter.go b/pkg/vm/internal/sorter.go index 08139acd..36a2df3c 100644 --- a/pkg/vm/internal/sorter.go +++ b/pkg/vm/internal/sorter.go @@ -63,6 +63,10 @@ func (s *Sorter) Get(_ context.Context, _ runtime.Value) (runtime.Value, error) return runtime.None, runtime.ErrNotSupported } +func (s *Sorter) Length(ctx context.Context) (runtime.Int, error) { + return s.Value.Length(ctx) +} + func (s *Sorter) Close() error { val := s.Value s.Value = nil diff --git a/pkg/vm/internal/sorter_multi.go b/pkg/vm/internal/sorter_multi.go index b3cf4042..0a5247e1 100644 --- a/pkg/vm/internal/sorter_multi.go +++ b/pkg/vm/internal/sorter_multi.go @@ -73,6 +73,10 @@ func (s *MultiSorter) Get(_ context.Context, _ runtime.Value) (runtime.Value, er return runtime.None, runtime.ErrNotSupported } +func (s *MultiSorter) Length(ctx context.Context) (runtime.Int, error) { + return s.Value.Length(ctx) +} + func (s *MultiSorter) Close() error { val := s.Value s.Value = nil diff --git a/pkg/vm/internal/transformer.go b/pkg/vm/internal/transformer.go index c3c9ad15..9a7ea8ed 100644 --- a/pkg/vm/internal/transformer.go +++ b/pkg/vm/internal/transformer.go @@ -11,6 +11,7 @@ type Transformer interface { runtime.Value runtime.Iterable runtime.Keyed + runtime.Measurable io.Closer Add(ctx context.Context, key, value runtime.Value) error diff --git a/pkg/vm/opcode.go b/pkg/vm/opcode.go index c270b373..b6e4c500 100644 --- a/pkg/vm/opcode.go +++ b/pkg/vm/opcode.go @@ -8,7 +8,6 @@ const ( OpJump OpJumpIfFalse OpJumpIfTrue - OpJumpIfEmpty // Register Operations OpMove // Move a value from register A to register B @@ -112,8 +111,6 @@ func (op Opcode) String() string { return "JMPF" case OpJumpIfTrue: return "JMPT" - case OpJumpIfEmpty: - return "JMPE" // Register Operations case OpMove: diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 49ca4f1a..d9cc0da6 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -84,23 +84,6 @@ loop: if runtime.ToBoolean(reg[src1]) { vm.pc = int(dst) } - case OpJumpIfEmpty: - val, ok := reg[src1].(runtime.Measurable) - - if ok { - size, err := val.Length(ctx) - - if err != nil { - return nil, err - } - - if size == 0 { - vm.pc = int(dst) - } - } else { - // If the value is not measurable, we consider it empty - vm.pc = int(dst) - } case OpAdd: reg[dst] = internal.Add(ctx, reg[src1], reg[src2]) case OpSub: diff --git a/test/integration/vm/vm_for_collect_agg_additional_test.go b/test/integration/vm/vm_for_collect_agg_additional_test.go deleted file mode 100644 index 64bf9203..00000000 --- a/test/integration/vm/vm_for_collect_agg_additional_test.go +++ /dev/null @@ -1,211 +0,0 @@ -package vm_test - -import ( - "testing" - - . "github.com/MontFerret/ferret/test/integration/base" -) - -func TestCollectAggregateAdditional(t *testing.T) { - RunUseCases(t, []UseCase{ - // Test 1: Multiple aggregation functions with complex expressions - - // Test 2: Nested FOR loops with COLLECT AGGREGATE - - // Test 3: Empty array handling - - // Test 4: Null value handling - CaseArray(` - LET users = [ - { - active: true, - age: null, - gender: "m", - married: true - }, - { - active: true, - age: 25, - gender: "f", - married: false - }, - { - active: true, - age: null, - gender: "m", - married: false - } - ] - FOR u IN users - COLLECT gender = u.gender - AGGREGATE minAge = MIN(u.age), maxAge = MAX(u.age) - RETURN { - gender, - minAge, - maxAge - } - `, []any{ - map[string]any{"gender": "f", "minAge": 25, "maxAge": 25}, - map[string]any{"gender": "m", "minAge": nil, "maxAge": nil}, - }, "Should handle null values in aggregation"), - - // Test 5: Multiple grouping keys with aggregation - CaseArray(` - LET users = [ - { - active: true, - age: 31, - gender: "m", - married: true, - department: "IT" - }, - { - active: true, - age: 25, - gender: "f", - married: false, - department: "Marketing" - }, - { - active: true, - age: 36, - gender: "m", - married: false, - department: "IT" - }, - { - active: false, - age: 69, - gender: "m", - married: true, - department: "Management" - }, - { - active: true, - age: 45, - gender: "f", - married: true, - department: "Marketing" - } - ] - FOR u IN users - COLLECT - department = u.department, - gender = u.gender - AGGREGATE - minAge = MIN(u.age), - maxAge = MAX(u.age) - RETURN { - department, - gender, - minAge, - maxAge - } - `, []any{ - map[string]any{"department": "IT", "gender": "m", "minAge": 31, "maxAge": 36}, - map[string]any{"department": "Management", "gender": "m", "minAge": 69, "maxAge": 69}, - map[string]any{"department": "Marketing", "gender": "f", "minAge": 25, "maxAge": 45}, - }, "Should aggregate with multiple grouping keys"), - - // Test 6: Aggregation with conditional expressions - CaseArray(` - LET users = [ - { - active: true, - age: 31, - gender: "m", - married: true, - salary: 75000 - }, - { - active: true, - age: 25, - gender: "f", - married: false, - salary: 60000 - }, - { - active: true, - age: 36, - gender: "m", - married: false, - salary: 80000 - }, - { - active: false, - age: 69, - gender: "m", - married: true, - salary: 95000 - }, - { - active: true, - age: 45, - gender: "f", - married: true, - salary: 70000 - } - ] - FOR u IN users - COLLECT gender = u.gender - AGGREGATE - activeCount = SUM(u.active ? 1 : 0), - marriedCount = SUM(u.married ? 1 : 0), - highSalaryCount = SUM(u.salary > 70000 ? 1 : 0) - RETURN { - gender, - activeCount, - marriedCount, - highSalaryCount - } - `, []any{ - map[string]any{ - "gender": "f", - "activeCount": 2, - "marriedCount": 1, - "highSalaryCount": 0, - }, - map[string]any{ - "gender": "m", - "activeCount": 2, - "marriedCount": 2, - "highSalaryCount": 2, - }, - }, "Should aggregate with conditional expressions"), - - // Test 7: Aggregation with array operations - CaseArray(` - LET users = [ - { - name: "John", - skills: ["JavaScript", "Python", "Go"] - }, - { - name: "Jane", - skills: ["Java", "C++", "Python"] - }, - { - name: "Bob", - skills: ["Go", "Rust"] - }, - { - name: "Alice", - skills: ["JavaScript", "TypeScript"] - } - ] - FOR u IN users - COLLECT AGGREGATE - allSkills = UNION(u.skills), - uniqueSkillCount = COUNT_DISTINCT(u.skills) - RETURN { - allSkills: SORTED(allSkills), - uniqueSkillCount - } - `, []any{ - map[string]any{ - "allSkills": []any{"C++", "Go", "Java", "JavaScript", "Python", "Rust", "TypeScript"}, - "uniqueSkillCount": 7, - }, - }, "Should aggregate with array operations"), - }) -} diff --git a/test/integration/vm/vm_for_collect_agg_test.go b/test/integration/vm/vm_for_collect_agg_test.go index 6b7bcbec..4c22026a 100644 --- a/test/integration/vm/vm_for_collect_agg_test.go +++ b/test/integration/vm/vm_for_collect_agg_test.go @@ -9,14 +9,38 @@ import ( func TestCollectAggregate(t *testing.T) { RunUseCases(t, []UseCase{ CaseArray(` - LET users = [] + LET users = [ + { + active: true, + age: null, + gender: "m", + married: true + }, + { + active: true, + age: 25, + gender: "f", + married: false + }, + { + active: true, + age: null, + gender: "m", + married: false + } + ] FOR u IN users - COLLECT AGGREGATE minAge = MIN(u.age), maxAge = MAX(u.age) + COLLECT gender = u.gender + AGGREGATE minAge = MIN(u.age), maxAge = MAX(u.age) RETURN { + gender, minAge, maxAge } - `, []any{}, "Should handle empty arrays gracefully"), + `, []any{ + map[string]any{"gender": "f", "minAge": 25, "maxAge": 25}, + map[string]any{"gender": "m", "minAge": nil, "maxAge": nil}, + }, "Should handle null values in aggregation"), CaseArray(` LET users = [ { @@ -63,6 +87,126 @@ FOR u IN users map[string]any{"genderGroup": "f", "minAge": 25, "maxAge": 45}, map[string]any{"genderGroup": "m", "minAge": 31, "maxAge": 69}, }, "Should collect and aggregate values by a single key"), + CaseArray(` + LET users = [ + { + active: true, + age: 31, + gender: "m", + married: true, + department: "IT" + }, + { + active: true, + age: 25, + gender: "f", + married: false, + department: "Marketing" + }, + { + active: true, + age: 36, + gender: "m", + married: false, + department: "IT" + }, + { + active: false, + age: 69, + gender: "m", + married: true, + department: "Management" + }, + { + active: true, + age: 45, + gender: "f", + married: true, + department: "Marketing" + } + ] + FOR u IN users + COLLECT + department = u.department, + gender = u.gender + AGGREGATE + minAge = MIN(u.age), + maxAge = MAX(u.age) + RETURN { + department, + gender, + minAge, + maxAge + } + `, []any{ + map[string]any{"department": "IT", "gender": "m", "minAge": 31, "maxAge": 36}, + map[string]any{"department": "Management", "gender": "m", "minAge": 69, "maxAge": 69}, + map[string]any{"department": "Marketing", "gender": "f", "minAge": 25, "maxAge": 45}, + }, "Should aggregate with multiple grouping keys"), + CaseArray(` + LET users = [ + { + active: true, + age: 31, + gender: "m", + married: true, + salary: 75000 + }, + { + active: true, + age: 25, + gender: "f", + married: false, + salary: 60000 + }, + { + active: true, + age: 36, + gender: "m", + married: false, + salary: 80000 + }, + { + active: false, + age: 69, + gender: "m", + married: true, + salary: 95000 + }, + { + active: true, + age: 45, + gender: "f", + married: true, + salary: 70000 + } + ] + FOR u IN users + COLLECT gender = u.gender + AGGREGATE + activeCount = SUM(u.active ? 1 : 0), + marriedCount = SUM(u.married ? 1 : 0), + highSalaryCount = SUM(u.salary > 70000 ? 1 : 0) + RETURN { + gender, + activeCount, + marriedCount, + highSalaryCount + } + `, []any{ + map[string]any{ + "gender": "f", + "activeCount": 2, + "marriedCount": 1, + "highSalaryCount": 0, + }, + map[string]any{ + "gender": "m", + "activeCount": 2, + "marriedCount": 2, + "highSalaryCount": 2, + }, + }, "Should aggregate with conditional expressions"), CaseArray(` LET users = [ { @@ -102,9 +246,20 @@ FOR u IN users minAge, maxAge } - `, []any{ - map[string]any{"minAge": 25, "maxAge": 69}, - }, "Should collect and aggregate values without grouping"), + `, + []any{map[string]any{"minAge": 25, "maxAge": 69}}, + "Should collect and aggregate values without grouping"), + CaseArray(` + LET users = [] + FOR u IN users + COLLECT AGGREGATE minAge = MIN(u.age), maxAge = MAX(u.age) + RETURN { + minAge, + maxAge + } + `, + []any{map[string]any{"minAge": nil, "maxAge": nil}}, + "Should handle empty arrays gracefully"), CaseArray(` LET users = [ { @@ -322,5 +477,38 @@ FOR u IN users "employeeCount": 2, }, }, "Should aggregate multiple values with complex expressions"), + CaseArray(` + LET users = [ + { + name: "John", + skills: ["JavaScript", "Python", "Go"] + }, + { + name: "Jane", + skills: ["Java", "C++", "Python"] + }, + { + name: "Bob", + skills: ["Go", "Rust"] + }, + { + name: "Alice", + skills: ["JavaScript", "TypeScript"] + } + ] + FOR u IN users + COLLECT AGGREGATE + allSkills = UNION(u.skills), + uniqueSkillCount = COUNT_DISTINCT(u.skills) + RETURN { + allSkills: SORTED(allSkills), + uniqueSkillCount + } + `, []any{ + map[string]any{ + "allSkills": []any{"C++", "Go", "Java", "JavaScript", "Python", "Rust", "TypeScript"}, + "uniqueSkillCount": 7, + }, + }, "Should aggregate with array operations"), }) } diff --git a/test/integration/vm/vm_for_collect_test.go b/test/integration/vm/vm_for_collect_test.go index 74cf4239..14a26a0e 100644 --- a/test/integration/vm/vm_for_collect_test.go +++ b/test/integration/vm/vm_for_collect_test.go @@ -87,6 +87,14 @@ func TestForCollect(t *testing.T) { RETURN {x, gender} `, "Should not have access to variables defined before COLLECT"), CaseArray(` + LET users = [] + FOR i IN users + COLLECT gender = i.gender + RETURN gender + `, + []any{}, + "Should handle empty arrays gracefully"), + CaseArray(` LET users = [ { active: true,