From 5185abd7147f4b271639dc2b849c40da9dbf212e Mon Sep 17 00:00:00 2001 From: Tim Voronov Date: Sat, 26 Jul 2025 14:55:53 -0400 Subject: [PATCH] Refactor `LoopCollectCompiler`: replace `CollectorSpec` with `Collector`, improve grouping and aggregation handling, update emitter methods, standardize key loading logic, streamline function naming and argument handling, add `Count` function to collections, and enhance code comments for readability. --- .../internal/core/aggregate_selector.go | 44 ---- pkg/compiler/internal/core/collector.go | 80 +++++++ .../internal/core/collector_aggregation.go | 60 +++++ pkg/compiler/internal/core/collector_spec.go | 80 ------- pkg/compiler/internal/core/emitter.go | 16 +- .../internal/core/emitter_extension.go | 4 + pkg/compiler/internal/core/label.go | 4 + pkg/compiler/internal/core/registers.go | 6 +- pkg/compiler/internal/helpers.go | 7 +- pkg/compiler/internal/loop_collect.go | 41 ++-- pkg/compiler/internal/loop_collect_agg.go | 223 ++++++++---------- pkg/compiler/internal/loop_collect_grp.go | 4 +- pkg/compiler/internal/loop_collect_prj.go | 2 +- pkg/stdlib/collections/count.go | 19 ++ pkg/stdlib/collections/count_distinct.go | 8 +- pkg/stdlib/collections/include.go | 15 +- pkg/stdlib/collections/lib.go | 14 +- pkg/stdlib/collections/reverse.go | 12 +- .../vm/vm_for_in_collect_agg_test.go | 32 +++ 19 files changed, 357 insertions(+), 314 deletions(-) delete mode 100644 pkg/compiler/internal/core/aggregate_selector.go create mode 100644 pkg/compiler/internal/core/collector.go create mode 100644 pkg/compiler/internal/core/collector_aggregation.go delete mode 100644 pkg/compiler/internal/core/collector_spec.go create mode 100644 pkg/stdlib/collections/count.go diff --git a/pkg/compiler/internal/core/aggregate_selector.go b/pkg/compiler/internal/core/aggregate_selector.go deleted file mode 100644 index ad8b1aad..00000000 --- a/pkg/compiler/internal/core/aggregate_selector.go +++ /dev/null @@ -1,44 +0,0 @@ -package core - -import ( - "github.com/MontFerret/ferret/pkg/runtime" - "github.com/MontFerret/ferret/pkg/vm" -) - -type AggregateSelector struct { - name runtime.String - args int - funcName runtime.String - protectedCall bool - register vm.Operand -} - -func NewAggregateSelector(name runtime.String, args int, funcName runtime.String, protectedCall bool, register vm.Operand) *AggregateSelector { - return &AggregateSelector{ - name: name, - register: register, - args: args, - funcName: funcName, - protectedCall: protectedCall, - } -} - -func (s *AggregateSelector) Name() runtime.String { - return s.name -} - -func (s *AggregateSelector) Args() int { - return s.args -} - -func (s *AggregateSelector) FuncName() runtime.String { - return s.funcName -} - -func (s *AggregateSelector) ProtectedCall() bool { - return s.protectedCall -} - -func (s *AggregateSelector) Register() vm.Operand { - return s.register -} diff --git a/pkg/compiler/internal/core/collector.go b/pkg/compiler/internal/core/collector.go new file mode 100644 index 00000000..7e1a1fc0 --- /dev/null +++ b/pkg/compiler/internal/core/collector.go @@ -0,0 +1,80 @@ +package core + +import "github.com/MontFerret/ferret/pkg/vm" + +type ( + CollectorType int + + Collector struct { + typ CollectorType + dst vm.Operand + projection *CollectorProjection + groupSelectors []*CollectSelector + aggregation *CollectorAggregation + } +) + +const ( + CollectorTypeCounter CollectorType = iota + CollectorTypeKey + CollectorTypeKeyCounter + CollectorTypeKeyGroup +) + +func NewCollector(type_ CollectorType, dst vm.Operand, projection *CollectorProjection, groupSelectors []*CollectSelector, aggregation *CollectorAggregation) *Collector { + return &Collector{ + typ: type_, + dst: dst, + projection: projection, + groupSelectors: groupSelectors, + aggregation: aggregation, + } +} + +func DetermineCollectorType(withGrouping, withAggregation, withProjection, withCounter bool) CollectorType { + if withGrouping { + if withCounter { + return CollectorTypeKeyCounter + } + + return CollectorTypeKeyGroup + } + + if withAggregation { + return CollectorTypeKeyGroup + } + + return CollectorTypeCounter +} + +func (c *Collector) Type() CollectorType { + return c.typ +} + +func (c *Collector) Destination() vm.Operand { + return c.dst +} + +func (c *Collector) Projection() *CollectorProjection { + return c.projection +} + +func (c *Collector) GroupSelectors() []*CollectSelector { + return c.groupSelectors +} + +func (c *Collector) Aggregation() *CollectorAggregation { + return c.aggregation +} + +func (c *Collector) HasProjection() bool { + return c.projection != nil +} + +func (c *Collector) HasGrouping() bool { + return len(c.groupSelectors) > 0 +} + +func (c *Collector) HasAggregation() bool { + return c.aggregation != nil +} diff --git a/pkg/compiler/internal/core/collector_aggregation.go b/pkg/compiler/internal/core/collector_aggregation.go new file mode 100644 index 00000000..93a61014 --- /dev/null +++ b/pkg/compiler/internal/core/collector_aggregation.go @@ -0,0 +1,60 @@ +package core + +import ( + "github.com/MontFerret/ferret/pkg/runtime" + "github.com/MontFerret/ferret/pkg/vm" +) + +type ( + CollectorAggregation struct { + state vm.Operand + selector []*AggregateSelector + } + + AggregateSelector struct { + name runtime.String + args int + funcName runtime.String + protectedCall bool + } +) + +func NewCollectorAggregation(state vm.Operand, selector []*AggregateSelector) *CollectorAggregation { + return &CollectorAggregation{ + state: state, + selector: selector, + } +} + +func (c *CollectorAggregation) State() vm.Operand { + return c.state +} + +func (c *CollectorAggregation) Selectors() []*AggregateSelector { + return c.selector +} + +func NewAggregateSelector(name runtime.String, args int, funcName runtime.String, protectedCall bool) *AggregateSelector { + return &AggregateSelector{ + name: name, + args: args, + funcName: funcName, + protectedCall: protectedCall, + } +} + +func (s *AggregateSelector) Name() runtime.String { + return s.name +} + +func (s *AggregateSelector) Args() int { + return s.args +} + +func (s *AggregateSelector) FuncName() runtime.String { + return s.funcName +} + +func (s *AggregateSelector) ProtectedCall() bool { + return s.protectedCall +} diff --git a/pkg/compiler/internal/core/collector_spec.go b/pkg/compiler/internal/core/collector_spec.go deleted file mode 100644 index efc4b1b9..00000000 --- a/pkg/compiler/internal/core/collector_spec.go +++ /dev/null @@ -1,80 +0,0 @@ -package core - -import "github.com/MontFerret/ferret/pkg/vm" - -type ( - CollectorType int - - CollectorSpec struct { - typ CollectorType - dst vm.Operand - projection *CollectorProjection - groupSelectors []*CollectSelector - aggregationSelectors []*AggregateSelector - } -) - -const ( - CollectorTypeCounter CollectorType = iota - CollectorTypeKey - CollectorTypeKeyCounter - CollectorTypeKeyGroup -) - -func NewCollectorSpec(type_ CollectorType, dst vm.Operand, projection *CollectorProjection, groupSelectors []*CollectSelector, aggregationSelectors []*AggregateSelector) *CollectorSpec { - return &CollectorSpec{ - typ: type_, - dst: dst, - projection: projection, - groupSelectors: groupSelectors, - aggregationSelectors: aggregationSelectors, - } -} - -func DetermineCollectorType(withGrouping, withAggregation, withProjection, withCounter bool) CollectorType { - if withGrouping { - if withCounter { - return CollectorTypeKeyCounter - } - - return CollectorTypeKeyGroup - } - - if withAggregation { - return CollectorTypeKeyGroup - } - - return CollectorTypeCounter -} - -func (c *CollectorSpec) Type() CollectorType { - return c.typ -} - -func (c *CollectorSpec) Destination() vm.Operand { - return c.dst -} - -func (c *CollectorSpec) Projection() *CollectorProjection { - return c.projection -} - -func (c *CollectorSpec) GroupSelectors() []*CollectSelector { - return c.groupSelectors -} - -func (c *CollectorSpec) AggregationSelectors() []*AggregateSelector { - return c.aggregationSelectors -} - -func (c *CollectorSpec) HasProjection() bool { - return c.projection != nil -} - -func (c *CollectorSpec) HasGrouping() bool { - return len(c.groupSelectors) > 0 -} - -func (c *CollectorSpec) HasAggregation() bool { - return len(c.aggregationSelectors) > 0 -} diff --git a/pkg/compiler/internal/core/emitter.go b/pkg/compiler/internal/core/emitter.go index 0b3ce697..4faa97a9 100644 --- a/pkg/compiler/internal/core/emitter.go +++ b/pkg/compiler/internal/core/emitter.go @@ -103,17 +103,17 @@ func (e *Emitter) Emit(op vm.Opcode) { e.EmitABC(op, 0, 0, 0) } -// EmitA emits an opcode with a single destination register argument. +// EmitA emits an opcode with a single destination value argument. func (e *Emitter) EmitA(op vm.Opcode, dest vm.Operand) { e.EmitABC(op, dest, 0, 0) } -// EmitAB emits an opcode with a destination register and a single source register argument. +// EmitAB emits an opcode with a destination value and a single source value argument. func (e *Emitter) EmitAB(op vm.Opcode, dest, src1 vm.Operand) { e.EmitABC(op, dest, src1, 0) } -// EmitAb emits an opcode with a destination register and a boolean argument. +// EmitAb emits an opcode with a destination value and a boolean argument. func (e *Emitter) EmitAb(op vm.Opcode, dest vm.Operand, arg bool) { var src1 vm.Operand @@ -124,7 +124,7 @@ func (e *Emitter) EmitAb(op vm.Opcode, dest vm.Operand, arg bool) { e.EmitABC(op, dest, src1, 0) } -// EmitAx emits an opcode with a destination register and a custom argument. +// EmitAx emits an opcode with a destination value and a custom argument. func (e *Emitter) EmitAx(op vm.Opcode, dest vm.Operand, arg int) { e.EmitABC(op, dest, vm.Operand(arg), 0) } @@ -134,7 +134,7 @@ func (e *Emitter) EmitAxy(op vm.Opcode, dest vm.Operand, arg1, agr2 int) { e.EmitABC(op, dest, vm.Operand(arg1), vm.Operand(agr2)) } -// EmitAs emits an opcode with a destination register and a sequence of registers. +// EmitAs emits an opcode with a destination value and a sequence of registers. func (e *Emitter) EmitAs(op vm.Opcode, dest vm.Operand, seq RegisterSequence) { if seq != nil { src1 := seq[0] @@ -145,12 +145,12 @@ func (e *Emitter) EmitAs(op vm.Opcode, dest vm.Operand, seq RegisterSequence) { } } -// EmitABx emits an opcode with a destination and source register and a custom argument. +// EmitABx emits an opcode with a destination and source value and a custom argument. func (e *Emitter) EmitABx(op vm.Opcode, dest vm.Operand, src vm.Operand, arg int) { e.EmitABC(op, dest, src, vm.Operand(arg)) } -// EmitABC emits an opcode with a destination register and two source register arguments. +// EmitABC emits an opcode with a destination value and two source value arguments. func (e *Emitter) EmitABC(op vm.Opcode, dest, src1, src2 vm.Operand) { e.instructions = append(e.instructions, vm.Instruction{ Opcode: op, @@ -267,7 +267,7 @@ func (e *Emitter) insertInstruction(label Label, ins vm.Instruction) { pos, ok := e.LabelPosition(label) if !ok { - panic(fmt.Errorf("label not marked: %d", label)) + panic(fmt.Errorf("label not marked: %s", label)) } // Insert instruction at position diff --git a/pkg/compiler/internal/core/emitter_extension.go b/pkg/compiler/internal/core/emitter_extension.go index b4bc4dae..4bef3b3a 100644 --- a/pkg/compiler/internal/core/emitter_extension.go +++ b/pkg/compiler/internal/core/emitter_extension.go @@ -108,6 +108,10 @@ func (e *Emitter) EmitLoadIndex(dst, arr, idx vm.Operand) { e.EmitABC(vm.OpLoadIndex, dst, arr, idx) } +func (e *Emitter) EmitLoadKey(dst, obj, key vm.Operand) { + e.EmitABC(vm.OpLoadKey, dst, obj, key) +} + func (e *Emitter) EmitLoadProperty(dst, obj, prop vm.Operand) { e.EmitABC(vm.OpLoadProperty, dst, obj, prop) } diff --git a/pkg/compiler/internal/core/label.go b/pkg/compiler/internal/core/label.go index 3aadf639..7746ee5b 100644 --- a/pkg/compiler/internal/core/label.go +++ b/pkg/compiler/internal/core/label.go @@ -14,3 +14,7 @@ type ( field int } ) + +func (l Label) String() string { + return l.name +} diff --git a/pkg/compiler/internal/core/registers.go b/pkg/compiler/internal/core/registers.go index a43a625e..f36c8ee7 100644 --- a/pkg/compiler/internal/core/registers.go +++ b/pkg/compiler/internal/core/registers.go @@ -46,7 +46,7 @@ func (ra *RegisterAllocator) Allocate(typ RegisterType) vm.Operand { return reg } - // New register + // New value reg := ra.next ra.next++ @@ -59,13 +59,13 @@ func (ra *RegisterAllocator) Allocate(typ RegisterType) vm.Operand { } func (ra *RegisterAllocator) Free(reg vm.Operand) { - //info, ok := ra.all[reg] + //info, ok := ra.all[state] //if !ok || !info.allocated { // return // double-free or unknown //} // //info.allocated = false - //ra.freelist[info.typ] = append(ra.freelist[info.typ], reg) + //ra.freelist[info.typ] = append(ra.freelist[info.typ], state) } func (ra *RegisterAllocator) AllocateSequence(count int) RegisterSequence { diff --git a/pkg/compiler/internal/helpers.go b/pkg/compiler/internal/helpers.go index 5fb53156..fe308821 100644 --- a/pkg/compiler/internal/helpers.go +++ b/pkg/compiler/internal/helpers.go @@ -26,11 +26,16 @@ func loadConstantTo(ctx *CompilerContext, constant runtime.Value, reg vm.Operand } func loadIndex(ctx *CompilerContext, dst, arr vm.Operand, idx int) { - idxReg := loadConstant(ctx, runtime.NewInt(idx)) + idxReg := loadConstant(ctx, runtime.Int(idx)) ctx.Emitter.EmitLoadIndex(dst, arr, idxReg) ctx.Registers.Free(idxReg) } +func loadKey(ctx *CompilerContext, dst, obj vm.Operand, key string) { + keyReg := loadConstant(ctx, runtime.String(key)) + ctx.Emitter.EmitLoadKey(dst, obj, keyReg) +} + func sortDirection(dir antlr.TerminalNode) runtime.SortDirection { if dir == nil { return runtime.SortDirectionAsc diff --git a/pkg/compiler/internal/loop_collect.go b/pkg/compiler/internal/loop_collect.go index 4bc60a3a..5b39e377 100644 --- a/pkg/compiler/internal/loop_collect.go +++ b/pkg/compiler/internal/loop_collect.go @@ -26,38 +26,38 @@ func (c *LoopCollectCompiler) Compile(ctx fql.ICollectClauseContext) { c.compileLoop(scope) } -// compileCollector processes the COLLECT clause components and creates a CollectorSpec. +// compileCollector processes the COLLECT clause components and creates a Collector. // This function handles the initialization of grouping, aggregation, and projection operations, // and sets up the appropriate collector type based on the COLLECT clause structure. -func (c *LoopCollectCompiler) compileCollector(ctx fql.ICollectClauseContext) *core.CollectorSpec { +func (c *LoopCollectCompiler) compileCollector(ctx fql.ICollectClauseContext) *core.Collector { // Extract all components of the COLLECT clause - grouping := ctx.CollectGrouping() - projection := ctx.CollectGroupProjection() - counter := ctx.CollectCounter() - aggregation := ctx.CollectAggregator() + groupingCtx := ctx.CollectGrouping() + projectionCtx := ctx.CollectGroupProjection() + counterCtx := ctx.CollectCounter() + aggregationCtx := ctx.CollectAggregator() // We gather keys and values for the collector. - kv, groupSelectors := c.initializeGrouping(grouping) + kv, groupSelectors := c.initializeGrouping(groupingCtx) // Determine the collector type based on the presence of different COLLECT components - collectorType := core.DetermineCollectorType(len(groupSelectors) > 0, aggregation != nil, projection != nil, counter != nil) + collectorType := core.DetermineCollectorType(len(groupSelectors) > 0, aggregationCtx != nil, projectionCtx != nil, counterCtx != nil) // We replace DataSet initialization with Collector initialization loop := c.ctx.Loops.Current() dst := loop.PatchDestinationAx(c.ctx.Registers, c.ctx.Emitter, vm.OpDataSetCollector, int(collectorType)) - var aggregationSelectors []*core.AggregateSelector + var aggregation *core.CollectorAggregation - // Initialize aggregation if present in the COLLECT clause - if aggregation != nil { - aggregationSelectors = c.initializeAggregation(aggregation, dst, kv, len(groupSelectors) > 0) + // Initialize aggregationCtx if present in the COLLECT clause + if aggregationCtx != nil { + aggregation = c.initializeAggregation(aggregationCtx, dst, kv, len(groupSelectors) > 0) } - // Initialize projection for group variables or counters - groupProjection := c.initializeProjection(kv, projection, counter) + // Initialize projectionCtx for group variables or counters + projection := c.initializeProjection(kv, projectionCtx, counterCtx) // Create the collector specification with all components - spec := core.NewCollectorSpec(collectorType, dst, groupProjection, groupSelectors, aggregationSelectors) + spec := core.NewCollector(collectorType, dst, projection, groupSelectors, aggregation) // Finalize the collector setup c.finalizeCollector(dst, kv, spec) @@ -72,7 +72,7 @@ func (c *LoopCollectCompiler) compileCollector(ctx fql.ICollectClauseContext) *c // finalizeCollector completes the collector setup by pushing key-value pairs to the collector // and emitting finalization instructions for the current loop. // The behavior varies based on whether grouping and aggregation are used. -func (c *LoopCollectCompiler) finalizeCollector(dst vm.Operand, kv *core.KV, spec *core.CollectorSpec) { +func (c *LoopCollectCompiler) finalizeCollector(dst vm.Operand, kv *core.KV, spec *core.Collector) { loop := c.ctx.Loops.Current() // If we do not use grouping but use aggregation, we do not need to push the key and value @@ -93,7 +93,7 @@ func (c *LoopCollectCompiler) finalizeCollector(dst vm.Operand, kv *core.KV, spe // compileLoop processes the loop operations based on the collector specification. // It handles different combinations of grouping, aggregation, and projection operations, // ensuring that the appropriate VM instructions are generated for each case. -func (c *LoopCollectCompiler) compileLoop(spec *core.CollectorSpec) { +func (c *LoopCollectCompiler) compileLoop(spec *core.Collector) { loop := c.ctx.Loops.Current() // If we are using a projection, we need to ensure the loop is set to ForInLoop @@ -123,17 +123,16 @@ func (c *LoopCollectCompiler) compileLoop(spec *core.CollectorSpec) { // Process aggregation if present if spec.HasAggregation() { - c.unpackGroupedValues(spec) - c.compileAggregation(spec) + c.finalizeAggregation(spec) } // Process grouping if present if spec.HasGrouping() { - c.compileGrouping(spec) + c.finalizeGrouping(spec) } // We finalize projection only if we have a projection and no aggregation - // Because if we have aggregation, we finalize it in the compileAggregation method. + // Because if we have aggregation, we finalize it in the finalizeAggregation method. if spec.HasProjection() && !spec.HasAggregation() { c.finalizeProjection(spec, loop.Value) } diff --git a/pkg/compiler/internal/loop_collect_agg.go b/pkg/compiler/internal/loop_collect_agg.go index 32552cab..46ad1b2b 100644 --- a/pkg/compiler/internal/loop_collect_agg.go +++ b/pkg/compiler/internal/loop_collect_agg.go @@ -14,55 +14,35 @@ import ( // For grouped aggregations, it compiles the selectors and packs them with the loop value. // For global aggregations, it pushes the selectors directly to the collector. // Returns a slice of AggregateSelectors that describe the aggregation operations. -func (c *LoopCollectCompiler) initializeAggregation(ctx fql.ICollectAggregatorContext, dst vm.Operand, kv *core.KV, withGrouping bool) []*core.AggregateSelector { +func (c *LoopCollectCompiler) initializeAggregation(ctx fql.ICollectAggregatorContext, dst vm.Operand, kv *core.KV, withGrouping bool) *core.CollectorAggregation { + loop := c.ctx.Loops.Current() selectors := ctx.AllCollectAggregateSelector() - var compiledSelectors []*core.AggregateSelector // If we have grouping, we need to pack the selectors into the collector value if withGrouping { + // TODO: We need to figure out how to free the aggregator register later + aggregator := c.ctx.Registers.Allocate(core.State) + // We create a separate collector for aggregation in grouped mode + c.ctx.Emitter.InsertAx(loop.StartLabel, vm.OpDataSetCollector, aggregator, int(core.CollectorTypeKeyGroup)) + // Compile selectors for grouped aggregation - compiledSelectors = c.compileGroupedAggregationSelectors(selectors) + aggregateSelectors := c.initializeGroupedAggregationSelectors(selectors, kv, aggregator) - // Pack the selectors into the collector value along with the loop value - c.packGroupedValues(kv, compiledSelectors) - } else { - // For global aggregation, we just push the selectors into the global collector - compiledSelectors = c.compileGlobalAggregationSelectors(selectors, dst) + return core.NewCollectorAggregation(aggregator, aggregateSelectors) } - return compiledSelectors + // For global aggregation, we just push the selectors into the global collector + aggregateSelectors := c.initializeGlobalAggregationSelectors(selectors, dst) + + return core.NewCollectorAggregation(dst, aggregateSelectors) } -// packGroupedValues combines the loop value with aggregation selector values into a single array. -// This is used for grouped aggregations to keep all values together for each group. -// The first element of the array is the loop value, followed by the aggregation selector values. -func (c *LoopCollectCompiler) packGroupedValues(kv *core.KV, selectors []*core.AggregateSelector) { - // Allocate a sequence of registers for the array elements - // We need one extra register for the loop value (hence +1) - seq := c.ctx.Registers.AllocateSequence(len(selectors) + 1) - - // Move the loop value to the first position in the sequence - c.ctx.Emitter.EmitMove(seq[0], kv.Value) - - // Move each selector value to its position in the sequence - for i, selector := range selectors { - c.ctx.Emitter.EmitMove(seq[i+1], selector.Register()) - // Free the selector register as we no longer need it - c.ctx.Registers.Free(selector.Register()) - } - - // Create an array from the sequence and store it in the kv.Value register - // This replaces the original loop value with an array containing both - // the loop value and all selector values - c.ctx.Emitter.EmitArray(kv.Value, seq) -} - -// compileGroupedAggregationSelectors processes aggregation selectors for grouped aggregations. +// initializeGroupedAggregationSelectors processes aggregation selectors for grouped aggregations. // It compiles each selector's function call expression and arguments, and creates AggregateSelector objects. // For selectors with multiple arguments, it packs them into an array. // Returns a slice of AggregateSelectors that describe the aggregation operations. -func (c *LoopCollectCompiler) compileGroupedAggregationSelectors(selectors []fql.ICollectAggregateSelectorContext) []*core.AggregateSelector { - wrappedSelectors := make([]*core.AggregateSelector, 0, len(selectors)) +func (c *LoopCollectCompiler) initializeGroupedAggregationSelectors(selectors []fql.ICollectAggregateSelectorContext, kv *core.KV, dst vm.Operand) []*core.AggregateSelector { + wrappedSelectors := make([]*core.AggregateSelector, len(selectors)) for i := 0; i < len(selectors); i++ { selector := selectors[i] @@ -78,16 +58,21 @@ func (c *LoopCollectCompiler) compileGroupedAggregationSelectors(selectors []fql panic("No arguments provided for the function call in the aggregate selector") } - var selectorArg vm.Operand - if len(args) > 1 { - // For multiple arguments, pack them into an array - selectorArg = c.ctx.Registers.Allocate(core.Temp) - c.ctx.Emitter.EmitArray(selectorArg, args) - c.ctx.Registers.FreeSequence(args) + // For multiple arguments, push each one with an indexed key + for y := 0; y < len(args); y++ { + // Create a key with format "name:index" + key := c.loadSelectorKey(kv.Key, name, y) + // Push the key-value pair to the collector + c.ctx.Emitter.EmitPushKV(dst, key, args[y]) + c.ctx.Registers.Free(key) + } } else { - // For a single argument, use it directly - selectorArg = args[0] + // For a single argument, use the selector name as the key + key := c.loadSelectorKey(kv.Key, name, -1) + // Push the key-value pair to the collector + c.ctx.Emitter.EmitPushKV(dst, key, args[0]) + c.ctx.Registers.Free(key) } // Get the function name and check if it's a protected call (with TRY) @@ -96,17 +81,20 @@ func (c *LoopCollectCompiler) compileGroupedAggregationSelectors(selectors []fql isProtected := fce.ErrorOperator() != nil // Create an AggregateSelector with all the information needed to process it later - wrappedSelectors = append(wrappedSelectors, core.NewAggregateSelector(name, len(args), funcName, isProtected, selectorArg)) + wrappedSelectors[i] = core.NewAggregateSelector(name, len(args), funcName, isProtected) + + // Free the argument registers + c.ctx.Registers.FreeSequence(args) } return wrappedSelectors } -// compileGlobalAggregationSelectors processes aggregation selectors for global (non-grouped) aggregations. +// initializeGlobalAggregationSelectors processes aggregation selectors for global (non-grouped) aggregations. // It compiles each selector's function call expression and arguments, and pushes them directly to the collector. // For selectors with multiple arguments, it uses indexed keys to store each argument separately. // Returns a slice of AggregateSelectors that describe the aggregation operations. -func (c *LoopCollectCompiler) compileGlobalAggregationSelectors(selectors []fql.ICollectAggregateSelectorContext, dst vm.Operand) []*core.AggregateSelector { +func (c *LoopCollectCompiler) initializeGlobalAggregationSelectors(selectors []fql.ICollectAggregateSelectorContext, dst vm.Operand) []*core.AggregateSelector { wrappedSelectors := make([]*core.AggregateSelector, 0, len(selectors)) for i := 0; i < len(selectors); i++ { @@ -127,7 +115,7 @@ func (c *LoopCollectCompiler) compileGlobalAggregationSelectors(selectors []fql. // For multiple arguments, push each one with an indexed key for y := 0; y < len(args); y++ { // Create a key with format "name:index" - key := c.loadAggregationArgKey(name, y) + key := c.loadGlobalSelectorKey(name, y) // Push the key-value pair to the collector c.ctx.Emitter.EmitPushKV(dst, key, args[y]) c.ctx.Registers.Free(key) @@ -147,7 +135,7 @@ func (c *LoopCollectCompiler) compileGlobalAggregationSelectors(selectors []fql. // For global aggregation, we don't need to store the register in the selector // as the values are already pushed to the collector - wrappedSelectors = append(wrappedSelectors, core.NewAggregateSelector(name, len(args), funcName, isProtected, vm.NoopOperand)) + wrappedSelectors = append(wrappedSelectors, core.NewAggregateSelector(name, len(args), funcName, isProtected)) // Free the argument registers c.ctx.Registers.FreeSequence(args) @@ -156,81 +144,32 @@ func (c *LoopCollectCompiler) compileGlobalAggregationSelectors(selectors []fql. return wrappedSelectors } -// unpackGroupedValues extracts values from the packed array created during grouped aggregation. -// It loads the loop value from index 0 and each aggregation selector value from subsequent indices. -// This is only needed for grouped aggregations, so it returns early if there's no grouping. -func (c *LoopCollectCompiler) unpackGroupedValues(spec *core.CollectorSpec) { - // Skip if there's no grouping - if !spec.HasGrouping() { - return - } - - loop := c.ctx.Loops.Current() - // Allocate a temporary register for the loop value - valReg := c.ctx.Registers.Allocate(core.Temp) - - // Load the original loop value from index 0 of the array - loadIndex(c.ctx, valReg, loop.Value, 0) - - // Load each aggregation selector value from its index in the array - for i, selector := range spec.AggregationSelectors() { - loadIndex(c.ctx, selector.Register(), loop.Value, i+1) - } - - // Free the temporary register - c.ctx.Registers.Free(valReg) -} - -// compileAggregation processes the aggregation operations based on the collector specification. +// finalizeAggregation processes the aggregation operations based on the collector specification. // It delegates to either grouped or global aggregation compilation based on whether grouping is used. -func (c *LoopCollectCompiler) compileAggregation(spec *core.CollectorSpec) { +func (c *LoopCollectCompiler) finalizeAggregation(spec *core.Collector) { if spec.HasGrouping() { // For aggregations with grouping - c.compileGroupedAggregation(spec) + c.finalizeGroupedAggregation(spec) } else { // For global aggregations without grouping - c.compileGlobalAggregation(spec) + c.finalizeGlobalAggregation(spec) } } -// compileGroupedAggregation handles grouped aggregation operations. +// finalizeGroupedAggregation handles grouped aggregation operations. // This function is currently commented out in the original code, likely because // the functionality is implemented differently or is being refactored. // The commented code shows the intended approach for handling grouped aggregations. -func (c *LoopCollectCompiler) compileGroupedAggregation(spec *core.CollectorSpec) { - //parentLoop := c.ctx.Loops.Current() - //// We need to allocate a temporary accumulator to store aggregation results - //selectors := ctx.AllCollectAggregateSelector() - //accumulator := c.ctx.Registers.Allocate(core.Temp) - //c.ctx.Emitter.EmitAx(vm.OpDataSetCollector, accumulator, int(core.CollectorTypeKeyGroup)) - // - //loop := c.ctx.Loops.NewForInLoop(core.TemporalLoop, false) - //loop.Src = c.ctx.Registers.Allocate(core.Temp) - // - //// Now we iterate over the grouped items - //parentLoop.EmitValue(loop.Src, c.ctx.Emitter) - // - //// Nested scope for aggregators - //c.ctx.Symbols.EnterScope() - //loop.DeclareValueVar(parentLoop.ValueName, c.ctx.Symbols) - //loop.EmitInitialization(c.ctx.Registers, c.ctx.Emitter, c.ctx.Loops.Depth()) - // - //// Add value selectors to the accumulators - //argsPkg := c.compileGroupedAggregationSelectors(selectors, accumulator) - // - //loop.EmitFinalization(c.ctx.Emitter) - //c.ctx.Symbols.ExitScope() - // - //// Now we can iterate over the selectors and execute the aggregation functions by passing the accumulators - //// And define variables for each accumulator result - //c.compileAggregationFuncCalls(selectors, accumulator, argsPkg) - //c.ctx.Registers.Free(accumulator) +func (c *LoopCollectCompiler) finalizeGroupedAggregation(spec *core.Collector) { + for i, selector := range spec.Aggregation().Selectors() { + c.compileGroupedAggregationFuncCall(selector, spec.Aggregation().State(), i) + } } -// compileGlobalAggregation handles global (non-grouped) aggregation operations. +// finalizeGlobalAggregation handles global (non-grouped) aggregation operations. // It creates a new loop with a single iteration to process the aggregation results. // This approach allows the aggregation to be processed in a consistent way with other operations. -func (c *LoopCollectCompiler) compileGlobalAggregation(spec *core.CollectorSpec) { +func (c *LoopCollectCompiler) finalizeGlobalAggregation(spec *core.Collector) { // At this point, the previous loop is finalized, so we can pop it and free its registers prevLoop := c.ctx.Loops.Pop() c.ctx.Registers.Free(prevLoop.Key) @@ -261,17 +200,17 @@ func (c *LoopCollectCompiler) compileGlobalAggregation(spec *core.CollectorSpec) loop.EmitInitialization(c.ctx.Registers, c.ctx.Emitter, c.ctx.Loops.Depth()) // Process the aggregation function calls using the values from the previous loop's collector - c.compileAggregationFuncCalls(spec, prevLoop.Dst) + c.compileGlobalAggregationFuncCalls(spec, prevLoop.Dst) // Free the previous loop's destination register c.ctx.Registers.Free(prevLoop.Dst) } -// compileAggregationFuncCalls processes the aggregation function calls for the selectors. +// compileGlobalAggregationFuncCalls processes the aggregation function calls for the selectors. // It loads the arguments from the aggregator, calls the aggregation functions, // and assigns the results to local variables. // It also handles the case where there are no records in the aggregator by loading NONE values. -func (c *LoopCollectCompiler) compileAggregationFuncCalls(spec *core.CollectorSpec, aggregator vm.Operand) { +func (c *LoopCollectCompiler) compileGlobalAggregationFuncCalls(spec *core.Collector, aggregator vm.Operand) { // Gets the number of records in the accumulator cond := c.ctx.Registers.Allocate(core.Temp) c.ctx.Emitter.EmitAB(vm.OpLength, cond, aggregator) @@ -285,7 +224,7 @@ func (c *LoopCollectCompiler) compileAggregationFuncCalls(spec *core.CollectorSp // We skip the key retrieval and function call if there are no records in the accumulator c.ctx.Emitter.EmitJumpIfTrue(cond, elseLabel) - selectors := spec.AggregationSelectors() + selectors := spec.Aggregation().Selectors() selectorVarRegs := make([]vm.Operand, len(selectors)) // Process each aggregation selector @@ -298,7 +237,7 @@ func (c *LoopCollectCompiler) compileAggregationFuncCalls(spec *core.CollectorSp args = c.ctx.Registers.AllocateSequence(selector.Args()) for y, reg := range args { - argKeyReg := c.loadAggregationArgKey(selector.Name(), y) + argKeyReg := c.loadGlobalSelectorKey(selector.Name(), y) c.ctx.Emitter.EmitABC(vm.OpLoadKey, reg, aggregator, argKeyReg) c.ctx.Registers.Free(argKeyReg) } @@ -347,22 +286,60 @@ func (c *LoopCollectCompiler) compileAggregationFuncCalls(spec *core.CollectorSp c.ctx.Registers.Free(cond) } -// compileAggregationFuncCall processes a single aggregation function call. -// It declares a local variable for the aggregation result and loads the value from the selector register. -// This is used for grouped aggregations where the selector values are stored in registers. -func (c *LoopCollectCompiler) compileAggregationFuncCall(selector *core.AggregateSelector) { +func (c *LoopCollectCompiler) compileGroupedAggregationFuncCall(selector *core.AggregateSelector, aggregator vm.Operand, idx int) { + loop := c.ctx.Loops.Current() // Declare a local variable with the selector name - varReg := c.ctx.Symbols.DeclareLocal(selector.Name().String(), core.TypeUnknown) - // Load the value from index 1 of the selector register (index 0 is the original value) - loadIndex(c.ctx, varReg, selector.Register(), 1) + valReg := c.ctx.Symbols.DeclareLocal(selector.Name().String(), core.TypeUnknown) + + var args core.RegisterSequence + + // We need to unpack arguments from the aggregator + if selector.Args() > 1 { + // For multiple arguments, allocate a sequence and load each argument by its indexed key + args = c.ctx.Registers.AllocateSequence(selector.Args()) + + for y, reg := range args { + key := c.loadSelectorKey(loop.Key, selector.Name(), y) + c.ctx.Emitter.EmitABC(vm.OpLoadKey, reg, aggregator, key) + c.ctx.Registers.Free(key) + } + } else { + // For a single argument, load it directly using the selector name as key + key := c.loadSelectorKey(loop.Key, selector.Name(), -1) + value := c.ctx.Registers.Allocate(core.Temp) + c.ctx.Emitter.EmitABC(vm.OpLoadKey, value, aggregator, key) + args = core.RegisterSequence{value} + c.ctx.Registers.Free(key) + } + + resArg := c.ctx.ExprCompiler.CompileFunctionCallByNameWith(selector.FuncName(), selector.ProtectedCall(), args) + + c.ctx.Emitter.EmitMove(valReg, resArg) } -// loadAggregationArgKey creates a key for an aggregation argument by combining the selector name and argument index. +// loadGlobalSelectorKey creates a key for an aggregation argument by combining the selector name and argument index. // This is used for global aggregations with multiple arguments to store each argument separately. // Returns a register containing the key as a string constant. -func (c *LoopCollectCompiler) loadAggregationArgKey(selector runtime.String, arg int) vm.Operand { +func (c *LoopCollectCompiler) loadGlobalSelectorKey(selector runtime.String, arg int) vm.Operand { // Create a key with format "selectorName:argIndex" argKey := selector.String() + ":" + strconv.Itoa(arg) // Load the key as a string constant return loadConstant(c.ctx, runtime.String(argKey)) } + +func (c *LoopCollectCompiler) loadSelectorKey(key vm.Operand, selector runtime.String, arg int) vm.Operand { + selectorKey := c.ctx.Registers.Allocate(core.Temp) + selectorName := loadConstant(c.ctx, selector) + + c.ctx.Emitter.EmitABC(vm.OpAdd, selectorKey, key, selectorName) + + if arg >= 0 { + selectorIndex := loadConstant(c.ctx, runtime.String(strconv.Itoa(arg))) + c.ctx.Emitter.EmitABC(vm.OpAdd, selectorKey, selectorKey, selectorIndex) + c.ctx.Registers.Free(selectorIndex) + } + + c.ctx.Registers.Free(selectorName) + + return selectorKey +} diff --git a/pkg/compiler/internal/loop_collect_grp.go b/pkg/compiler/internal/loop_collect_grp.go index fc7d51a3..da6fbe4e 100644 --- a/pkg/compiler/internal/loop_collect_grp.go +++ b/pkg/compiler/internal/loop_collect_grp.go @@ -94,9 +94,9 @@ func (c *LoopCollectCompiler) compileGroupKeys(ctx fql.ICollectGroupingContext) return kvKeyReg, collectSelectors } -// compileGrouping processes the group selectors and creates local variables for them. +// finalizeGrouping processes the group selectors and creates local variables for them. // It handles both multiple selectors (as array elements) and single selectors differently. -func (c *LoopCollectCompiler) compileGrouping(spec *core.CollectorSpec) { +func (c *LoopCollectCompiler) finalizeGrouping(spec *core.Collector) { loop := c.ctx.Loops.Current() if len(spec.GroupSelectors()) > 1 { diff --git a/pkg/compiler/internal/loop_collect_prj.go b/pkg/compiler/internal/loop_collect_prj.go index ba5b6112..7c3f0884 100644 --- a/pkg/compiler/internal/loop_collect_prj.go +++ b/pkg/compiler/internal/loop_collect_prj.go @@ -37,7 +37,7 @@ func (c *LoopCollectCompiler) initializeProjection(kv *core.KV, projection fql.I // finalizeProjection completes the projection setup by creating and assigning local variables. // It handles different behaviors based on whether grouping and aggregation are used. // Returns the register containing the projected value. -func (c *LoopCollectCompiler) finalizeProjection(spec *core.CollectorSpec, aggregator vm.Operand) vm.Operand { +func (c *LoopCollectCompiler) finalizeProjection(spec *core.Collector, aggregator vm.Operand) vm.Operand { loop := c.ctx.Loops.Current() varName := spec.Projection().VariableName() diff --git a/pkg/stdlib/collections/count.go b/pkg/stdlib/collections/count.go new file mode 100644 index 00000000..6328de82 --- /dev/null +++ b/pkg/stdlib/collections/count.go @@ -0,0 +1,19 @@ +package collections + +import ( + "context" + + "github.com/MontFerret/ferret/pkg/runtime" + "github.com/MontFerret/ferret/pkg/runtime/core" +) + +// COUNT computes the number of distinct elements in the given collection and returns the count as an integer. +func Count(ctx context.Context, arg core.Value) (core.Value, error) { + collection, err := runtime.CastCollection(arg) + + if err != nil { + return runtime.ZeroInt, err + } + + return collection.Length(ctx) +} diff --git a/pkg/stdlib/collections/count_distinct.go b/pkg/stdlib/collections/count_distinct.go index 9baa5e9b..7625753d 100644 --- a/pkg/stdlib/collections/count_distinct.go +++ b/pkg/stdlib/collections/count_distinct.go @@ -8,12 +8,8 @@ import ( ) // COUNT_DISTINCT computes the number of distinct elements in the given collection and returns the count as an integer. -func CountDistinct(ctx context.Context, args ...core.Value) (core.Value, error) { - if err := runtime.ValidateArgs(args, 1, 1); err != nil { - return runtime.None, err - } - - collection, err := runtime.CastCollection(args[0]) +func CountDistinct(ctx context.Context, arg core.Value) (core.Value, error) { + collection, err := runtime.CastCollection(arg) if err != nil { return runtime.ZeroInt, err diff --git a/pkg/stdlib/collections/include.go b/pkg/stdlib/collections/include.go index c75ecfe6..b9d75f21 100644 --- a/pkg/stdlib/collections/include.go +++ b/pkg/stdlib/collections/include.go @@ -10,16 +10,11 @@ import ( // @param {String | Any[] | hashMap | Iterable} haystack - The value container. // @param {Any} needle - The target value to assert. // @return {Boolean} - A boolean value that indicates whether a container contains a given value. -func Includes(ctx context.Context, args ...runtime.Value) (runtime.Value, error) { - err := runtime.ValidateArgs(args, 2, 2) - - if err != nil { - return runtime.None, err - } - +func Includes(ctx context.Context, arg1, arg2 runtime.Value) (runtime.Value, error) { + var err error var result runtime.Boolean - haystack := args[0] - needle := args[1] + haystack := arg1 + needle := arg2 switch v := haystack.(type) { case runtime.String: @@ -67,5 +62,5 @@ func Includes(ctx context.Context, args ...runtime.Value) (runtime.Value, error) ) } - return result, nil + return result, err } diff --git a/pkg/stdlib/collections/lib.go b/pkg/stdlib/collections/lib.go index d91f58ab..fc2fa042 100644 --- a/pkg/stdlib/collections/lib.go +++ b/pkg/stdlib/collections/lib.go @@ -5,10 +5,12 @@ import ( ) func RegisterLib(ns runtime.Namespace) error { - return ns.RegisterFunctions( - runtime.NewFunctionsFromMap(map[string]runtime.Function{ - "COUNT_DISTINCT": CountDistinct, - "INCLUDES": Includes, - "REVERSE": Reverse, - })) + return ns.RegisterFunctions(runtime. + NewFunctionsBuilder(). + Set1("COUNT_DISTINCT", CountDistinct). + Set1("COUNT", Count). + Set2("INCLUDES", Includes). + Set1("REVERSE", Reverse). + Build(), + ) } diff --git a/pkg/stdlib/collections/reverse.go b/pkg/stdlib/collections/reverse.go index 2fbf0665..6e19ef37 100644 --- a/pkg/stdlib/collections/reverse.go +++ b/pkg/stdlib/collections/reverse.go @@ -9,14 +9,8 @@ import ( // REVERSE returns the reverse of a given string or array value. // @param {String | Any[]} value - The string or array to reverse. // @return {String | Any[]} - A reversed version of a given value. -func Reverse(ctx context.Context, args ...runtime.Value) (runtime.Value, error) { - err := runtime.ValidateArgs(args, 1, 1) - - if err != nil { - return runtime.EmptyString, err - } - - switch col := args[0].(type) { +func Reverse(ctx context.Context, arg runtime.Value) (runtime.Value, error) { + switch col := arg.(type) { case runtime.String: runes := []rune(string(col)) size := len(runes) @@ -48,6 +42,6 @@ func Reverse(ctx context.Context, args ...runtime.Value) (runtime.Value, error) return result, nil default: - return runtime.None, runtime.TypeErrorOf(args[0], runtime.TypeList, runtime.TypeString) + return runtime.None, runtime.TypeErrorOf(arg, runtime.TypeList, runtime.TypeString) } } diff --git a/test/integration/vm/vm_for_in_collect_agg_test.go b/test/integration/vm/vm_for_in_collect_agg_test.go index 33109af3..55b26587 100644 --- a/test/integration/vm/vm_for_in_collect_agg_test.go +++ b/test/integration/vm/vm_for_in_collect_agg_test.go @@ -6,6 +6,38 @@ import ( func TestCollectAggregate(t *testing.T) { RunUseCases(t, []UseCase{ + DebugCaseArray(` + LET users = [ + { + active: true, + age: null, + gender: "m", + married: true + }, + { + active: true, + age: 25, + gender: "f", + married: false + }, + { + active: true, + age: null, + gender: "m", + married: false + } + ] + FOR u IN users + COLLECT gender = u.gender + AGGREGATE userCount = COUNT(u) + RETURN { + gender, + userCount, + } + `, []any{ + map[string]any{"gender": "f", "userCount": 1}, + map[string]any{"gender": "m", "userCount": 2}, + }), SkipCaseArray(` LET users = [ {