1
0
mirror of https://github.com/MontFerret/ferret.git synced 2025-08-13 19:52:52 +02:00

Refactor collector and aggregator handling; add Get and Close methods to collectors and sorters, restructure aggregation logic with multi-argument support, and update loop compilation to optimize resource management and key-value integration

This commit is contained in:
Tim Voronov
2025-06-24 18:19:22 -04:00
parent 4de53a5665
commit f859c4ead7
12 changed files with 236 additions and 96 deletions

View File

@@ -2,8 +2,9 @@ package core
import (
"fmt"
"github.com/MontFerret/ferret/pkg/vm"
"strings"
"github.com/MontFerret/ferret/pkg/vm"
)
type LoopTable struct {

View File

@@ -1,6 +1,8 @@
package internal
import (
"strconv"
"github.com/MontFerret/ferret/pkg/compiler/internal/core"
"github.com/MontFerret/ferret/pkg/parser/fql"
"github.com/MontFerret/ferret/pkg/runtime"
@@ -17,9 +19,11 @@ func (cc *LoopCollectCompiler) compileAggregation(c fql.ICollectAggregatorContex
func (cc *LoopCollectCompiler) compileGroupedAggregation(c fql.ICollectAggregatorContext) {
parentLoop := cc.ctx.Loops.Current()
// We need to allocate a temporary accumulators to store aggregation results
// We need to allocate a temporary accumulator to store aggregation results
selectors := c.AllCollectAggregateSelector()
accums := cc.initAggrAccumulators(selectors)
accumulator := cc.ctx.Registers.Allocate(core.Temp)
cc.ctx.Emitter.EmitAx(vm.OpDataSetCollector, accumulator, int(core.CollectorTypeKeyGroup))
loop := cc.ctx.Loops.CreateFor(core.TemporalLoop, cc.ctx.Registers.Allocate(core.Temp), false)
// Now we iterate over the grouped items
@@ -31,26 +35,15 @@ func (cc *LoopCollectCompiler) compileGroupedAggregation(c fql.ICollectAggregato
loop.EmitInitialization(cc.ctx.Registers, cc.ctx.Emitter)
// Add value selectors to the accumulators
cc.collectAggregationFuncArgs(selectors, func(i int, resultReg vm.Operand) {
cc.ctx.Emitter.EmitAB(vm.OpPush, accums[i], resultReg)
})
argsPkg := cc.compileAggregationFuncArgs(selectors, accumulator)
loop.EmitFinalization(cc.ctx.Emitter)
cc.ctx.Symbols.ExitScope()
// Now we can iterate over the selectors and execute the aggregation functions by passing the accumulators
// And define variables for each accumulator result
cc.compileAggregationFuncCall(selectors, func(i int, _ string) core.RegisterSequence {
return core.RegisterSequence{accums[i]}
}, func(i int) {})
// Free the registers for accumulators
for _, reg := range accums {
cc.ctx.Registers.Free(reg)
}
// Free the register for the iterator value
// cc.ctx.Registers.Free(aggrIterVal)
cc.compileAggregationFuncCall(selectors, accumulator, argsPkg)
cc.ctx.Registers.Free(accumulator)
}
func (cc *LoopCollectCompiler) compileGlobalAggregation(c fql.ICollectAggregatorContext) {
@@ -60,14 +53,9 @@ func (cc *LoopCollectCompiler) compileGlobalAggregation(c fql.ICollectAggregator
// Nested scope for aggregators
cc.ctx.Symbols.EnterScope()
// Now we add value selectors to the accumulators
// Now we add value selectors to the collector
selectors := c.AllCollectAggregateSelector()
cc.collectAggregationFuncArgs(selectors, func(i int, resultReg vm.Operand) {
aggrKeyName := selectors[i].Identifier().GetText()
aggrKeyReg := loadConstant(cc.ctx, runtime.String(aggrKeyName))
cc.ctx.Emitter.EmitABC(vm.OpPushKV, parentLoop.Dst, aggrKeyReg, resultReg)
cc.ctx.Registers.Free(aggrKeyReg)
})
argsPkg := cc.compileAggregationFuncArgs(selectors, parentLoop.Dst)
parentLoop.EmitFinalization(cc.ctx.Emitter)
cc.ctx.Loops.Pop()
@@ -90,27 +78,13 @@ func (cc *LoopCollectCompiler) compileGlobalAggregation(c fql.ICollectAggregator
loop.EmitInitialization(cc.ctx.Registers, cc.ctx.Emitter)
// We just need to take the grouped values and call aggregation functions using them as args
var key vm.Operand
var value vm.Operand
cc.compileAggregationFuncCall(selectors, func(i int, selectorVarName string) core.RegisterSequence {
// We execute the function call with the accumulator as an argument
key = loadConstant(cc.ctx, runtime.String(selectorVarName))
value = cc.ctx.Registers.Allocate(core.Temp)
cc.ctx.Emitter.EmitABC(vm.OpLoadKey, value, aggregator, key)
return core.RegisterSequence{value}
}, func(_ int) {
cc.ctx.Registers.Free(value)
cc.ctx.Registers.Free(key)
})
cc.compileAggregationFuncCall(selectors, aggregator, argsPkg)
cc.ctx.Registers.Free(aggregator)
// Free the register for the iterator value
// cc.ctx.Registers.Free(aggrIterVal)
}
func (cc *LoopCollectCompiler) collectAggregationFuncArgs(selectors []fql.ICollectAggregateSelectorContext, collector func(int, vm.Operand)) {
func (cc *LoopCollectCompiler) compileAggregationFuncArgs(selectors []fql.ICollectAggregateSelectorContext, collector vm.Operand) []int {
argsPkg := make([]int, len(selectors))
for i := 0; i < len(selectors); i++ {
selector := selectors[i]
fcx := selector.FunctionCallExpression()
@@ -121,63 +95,64 @@ func (cc *LoopCollectCompiler) collectAggregationFuncArgs(selectors []fql.IColle
panic("No arguments provided for the function call in the aggregate selector")
}
aggrKeyReg := loadConstant(cc.ctx, runtime.Int(i))
// we keep information about the args - whether we need to unpack them or not
argsPkg[i] = len(args)
if len(args) > 1 {
// TODO: Better error handling
panic("Too many arguments")
for y, arg := range args {
argKeyReg := cc.loadAggregationArgKey(i, y)
cc.ctx.Emitter.EmitABC(vm.OpPushKV, collector, argKeyReg, arg)
cc.ctx.Registers.Free(argKeyReg)
}
} else {
cc.ctx.Emitter.EmitABC(vm.OpPushKV, collector, aggrKeyReg, args[0])
}
resultReg := args[0]
collector(i, resultReg)
cc.ctx.Registers.Free(resultReg)
cc.ctx.Registers.Free(aggrKeyReg)
cc.ctx.Registers.FreeSequence(args)
}
return argsPkg
}
func (cc *LoopCollectCompiler) compileAggregationFuncCall(selectors []fql.ICollectAggregateSelectorContext, provider func(int, string) core.RegisterSequence, cleanup func(int)) {
func (cc *LoopCollectCompiler) compileAggregationFuncCall(selectors []fql.ICollectAggregateSelectorContext, accumulator vm.Operand, argsPkg []int) {
for i, selector := range selectors {
fcx := selector.FunctionCallExpression()
// We won't make any checks here, as we already did it before
selectorVarName := selector.Identifier().GetText()
argsNum := argsPkg[i]
result := cc.ctx.ExprCompiler.CompileFunctionCallWith(fcx.FunctionCall(), fcx.ErrorOperator() != nil, provider(i, selectorVarName))
var args core.RegisterSequence
// We need to unpack arguments
if argsNum > 1 {
args = cc.ctx.Registers.AllocateSequence(argsNum)
for y, reg := range args {
argKeyReg := cc.loadAggregationArgKey(i, y)
cc.ctx.Emitter.EmitABC(vm.OpLoadKey, reg, accumulator, argKeyReg)
cc.ctx.Registers.Free(argKeyReg)
}
} else {
key := loadConstant(cc.ctx, runtime.Int(i))
value := cc.ctx.Registers.Allocate(core.Temp)
cc.ctx.Emitter.EmitABC(vm.OpLoadKey, value, accumulator, key)
args = core.RegisterSequence{value}
cc.ctx.Registers.Free(key)
}
fcx := selector.FunctionCallExpression()
result := cc.ctx.ExprCompiler.CompileFunctionCallWith(fcx.FunctionCall(), fcx.ErrorOperator() != nil, args)
// We define the variable for the selector result in the upper scope
// Since this temporary scope is only for aggregators and will be closed after the aggregation
selectorVarName := selector.Identifier().GetText()
varReg := cc.ctx.Symbols.DeclareLocal(selectorVarName)
cc.ctx.Emitter.EmitAB(vm.OpMove, varReg, result)
cc.ctx.Registers.Free(result)
cleanup(i)
}
}
func (cc *LoopCollectCompiler) initAggrAccumulators(selectors []fql.ICollectAggregateSelectorContext) []vm.Operand {
accums := make([]vm.Operand, len(selectors))
// First of all, we allocate registers for accumulators
accums = make([]vm.Operand, len(selectors))
// We need to allocate a register for each accumulator
for i := 0; i < len(selectors); i++ {
reg := cc.ctx.Registers.Allocate(core.Temp)
accums[i] = reg
// TODO: Select persistent List type, we do not know how many items we will have
cc.ctx.Emitter.EmitA(vm.OpList, reg)
}
return accums
}
func (cc *LoopCollectCompiler) emitPushToAggrAccumulators(accums []vm.Operand, selectors []fql.ICollectAggregateSelectorContext, loop *core.Loop) {
for i, selector := range selectors {
fcx := selector.FunctionCallExpression()
args := cc.ctx.ExprCompiler.CompileArgumentList(fcx.FunctionCall().ArgumentList())
if len(args) != 1 {
panic("aggregate function must have exactly one argument")
}
cc.ctx.Emitter.EmitAB(vm.OpPush, accums[i], args[0])
cc.ctx.Registers.Free(args[0])
}
func (cc *LoopCollectCompiler) loadAggregationArgKey(selector int, arg int) vm.Operand {
argKey := strconv.Itoa(selector) + ":" + strconv.Itoa(arg)
return loadConstant(cc.ctx, runtime.String(argKey))
}

View File

@@ -28,3 +28,11 @@ func (c *CounterCollector) Add(_ context.Context, _, _ runtime.Value) error {
return nil
}
func (c *CounterCollector) Get(_ context.Context, _ runtime.Value) (runtime.Value, error) {
return c.Value, nil
}
func (c *CounterCollector) Close() error {
return nil
}

View File

@@ -2,6 +2,7 @@ package internal
import (
"context"
"io"
"github.com/MontFerret/ferret/pkg/runtime"
)
@@ -52,3 +53,31 @@ func (c *KeyCollector) Add(ctx context.Context, key, _ runtime.Value) error {
return nil
}
func (c *KeyCollector) Get(ctx context.Context, key runtime.Value) (runtime.Value, error) {
k, err := Stringify(ctx, key)
if err != nil {
return nil, err
}
v, ok := c.grouping[k]
if !ok {
return runtime.None, runtime.ErrNotFound
}
return v, nil
}
func (c *KeyCollector) Close() error {
val := c.Value
c.Value = nil
c.grouping = nil
if closer := val.(io.Closer); closer != nil {
return closer.Close()
}
return nil
}

View File

@@ -2,6 +2,7 @@ package internal
import (
"context"
"io"
"github.com/MontFerret/ferret/pkg/runtime"
)
@@ -101,3 +102,31 @@ func (c *KeyCounterCollector) Add(ctx context.Context, key, _ runtime.Value) err
return nil
}
func (c *KeyCounterCollector) Get(ctx context.Context, key runtime.Value) (runtime.Value, error) {
k, err := Stringify(ctx, key)
if err != nil {
return nil, err
}
v, ok := c.grouping[k]
if !ok {
return runtime.None, runtime.ErrNotFound
}
return v, nil
}
func (c *KeyCounterCollector) Close() error {
val := c.Value
c.Value = nil
c.grouping = nil
if closer := val.(io.Closer); closer != nil {
return closer.Close()
}
return nil
}

View File

@@ -2,6 +2,7 @@ package internal
import (
"context"
"io"
"github.com/MontFerret/ferret/pkg/runtime"
)
@@ -21,10 +22,6 @@ func NewKeyGroupCollector() Transformer {
}
}
func (c *KeyGroupCollector) Get(_ context.Context, key runtime.Value) (runtime.Value, error) {
return c.grouping[key.String()], nil
}
func (c *KeyGroupCollector) Iterate(ctx context.Context) (runtime.Iterator, error) {
if !c.sorted {
if err := c.sort(ctx); err != nil {
@@ -83,3 +80,31 @@ func (c *KeyGroupCollector) sort(ctx context.Context) error {
return comp
})
}
func (c *KeyGroupCollector) Get(ctx context.Context, key runtime.Value) (runtime.Value, error) {
k, err := Stringify(ctx, key)
if err != nil {
return nil, err
}
v, ok := c.grouping[k]
if !ok {
return runtime.None, runtime.ErrNotFound
}
return v, nil
}
func (c *KeyGroupCollector) Close() error {
val := c.Value
c.Value = nil
c.grouping = nil
if closer := val.(io.Closer); closer != nil {
return closer.Close()
}
return nil
}

View File

@@ -2,9 +2,10 @@ package internal_test
import (
"context"
"github.com/MontFerret/ferret/pkg/vm/internal"
"testing"
"github.com/MontFerret/ferret/pkg/vm/internal"
"github.com/MontFerret/ferret/pkg/runtime"
. "github.com/smartystreets/goconvey/convey"

View File

@@ -2,6 +2,7 @@ package internal
import (
"context"
"io"
"github.com/MontFerret/ferret/pkg/runtime"
)
@@ -57,3 +58,18 @@ func (s *Sorter) sort(ctx context.Context) error {
return -comp
})
}
func (s *Sorter) Get(_ context.Context, _ runtime.Value) (runtime.Value, error) {
return runtime.None, runtime.ErrNotSupported
}
func (s *Sorter) Close() error {
val := s.Value
s.Value = nil
if closer := val.(io.Closer); closer != nil {
return closer.Close()
}
return nil
}

View File

@@ -2,6 +2,7 @@ package internal
import (
"context"
"io"
"github.com/MontFerret/ferret/pkg/runtime"
)
@@ -67,3 +68,18 @@ func (s *MultiSorter) sort(ctx context.Context) error {
return 0
})
}
func (s *MultiSorter) Get(_ context.Context, _ runtime.Value) (runtime.Value, error) {
return runtime.None, runtime.ErrNotSupported
}
func (s *MultiSorter) Close() error {
val := s.Value
s.Value = nil
if closer := val.(io.Closer); closer != nil {
return closer.Close()
}
return nil
}

View File

@@ -2,6 +2,7 @@ package internal
import (
"context"
"io"
"github.com/MontFerret/ferret/pkg/runtime"
)
@@ -9,6 +10,8 @@ import (
type Transformer interface {
runtime.Value
runtime.Iterable
runtime.Keyed
io.Closer
Add(ctx context.Context, key, value runtime.Value) error
}

View File

@@ -192,7 +192,6 @@ loop:
start := int(src1)
end := int(src1) + size
// Iterate over registers starting from src1 and up to the src2
for i := start; i < end; i++ {
_ = arr.Add(ctx, reg[i])
}

View File

@@ -96,7 +96,46 @@ FOR u IN users
`, []any{
map[string]any{"minAge": 25, "maxAge": 69},
}, "Should collect and aggregate values without grouping"),
SkipCaseArray(`
CaseArray(`
LET users = [
{
active: true,
married: true,
age: 31,
gender: "m"
},
{
active: true,
married: false,
age: 25,
gender: "f"
},
{
active: true,
married: false,
age: 36,
gender: "m"
},
{
active: false,
married: true,
age: 69,
gender: "m"
},
{
active: true,
married: true,
age: 45,
gender: "f"
}
]
FOR u IN users
COLLECT AGGREGATE ages = UNION(u.age, u.age)
RETURN { ages }
`, []any{
map[string]any{"ages": []any{31, 25, 36, 69, 45, 31, 25, 36, 69, 45}},
}, "Should call aggregation functions with more than one argument"),
CaseArray(`
LET users = [
{
active: true,
@@ -131,16 +170,15 @@ LET users = [
]
FOR u IN users
COLLECT genderGroup = u.gender
AGGREGATE minAge = MIN(u.age), maxAge = MAX(u.age)
AGGREGATE ages = UNION(u.age, u.age)
RETURN {
genderGroup,
minAge,
maxAge
ages,
}
`, []any{
map[string]any{"genderGroup": "f", "minAge": 25, "maxAge": 45},
map[string]any{"genderGroup": "m", "minAge": 31, "maxAge": 69},
map[string]any{"genderGroup": "f", "ages": []any{25, 45, 25, 45}},
map[string]any{"genderGroup": "m", "ages": []any{31, 36, 69, 31, 36, 69}},
}, "Should collect and aggregate values by a single key"),
})
}