From f859c4ead7ac62013dc0467d18ead7dafa28c99e Mon Sep 17 00:00:00 2001 From: Tim Voronov Date: Tue, 24 Jun 2025 18:19:22 -0400 Subject: [PATCH] Refactor collector and aggregator handling; add `Get` and `Close` methods to collectors and sorters, restructure aggregation logic with multi-argument support, and update loop compilation to optimize resource management and key-value integration --- pkg/compiler/internal/core/loops.go | 3 +- ...op_collect_aggr.go => loop_collect_agg.go} | 141 +++++++----------- pkg/vm/internal/collector_counter.go | 8 + pkg/vm/internal/collector_key.go | 29 ++++ pkg/vm/internal/collector_key_counter.go | 29 ++++ pkg/vm/internal/collector_key_group.go | 33 +++- pkg/vm/internal/range_iter_test.go | 3 +- pkg/vm/internal/sorter.go | 16 ++ pkg/vm/internal/sorter_multi.go | 16 ++ pkg/vm/internal/transformer.go | 3 + pkg/vm/vm.go | 1 - .../integration/vm/vm_for_collect_agg_test.go | 50 ++++++- 12 files changed, 236 insertions(+), 96 deletions(-) rename pkg/compiler/internal/{loop_collect_aggr.go => loop_collect_agg.go} (53%) diff --git a/pkg/compiler/internal/core/loops.go b/pkg/compiler/internal/core/loops.go index ffb812cd..a5ed8e4d 100644 --- a/pkg/compiler/internal/core/loops.go +++ b/pkg/compiler/internal/core/loops.go @@ -2,8 +2,9 @@ package core import ( "fmt" - "github.com/MontFerret/ferret/pkg/vm" "strings" + + "github.com/MontFerret/ferret/pkg/vm" ) type LoopTable struct { diff --git a/pkg/compiler/internal/loop_collect_aggr.go b/pkg/compiler/internal/loop_collect_agg.go similarity index 53% rename from pkg/compiler/internal/loop_collect_aggr.go rename to pkg/compiler/internal/loop_collect_agg.go index 2fa52a78..4028af96 100644 --- a/pkg/compiler/internal/loop_collect_aggr.go +++ b/pkg/compiler/internal/loop_collect_agg.go @@ -1,6 +1,8 @@ package internal import ( + "strconv" + "github.com/MontFerret/ferret/pkg/compiler/internal/core" "github.com/MontFerret/ferret/pkg/parser/fql" "github.com/MontFerret/ferret/pkg/runtime" @@ -17,9 +19,11 @@ func (cc *LoopCollectCompiler) compileAggregation(c fql.ICollectAggregatorContex func (cc *LoopCollectCompiler) compileGroupedAggregation(c fql.ICollectAggregatorContext) { parentLoop := cc.ctx.Loops.Current() - // We need to allocate a temporary accumulators to store aggregation results + // We need to allocate a temporary accumulator to store aggregation results selectors := c.AllCollectAggregateSelector() - accums := cc.initAggrAccumulators(selectors) + accumulator := cc.ctx.Registers.Allocate(core.Temp) + cc.ctx.Emitter.EmitAx(vm.OpDataSetCollector, accumulator, int(core.CollectorTypeKeyGroup)) + loop := cc.ctx.Loops.CreateFor(core.TemporalLoop, cc.ctx.Registers.Allocate(core.Temp), false) // Now we iterate over the grouped items @@ -31,26 +35,15 @@ func (cc *LoopCollectCompiler) compileGroupedAggregation(c fql.ICollectAggregato loop.EmitInitialization(cc.ctx.Registers, cc.ctx.Emitter) // Add value selectors to the accumulators - cc.collectAggregationFuncArgs(selectors, func(i int, resultReg vm.Operand) { - cc.ctx.Emitter.EmitAB(vm.OpPush, accums[i], resultReg) - }) + argsPkg := cc.compileAggregationFuncArgs(selectors, accumulator) loop.EmitFinalization(cc.ctx.Emitter) cc.ctx.Symbols.ExitScope() // Now we can iterate over the selectors and execute the aggregation functions by passing the accumulators // And define variables for each accumulator result - cc.compileAggregationFuncCall(selectors, func(i int, _ string) core.RegisterSequence { - return core.RegisterSequence{accums[i]} - }, func(i int) {}) - - // Free the registers for accumulators - for _, reg := range accums { - cc.ctx.Registers.Free(reg) - } - - // Free the register for the iterator value - // cc.ctx.Registers.Free(aggrIterVal) + cc.compileAggregationFuncCall(selectors, accumulator, argsPkg) + cc.ctx.Registers.Free(accumulator) } func (cc *LoopCollectCompiler) compileGlobalAggregation(c fql.ICollectAggregatorContext) { @@ -60,14 +53,9 @@ func (cc *LoopCollectCompiler) compileGlobalAggregation(c fql.ICollectAggregator // Nested scope for aggregators cc.ctx.Symbols.EnterScope() - // Now we add value selectors to the accumulators + // Now we add value selectors to the collector selectors := c.AllCollectAggregateSelector() - cc.collectAggregationFuncArgs(selectors, func(i int, resultReg vm.Operand) { - aggrKeyName := selectors[i].Identifier().GetText() - aggrKeyReg := loadConstant(cc.ctx, runtime.String(aggrKeyName)) - cc.ctx.Emitter.EmitABC(vm.OpPushKV, parentLoop.Dst, aggrKeyReg, resultReg) - cc.ctx.Registers.Free(aggrKeyReg) - }) + argsPkg := cc.compileAggregationFuncArgs(selectors, parentLoop.Dst) parentLoop.EmitFinalization(cc.ctx.Emitter) cc.ctx.Loops.Pop() @@ -90,27 +78,13 @@ func (cc *LoopCollectCompiler) compileGlobalAggregation(c fql.ICollectAggregator loop.EmitInitialization(cc.ctx.Registers, cc.ctx.Emitter) // We just need to take the grouped values and call aggregation functions using them as args - var key vm.Operand - var value vm.Operand - cc.compileAggregationFuncCall(selectors, func(i int, selectorVarName string) core.RegisterSequence { - // We execute the function call with the accumulator as an argument - key = loadConstant(cc.ctx, runtime.String(selectorVarName)) - value = cc.ctx.Registers.Allocate(core.Temp) - cc.ctx.Emitter.EmitABC(vm.OpLoadKey, value, aggregator, key) - - return core.RegisterSequence{value} - }, func(_ int) { - cc.ctx.Registers.Free(value) - cc.ctx.Registers.Free(key) - }) - + cc.compileAggregationFuncCall(selectors, aggregator, argsPkg) cc.ctx.Registers.Free(aggregator) - - // Free the register for the iterator value - // cc.ctx.Registers.Free(aggrIterVal) } -func (cc *LoopCollectCompiler) collectAggregationFuncArgs(selectors []fql.ICollectAggregateSelectorContext, collector func(int, vm.Operand)) { +func (cc *LoopCollectCompiler) compileAggregationFuncArgs(selectors []fql.ICollectAggregateSelectorContext, collector vm.Operand) []int { + argsPkg := make([]int, len(selectors)) + for i := 0; i < len(selectors); i++ { selector := selectors[i] fcx := selector.FunctionCallExpression() @@ -121,63 +95,64 @@ func (cc *LoopCollectCompiler) collectAggregationFuncArgs(selectors []fql.IColle panic("No arguments provided for the function call in the aggregate selector") } + aggrKeyReg := loadConstant(cc.ctx, runtime.Int(i)) + // we keep information about the args - whether we need to unpack them or not + argsPkg[i] = len(args) + if len(args) > 1 { - // TODO: Better error handling - panic("Too many arguments") + for y, arg := range args { + argKeyReg := cc.loadAggregationArgKey(i, y) + cc.ctx.Emitter.EmitABC(vm.OpPushKV, collector, argKeyReg, arg) + cc.ctx.Registers.Free(argKeyReg) + } + } else { + cc.ctx.Emitter.EmitABC(vm.OpPushKV, collector, aggrKeyReg, args[0]) } - resultReg := args[0] - collector(i, resultReg) - cc.ctx.Registers.Free(resultReg) + cc.ctx.Registers.Free(aggrKeyReg) + cc.ctx.Registers.FreeSequence(args) } + + return argsPkg } -func (cc *LoopCollectCompiler) compileAggregationFuncCall(selectors []fql.ICollectAggregateSelectorContext, provider func(int, string) core.RegisterSequence, cleanup func(int)) { +func (cc *LoopCollectCompiler) compileAggregationFuncCall(selectors []fql.ICollectAggregateSelectorContext, accumulator vm.Operand, argsPkg []int) { for i, selector := range selectors { - fcx := selector.FunctionCallExpression() - // We won't make any checks here, as we already did it before - selectorVarName := selector.Identifier().GetText() + argsNum := argsPkg[i] - result := cc.ctx.ExprCompiler.CompileFunctionCallWith(fcx.FunctionCall(), fcx.ErrorOperator() != nil, provider(i, selectorVarName)) + var args core.RegisterSequence + + // We need to unpack arguments + if argsNum > 1 { + args = cc.ctx.Registers.AllocateSequence(argsNum) + + for y, reg := range args { + argKeyReg := cc.loadAggregationArgKey(i, y) + cc.ctx.Emitter.EmitABC(vm.OpLoadKey, reg, accumulator, argKeyReg) + + cc.ctx.Registers.Free(argKeyReg) + } + } else { + key := loadConstant(cc.ctx, runtime.Int(i)) + value := cc.ctx.Registers.Allocate(core.Temp) + cc.ctx.Emitter.EmitABC(vm.OpLoadKey, value, accumulator, key) + args = core.RegisterSequence{value} + cc.ctx.Registers.Free(key) + } + + fcx := selector.FunctionCallExpression() + result := cc.ctx.ExprCompiler.CompileFunctionCallWith(fcx.FunctionCall(), fcx.ErrorOperator() != nil, args) // We define the variable for the selector result in the upper scope // Since this temporary scope is only for aggregators and will be closed after the aggregation + selectorVarName := selector.Identifier().GetText() varReg := cc.ctx.Symbols.DeclareLocal(selectorVarName) cc.ctx.Emitter.EmitAB(vm.OpMove, varReg, result) cc.ctx.Registers.Free(result) - - cleanup(i) } } -func (cc *LoopCollectCompiler) initAggrAccumulators(selectors []fql.ICollectAggregateSelectorContext) []vm.Operand { - accums := make([]vm.Operand, len(selectors)) - - // First of all, we allocate registers for accumulators - accums = make([]vm.Operand, len(selectors)) - - // We need to allocate a register for each accumulator - for i := 0; i < len(selectors); i++ { - reg := cc.ctx.Registers.Allocate(core.Temp) - accums[i] = reg - - // TODO: Select persistent List type, we do not know how many items we will have - cc.ctx.Emitter.EmitA(vm.OpList, reg) - } - - return accums -} - -func (cc *LoopCollectCompiler) emitPushToAggrAccumulators(accums []vm.Operand, selectors []fql.ICollectAggregateSelectorContext, loop *core.Loop) { - for i, selector := range selectors { - fcx := selector.FunctionCallExpression() - args := cc.ctx.ExprCompiler.CompileArgumentList(fcx.FunctionCall().ArgumentList()) - - if len(args) != 1 { - panic("aggregate function must have exactly one argument") - } - - cc.ctx.Emitter.EmitAB(vm.OpPush, accums[i], args[0]) - cc.ctx.Registers.Free(args[0]) - } +func (cc *LoopCollectCompiler) loadAggregationArgKey(selector int, arg int) vm.Operand { + argKey := strconv.Itoa(selector) + ":" + strconv.Itoa(arg) + return loadConstant(cc.ctx, runtime.String(argKey)) } diff --git a/pkg/vm/internal/collector_counter.go b/pkg/vm/internal/collector_counter.go index b56c5300..a07d5a5f 100644 --- a/pkg/vm/internal/collector_counter.go +++ b/pkg/vm/internal/collector_counter.go @@ -28,3 +28,11 @@ func (c *CounterCollector) Add(_ context.Context, _, _ runtime.Value) error { return nil } + +func (c *CounterCollector) Get(_ context.Context, _ runtime.Value) (runtime.Value, error) { + return c.Value, nil +} + +func (c *CounterCollector) Close() error { + return nil +} diff --git a/pkg/vm/internal/collector_key.go b/pkg/vm/internal/collector_key.go index e5eaed9f..713564a0 100644 --- a/pkg/vm/internal/collector_key.go +++ b/pkg/vm/internal/collector_key.go @@ -2,6 +2,7 @@ package internal import ( "context" + "io" "github.com/MontFerret/ferret/pkg/runtime" ) @@ -52,3 +53,31 @@ func (c *KeyCollector) Add(ctx context.Context, key, _ runtime.Value) error { return nil } + +func (c *KeyCollector) Get(ctx context.Context, key runtime.Value) (runtime.Value, error) { + k, err := Stringify(ctx, key) + + if err != nil { + return nil, err + } + + v, ok := c.grouping[k] + + if !ok { + return runtime.None, runtime.ErrNotFound + } + + return v, nil +} + +func (c *KeyCollector) Close() error { + val := c.Value + c.Value = nil + c.grouping = nil + + if closer := val.(io.Closer); closer != nil { + return closer.Close() + } + + return nil +} diff --git a/pkg/vm/internal/collector_key_counter.go b/pkg/vm/internal/collector_key_counter.go index 64afdd2c..c75a94ec 100644 --- a/pkg/vm/internal/collector_key_counter.go +++ b/pkg/vm/internal/collector_key_counter.go @@ -2,6 +2,7 @@ package internal import ( "context" + "io" "github.com/MontFerret/ferret/pkg/runtime" ) @@ -101,3 +102,31 @@ func (c *KeyCounterCollector) Add(ctx context.Context, key, _ runtime.Value) err return nil } + +func (c *KeyCounterCollector) Get(ctx context.Context, key runtime.Value) (runtime.Value, error) { + k, err := Stringify(ctx, key) + + if err != nil { + return nil, err + } + + v, ok := c.grouping[k] + + if !ok { + return runtime.None, runtime.ErrNotFound + } + + return v, nil +} + +func (c *KeyCounterCollector) Close() error { + val := c.Value + c.Value = nil + c.grouping = nil + + if closer := val.(io.Closer); closer != nil { + return closer.Close() + } + + return nil +} diff --git a/pkg/vm/internal/collector_key_group.go b/pkg/vm/internal/collector_key_group.go index 006f29b1..e7535571 100644 --- a/pkg/vm/internal/collector_key_group.go +++ b/pkg/vm/internal/collector_key_group.go @@ -2,6 +2,7 @@ package internal import ( "context" + "io" "github.com/MontFerret/ferret/pkg/runtime" ) @@ -21,10 +22,6 @@ func NewKeyGroupCollector() Transformer { } } -func (c *KeyGroupCollector) Get(_ context.Context, key runtime.Value) (runtime.Value, error) { - return c.grouping[key.String()], nil -} - func (c *KeyGroupCollector) Iterate(ctx context.Context) (runtime.Iterator, error) { if !c.sorted { if err := c.sort(ctx); err != nil { @@ -83,3 +80,31 @@ func (c *KeyGroupCollector) sort(ctx context.Context) error { return comp }) } + +func (c *KeyGroupCollector) Get(ctx context.Context, key runtime.Value) (runtime.Value, error) { + k, err := Stringify(ctx, key) + + if err != nil { + return nil, err + } + + v, ok := c.grouping[k] + + if !ok { + return runtime.None, runtime.ErrNotFound + } + + return v, nil +} + +func (c *KeyGroupCollector) Close() error { + val := c.Value + c.Value = nil + c.grouping = nil + + if closer := val.(io.Closer); closer != nil { + return closer.Close() + } + + return nil +} diff --git a/pkg/vm/internal/range_iter_test.go b/pkg/vm/internal/range_iter_test.go index ddf7ca5b..3683c614 100644 --- a/pkg/vm/internal/range_iter_test.go +++ b/pkg/vm/internal/range_iter_test.go @@ -2,9 +2,10 @@ package internal_test import ( "context" - "github.com/MontFerret/ferret/pkg/vm/internal" "testing" + "github.com/MontFerret/ferret/pkg/vm/internal" + "github.com/MontFerret/ferret/pkg/runtime" . "github.com/smartystreets/goconvey/convey" diff --git a/pkg/vm/internal/sorter.go b/pkg/vm/internal/sorter.go index 63a064f2..08139acd 100644 --- a/pkg/vm/internal/sorter.go +++ b/pkg/vm/internal/sorter.go @@ -2,6 +2,7 @@ package internal import ( "context" + "io" "github.com/MontFerret/ferret/pkg/runtime" ) @@ -57,3 +58,18 @@ func (s *Sorter) sort(ctx context.Context) error { return -comp }) } + +func (s *Sorter) Get(_ context.Context, _ runtime.Value) (runtime.Value, error) { + return runtime.None, runtime.ErrNotSupported +} + +func (s *Sorter) Close() error { + val := s.Value + s.Value = nil + + if closer := val.(io.Closer); closer != nil { + return closer.Close() + } + + return nil +} diff --git a/pkg/vm/internal/sorter_multi.go b/pkg/vm/internal/sorter_multi.go index 1943d6a8..b3cf4042 100644 --- a/pkg/vm/internal/sorter_multi.go +++ b/pkg/vm/internal/sorter_multi.go @@ -2,6 +2,7 @@ package internal import ( "context" + "io" "github.com/MontFerret/ferret/pkg/runtime" ) @@ -67,3 +68,18 @@ func (s *MultiSorter) sort(ctx context.Context) error { return 0 }) } + +func (s *MultiSorter) Get(_ context.Context, _ runtime.Value) (runtime.Value, error) { + return runtime.None, runtime.ErrNotSupported +} + +func (s *MultiSorter) Close() error { + val := s.Value + s.Value = nil + + if closer := val.(io.Closer); closer != nil { + return closer.Close() + } + + return nil +} diff --git a/pkg/vm/internal/transformer.go b/pkg/vm/internal/transformer.go index cd5bc512..c3c9ad15 100644 --- a/pkg/vm/internal/transformer.go +++ b/pkg/vm/internal/transformer.go @@ -2,6 +2,7 @@ package internal import ( "context" + "io" "github.com/MontFerret/ferret/pkg/runtime" ) @@ -9,6 +10,8 @@ import ( type Transformer interface { runtime.Value runtime.Iterable + runtime.Keyed + io.Closer Add(ctx context.Context, key, value runtime.Value) error } diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index c5f24f4a..cbfaa47e 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -192,7 +192,6 @@ loop: start := int(src1) end := int(src1) + size - // Iterate over registers starting from src1 and up to the src2 for i := start; i < end; i++ { _ = arr.Add(ctx, reg[i]) } diff --git a/test/integration/vm/vm_for_collect_agg_test.go b/test/integration/vm/vm_for_collect_agg_test.go index 6e53b409..0ce1f761 100644 --- a/test/integration/vm/vm_for_collect_agg_test.go +++ b/test/integration/vm/vm_for_collect_agg_test.go @@ -96,7 +96,46 @@ FOR u IN users `, []any{ map[string]any{"minAge": 25, "maxAge": 69}, }, "Should collect and aggregate values without grouping"), - SkipCaseArray(` + CaseArray(` +LET users = [ + { + active: true, + married: true, + age: 31, + gender: "m" + }, + { + active: true, + married: false, + age: 25, + gender: "f" + }, + { + active: true, + married: false, + age: 36, + gender: "m" + }, + { + active: false, + married: true, + age: 69, + gender: "m" + }, + { + active: true, + married: true, + age: 45, + gender: "f" + } + ] + FOR u IN users + COLLECT AGGREGATE ages = UNION(u.age, u.age) + RETURN { ages } +`, []any{ + map[string]any{"ages": []any{31, 25, 36, 69, 45, 31, 25, 36, 69, 45}}, + }, "Should call aggregation functions with more than one argument"), + CaseArray(` LET users = [ { active: true, @@ -131,16 +170,15 @@ LET users = [ ] FOR u IN users COLLECT genderGroup = u.gender - AGGREGATE minAge = MIN(u.age), maxAge = MAX(u.age) + AGGREGATE ages = UNION(u.age, u.age) RETURN { genderGroup, - minAge, - maxAge + ages, } `, []any{ - map[string]any{"genderGroup": "f", "minAge": 25, "maxAge": 45}, - map[string]any{"genderGroup": "m", "minAge": 31, "maxAge": 69}, + map[string]any{"genderGroup": "f", "ages": []any{25, 45, 25, 45}}, + map[string]any{"genderGroup": "m", "ages": []any{31, 36, 69, 31, 36, 69}}, }, "Should collect and aggregate values by a single key"), }) }