mirror of
https://github.com/MontFerret/ferret.git
synced 2025-08-13 19:52:52 +02:00
Refactor collector and aggregator handling; add Get
and Close
methods to collectors and sorters, restructure aggregation logic with multi-argument support, and update loop compilation to optimize resource management and key-value integration
This commit is contained in:
@@ -2,8 +2,9 @@ package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/MontFerret/ferret/pkg/vm"
|
||||
"strings"
|
||||
|
||||
"github.com/MontFerret/ferret/pkg/vm"
|
||||
)
|
||||
|
||||
type LoopTable struct {
|
||||
|
@@ -1,6 +1,8 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/MontFerret/ferret/pkg/compiler/internal/core"
|
||||
"github.com/MontFerret/ferret/pkg/parser/fql"
|
||||
"github.com/MontFerret/ferret/pkg/runtime"
|
||||
@@ -17,9 +19,11 @@ func (cc *LoopCollectCompiler) compileAggregation(c fql.ICollectAggregatorContex
|
||||
|
||||
func (cc *LoopCollectCompiler) compileGroupedAggregation(c fql.ICollectAggregatorContext) {
|
||||
parentLoop := cc.ctx.Loops.Current()
|
||||
// We need to allocate a temporary accumulators to store aggregation results
|
||||
// We need to allocate a temporary accumulator to store aggregation results
|
||||
selectors := c.AllCollectAggregateSelector()
|
||||
accums := cc.initAggrAccumulators(selectors)
|
||||
accumulator := cc.ctx.Registers.Allocate(core.Temp)
|
||||
cc.ctx.Emitter.EmitAx(vm.OpDataSetCollector, accumulator, int(core.CollectorTypeKeyGroup))
|
||||
|
||||
loop := cc.ctx.Loops.CreateFor(core.TemporalLoop, cc.ctx.Registers.Allocate(core.Temp), false)
|
||||
|
||||
// Now we iterate over the grouped items
|
||||
@@ -31,26 +35,15 @@ func (cc *LoopCollectCompiler) compileGroupedAggregation(c fql.ICollectAggregato
|
||||
loop.EmitInitialization(cc.ctx.Registers, cc.ctx.Emitter)
|
||||
|
||||
// Add value selectors to the accumulators
|
||||
cc.collectAggregationFuncArgs(selectors, func(i int, resultReg vm.Operand) {
|
||||
cc.ctx.Emitter.EmitAB(vm.OpPush, accums[i], resultReg)
|
||||
})
|
||||
argsPkg := cc.compileAggregationFuncArgs(selectors, accumulator)
|
||||
|
||||
loop.EmitFinalization(cc.ctx.Emitter)
|
||||
cc.ctx.Symbols.ExitScope()
|
||||
|
||||
// Now we can iterate over the selectors and execute the aggregation functions by passing the accumulators
|
||||
// And define variables for each accumulator result
|
||||
cc.compileAggregationFuncCall(selectors, func(i int, _ string) core.RegisterSequence {
|
||||
return core.RegisterSequence{accums[i]}
|
||||
}, func(i int) {})
|
||||
|
||||
// Free the registers for accumulators
|
||||
for _, reg := range accums {
|
||||
cc.ctx.Registers.Free(reg)
|
||||
}
|
||||
|
||||
// Free the register for the iterator value
|
||||
// cc.ctx.Registers.Free(aggrIterVal)
|
||||
cc.compileAggregationFuncCall(selectors, accumulator, argsPkg)
|
||||
cc.ctx.Registers.Free(accumulator)
|
||||
}
|
||||
|
||||
func (cc *LoopCollectCompiler) compileGlobalAggregation(c fql.ICollectAggregatorContext) {
|
||||
@@ -60,14 +53,9 @@ func (cc *LoopCollectCompiler) compileGlobalAggregation(c fql.ICollectAggregator
|
||||
|
||||
// Nested scope for aggregators
|
||||
cc.ctx.Symbols.EnterScope()
|
||||
// Now we add value selectors to the accumulators
|
||||
// Now we add value selectors to the collector
|
||||
selectors := c.AllCollectAggregateSelector()
|
||||
cc.collectAggregationFuncArgs(selectors, func(i int, resultReg vm.Operand) {
|
||||
aggrKeyName := selectors[i].Identifier().GetText()
|
||||
aggrKeyReg := loadConstant(cc.ctx, runtime.String(aggrKeyName))
|
||||
cc.ctx.Emitter.EmitABC(vm.OpPushKV, parentLoop.Dst, aggrKeyReg, resultReg)
|
||||
cc.ctx.Registers.Free(aggrKeyReg)
|
||||
})
|
||||
argsPkg := cc.compileAggregationFuncArgs(selectors, parentLoop.Dst)
|
||||
|
||||
parentLoop.EmitFinalization(cc.ctx.Emitter)
|
||||
cc.ctx.Loops.Pop()
|
||||
@@ -90,27 +78,13 @@ func (cc *LoopCollectCompiler) compileGlobalAggregation(c fql.ICollectAggregator
|
||||
loop.EmitInitialization(cc.ctx.Registers, cc.ctx.Emitter)
|
||||
|
||||
// We just need to take the grouped values and call aggregation functions using them as args
|
||||
var key vm.Operand
|
||||
var value vm.Operand
|
||||
cc.compileAggregationFuncCall(selectors, func(i int, selectorVarName string) core.RegisterSequence {
|
||||
// We execute the function call with the accumulator as an argument
|
||||
key = loadConstant(cc.ctx, runtime.String(selectorVarName))
|
||||
value = cc.ctx.Registers.Allocate(core.Temp)
|
||||
cc.ctx.Emitter.EmitABC(vm.OpLoadKey, value, aggregator, key)
|
||||
|
||||
return core.RegisterSequence{value}
|
||||
}, func(_ int) {
|
||||
cc.ctx.Registers.Free(value)
|
||||
cc.ctx.Registers.Free(key)
|
||||
})
|
||||
|
||||
cc.compileAggregationFuncCall(selectors, aggregator, argsPkg)
|
||||
cc.ctx.Registers.Free(aggregator)
|
||||
|
||||
// Free the register for the iterator value
|
||||
// cc.ctx.Registers.Free(aggrIterVal)
|
||||
}
|
||||
|
||||
func (cc *LoopCollectCompiler) collectAggregationFuncArgs(selectors []fql.ICollectAggregateSelectorContext, collector func(int, vm.Operand)) {
|
||||
func (cc *LoopCollectCompiler) compileAggregationFuncArgs(selectors []fql.ICollectAggregateSelectorContext, collector vm.Operand) []int {
|
||||
argsPkg := make([]int, len(selectors))
|
||||
|
||||
for i := 0; i < len(selectors); i++ {
|
||||
selector := selectors[i]
|
||||
fcx := selector.FunctionCallExpression()
|
||||
@@ -121,63 +95,64 @@ func (cc *LoopCollectCompiler) collectAggregationFuncArgs(selectors []fql.IColle
|
||||
panic("No arguments provided for the function call in the aggregate selector")
|
||||
}
|
||||
|
||||
aggrKeyReg := loadConstant(cc.ctx, runtime.Int(i))
|
||||
// we keep information about the args - whether we need to unpack them or not
|
||||
argsPkg[i] = len(args)
|
||||
|
||||
if len(args) > 1 {
|
||||
// TODO: Better error handling
|
||||
panic("Too many arguments")
|
||||
for y, arg := range args {
|
||||
argKeyReg := cc.loadAggregationArgKey(i, y)
|
||||
cc.ctx.Emitter.EmitABC(vm.OpPushKV, collector, argKeyReg, arg)
|
||||
cc.ctx.Registers.Free(argKeyReg)
|
||||
}
|
||||
} else {
|
||||
cc.ctx.Emitter.EmitABC(vm.OpPushKV, collector, aggrKeyReg, args[0])
|
||||
}
|
||||
|
||||
resultReg := args[0]
|
||||
collector(i, resultReg)
|
||||
cc.ctx.Registers.Free(resultReg)
|
||||
cc.ctx.Registers.Free(aggrKeyReg)
|
||||
cc.ctx.Registers.FreeSequence(args)
|
||||
}
|
||||
|
||||
return argsPkg
|
||||
}
|
||||
|
||||
func (cc *LoopCollectCompiler) compileAggregationFuncCall(selectors []fql.ICollectAggregateSelectorContext, provider func(int, string) core.RegisterSequence, cleanup func(int)) {
|
||||
func (cc *LoopCollectCompiler) compileAggregationFuncCall(selectors []fql.ICollectAggregateSelectorContext, accumulator vm.Operand, argsPkg []int) {
|
||||
for i, selector := range selectors {
|
||||
fcx := selector.FunctionCallExpression()
|
||||
// We won't make any checks here, as we already did it before
|
||||
selectorVarName := selector.Identifier().GetText()
|
||||
argsNum := argsPkg[i]
|
||||
|
||||
result := cc.ctx.ExprCompiler.CompileFunctionCallWith(fcx.FunctionCall(), fcx.ErrorOperator() != nil, provider(i, selectorVarName))
|
||||
var args core.RegisterSequence
|
||||
|
||||
// We need to unpack arguments
|
||||
if argsNum > 1 {
|
||||
args = cc.ctx.Registers.AllocateSequence(argsNum)
|
||||
|
||||
for y, reg := range args {
|
||||
argKeyReg := cc.loadAggregationArgKey(i, y)
|
||||
cc.ctx.Emitter.EmitABC(vm.OpLoadKey, reg, accumulator, argKeyReg)
|
||||
|
||||
cc.ctx.Registers.Free(argKeyReg)
|
||||
}
|
||||
} else {
|
||||
key := loadConstant(cc.ctx, runtime.Int(i))
|
||||
value := cc.ctx.Registers.Allocate(core.Temp)
|
||||
cc.ctx.Emitter.EmitABC(vm.OpLoadKey, value, accumulator, key)
|
||||
args = core.RegisterSequence{value}
|
||||
cc.ctx.Registers.Free(key)
|
||||
}
|
||||
|
||||
fcx := selector.FunctionCallExpression()
|
||||
result := cc.ctx.ExprCompiler.CompileFunctionCallWith(fcx.FunctionCall(), fcx.ErrorOperator() != nil, args)
|
||||
|
||||
// We define the variable for the selector result in the upper scope
|
||||
// Since this temporary scope is only for aggregators and will be closed after the aggregation
|
||||
selectorVarName := selector.Identifier().GetText()
|
||||
varReg := cc.ctx.Symbols.DeclareLocal(selectorVarName)
|
||||
cc.ctx.Emitter.EmitAB(vm.OpMove, varReg, result)
|
||||
cc.ctx.Registers.Free(result)
|
||||
|
||||
cleanup(i)
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *LoopCollectCompiler) initAggrAccumulators(selectors []fql.ICollectAggregateSelectorContext) []vm.Operand {
|
||||
accums := make([]vm.Operand, len(selectors))
|
||||
|
||||
// First of all, we allocate registers for accumulators
|
||||
accums = make([]vm.Operand, len(selectors))
|
||||
|
||||
// We need to allocate a register for each accumulator
|
||||
for i := 0; i < len(selectors); i++ {
|
||||
reg := cc.ctx.Registers.Allocate(core.Temp)
|
||||
accums[i] = reg
|
||||
|
||||
// TODO: Select persistent List type, we do not know how many items we will have
|
||||
cc.ctx.Emitter.EmitA(vm.OpList, reg)
|
||||
}
|
||||
|
||||
return accums
|
||||
}
|
||||
|
||||
func (cc *LoopCollectCompiler) emitPushToAggrAccumulators(accums []vm.Operand, selectors []fql.ICollectAggregateSelectorContext, loop *core.Loop) {
|
||||
for i, selector := range selectors {
|
||||
fcx := selector.FunctionCallExpression()
|
||||
args := cc.ctx.ExprCompiler.CompileArgumentList(fcx.FunctionCall().ArgumentList())
|
||||
|
||||
if len(args) != 1 {
|
||||
panic("aggregate function must have exactly one argument")
|
||||
}
|
||||
|
||||
cc.ctx.Emitter.EmitAB(vm.OpPush, accums[i], args[0])
|
||||
cc.ctx.Registers.Free(args[0])
|
||||
}
|
||||
func (cc *LoopCollectCompiler) loadAggregationArgKey(selector int, arg int) vm.Operand {
|
||||
argKey := strconv.Itoa(selector) + ":" + strconv.Itoa(arg)
|
||||
return loadConstant(cc.ctx, runtime.String(argKey))
|
||||
}
|
@@ -28,3 +28,11 @@ func (c *CounterCollector) Add(_ context.Context, _, _ runtime.Value) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CounterCollector) Get(_ context.Context, _ runtime.Value) (runtime.Value, error) {
|
||||
return c.Value, nil
|
||||
}
|
||||
|
||||
func (c *CounterCollector) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/MontFerret/ferret/pkg/runtime"
|
||||
)
|
||||
@@ -52,3 +53,31 @@ func (c *KeyCollector) Add(ctx context.Context, key, _ runtime.Value) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *KeyCollector) Get(ctx context.Context, key runtime.Value) (runtime.Value, error) {
|
||||
k, err := Stringify(ctx, key)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v, ok := c.grouping[k]
|
||||
|
||||
if !ok {
|
||||
return runtime.None, runtime.ErrNotFound
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (c *KeyCollector) Close() error {
|
||||
val := c.Value
|
||||
c.Value = nil
|
||||
c.grouping = nil
|
||||
|
||||
if closer := val.(io.Closer); closer != nil {
|
||||
return closer.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/MontFerret/ferret/pkg/runtime"
|
||||
)
|
||||
@@ -101,3 +102,31 @@ func (c *KeyCounterCollector) Add(ctx context.Context, key, _ runtime.Value) err
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *KeyCounterCollector) Get(ctx context.Context, key runtime.Value) (runtime.Value, error) {
|
||||
k, err := Stringify(ctx, key)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v, ok := c.grouping[k]
|
||||
|
||||
if !ok {
|
||||
return runtime.None, runtime.ErrNotFound
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (c *KeyCounterCollector) Close() error {
|
||||
val := c.Value
|
||||
c.Value = nil
|
||||
c.grouping = nil
|
||||
|
||||
if closer := val.(io.Closer); closer != nil {
|
||||
return closer.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/MontFerret/ferret/pkg/runtime"
|
||||
)
|
||||
@@ -21,10 +22,6 @@ func NewKeyGroupCollector() Transformer {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *KeyGroupCollector) Get(_ context.Context, key runtime.Value) (runtime.Value, error) {
|
||||
return c.grouping[key.String()], nil
|
||||
}
|
||||
|
||||
func (c *KeyGroupCollector) Iterate(ctx context.Context) (runtime.Iterator, error) {
|
||||
if !c.sorted {
|
||||
if err := c.sort(ctx); err != nil {
|
||||
@@ -83,3 +80,31 @@ func (c *KeyGroupCollector) sort(ctx context.Context) error {
|
||||
return comp
|
||||
})
|
||||
}
|
||||
|
||||
func (c *KeyGroupCollector) Get(ctx context.Context, key runtime.Value) (runtime.Value, error) {
|
||||
k, err := Stringify(ctx, key)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v, ok := c.grouping[k]
|
||||
|
||||
if !ok {
|
||||
return runtime.None, runtime.ErrNotFound
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (c *KeyGroupCollector) Close() error {
|
||||
val := c.Value
|
||||
c.Value = nil
|
||||
c.grouping = nil
|
||||
|
||||
if closer := val.(io.Closer); closer != nil {
|
||||
return closer.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -2,9 +2,10 @@ package internal_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/MontFerret/ferret/pkg/vm/internal"
|
||||
"testing"
|
||||
|
||||
"github.com/MontFerret/ferret/pkg/vm/internal"
|
||||
|
||||
"github.com/MontFerret/ferret/pkg/runtime"
|
||||
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
|
@@ -2,6 +2,7 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/MontFerret/ferret/pkg/runtime"
|
||||
)
|
||||
@@ -57,3 +58,18 @@ func (s *Sorter) sort(ctx context.Context) error {
|
||||
return -comp
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Sorter) Get(_ context.Context, _ runtime.Value) (runtime.Value, error) {
|
||||
return runtime.None, runtime.ErrNotSupported
|
||||
}
|
||||
|
||||
func (s *Sorter) Close() error {
|
||||
val := s.Value
|
||||
s.Value = nil
|
||||
|
||||
if closer := val.(io.Closer); closer != nil {
|
||||
return closer.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/MontFerret/ferret/pkg/runtime"
|
||||
)
|
||||
@@ -67,3 +68,18 @@ func (s *MultiSorter) sort(ctx context.Context) error {
|
||||
return 0
|
||||
})
|
||||
}
|
||||
|
||||
func (s *MultiSorter) Get(_ context.Context, _ runtime.Value) (runtime.Value, error) {
|
||||
return runtime.None, runtime.ErrNotSupported
|
||||
}
|
||||
|
||||
func (s *MultiSorter) Close() error {
|
||||
val := s.Value
|
||||
s.Value = nil
|
||||
|
||||
if closer := val.(io.Closer); closer != nil {
|
||||
return closer.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/MontFerret/ferret/pkg/runtime"
|
||||
)
|
||||
@@ -9,6 +10,8 @@ import (
|
||||
type Transformer interface {
|
||||
runtime.Value
|
||||
runtime.Iterable
|
||||
runtime.Keyed
|
||||
io.Closer
|
||||
|
||||
Add(ctx context.Context, key, value runtime.Value) error
|
||||
}
|
||||
|
@@ -192,7 +192,6 @@ loop:
|
||||
start := int(src1)
|
||||
end := int(src1) + size
|
||||
|
||||
// Iterate over registers starting from src1 and up to the src2
|
||||
for i := start; i < end; i++ {
|
||||
_ = arr.Add(ctx, reg[i])
|
||||
}
|
||||
|
@@ -96,7 +96,46 @@ FOR u IN users
|
||||
`, []any{
|
||||
map[string]any{"minAge": 25, "maxAge": 69},
|
||||
}, "Should collect and aggregate values without grouping"),
|
||||
SkipCaseArray(`
|
||||
CaseArray(`
|
||||
LET users = [
|
||||
{
|
||||
active: true,
|
||||
married: true,
|
||||
age: 31,
|
||||
gender: "m"
|
||||
},
|
||||
{
|
||||
active: true,
|
||||
married: false,
|
||||
age: 25,
|
||||
gender: "f"
|
||||
},
|
||||
{
|
||||
active: true,
|
||||
married: false,
|
||||
age: 36,
|
||||
gender: "m"
|
||||
},
|
||||
{
|
||||
active: false,
|
||||
married: true,
|
||||
age: 69,
|
||||
gender: "m"
|
||||
},
|
||||
{
|
||||
active: true,
|
||||
married: true,
|
||||
age: 45,
|
||||
gender: "f"
|
||||
}
|
||||
]
|
||||
FOR u IN users
|
||||
COLLECT AGGREGATE ages = UNION(u.age, u.age)
|
||||
RETURN { ages }
|
||||
`, []any{
|
||||
map[string]any{"ages": []any{31, 25, 36, 69, 45, 31, 25, 36, 69, 45}},
|
||||
}, "Should call aggregation functions with more than one argument"),
|
||||
CaseArray(`
|
||||
LET users = [
|
||||
{
|
||||
active: true,
|
||||
@@ -131,16 +170,15 @@ LET users = [
|
||||
]
|
||||
FOR u IN users
|
||||
COLLECT genderGroup = u.gender
|
||||
AGGREGATE minAge = MIN(u.age), maxAge = MAX(u.age)
|
||||
AGGREGATE ages = UNION(u.age, u.age)
|
||||
|
||||
RETURN {
|
||||
genderGroup,
|
||||
minAge,
|
||||
maxAge
|
||||
ages,
|
||||
}
|
||||
`, []any{
|
||||
map[string]any{"genderGroup": "f", "minAge": 25, "maxAge": 45},
|
||||
map[string]any{"genderGroup": "m", "minAge": 31, "maxAge": 69},
|
||||
map[string]any{"genderGroup": "f", "ages": []any{25, 45, 25, 45}},
|
||||
map[string]any{"genderGroup": "m", "ages": []any{31, 36, 69, 31, 36, 69}},
|
||||
}, "Should collect and aggregate values by a single key"),
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user