From 87ace590bf21c52710bba6648448e2a30c9729e2 Mon Sep 17 00:00:00 2001 From: Tim Voronov Date: Thu, 19 Jun 2025 16:54:13 -0400 Subject: [PATCH] Refactor loop compilation; introduce `LoopSortCompiler` and `LoopCollectCompiler` to unify sort and collect logic, restructure loop methods, and improve clarity --- pkg/compiler/internal/context.go | 16 +- pkg/compiler/internal/core/emitter.go | 4 + pkg/compiler/internal/core/kv.go | 8 + pkg/compiler/internal/core/loop.go | 11 + pkg/compiler/internal/core/loops.go | 47 ++-- pkg/compiler/internal/expr.go | 2 +- pkg/compiler/internal/loop.go | 222 ++++-------------- pkg/compiler/internal/loop_collect.go | 61 +++-- pkg/compiler/internal/loop_sort.go | 82 +++++++ .../bytecode/bytecode_for_collect_agg_test.go | 13 +- .../integration/bytecode/bytecode_for_test.go | 18 ++ test/integration/vm/vm_for_test.go | 5 +- 12 files changed, 253 insertions(+), 236 deletions(-) create mode 100644 pkg/compiler/internal/core/kv.go create mode 100644 pkg/compiler/internal/loop_sort.go create mode 100644 test/integration/bytecode/bytecode_for_test.go diff --git a/pkg/compiler/internal/context.go b/pkg/compiler/internal/context.go index 43c06433..8c62ac00 100644 --- a/pkg/compiler/internal/context.go +++ b/pkg/compiler/internal/context.go @@ -10,12 +10,13 @@ type CompilerContext struct { Loops *core.LoopTable CatchTable *core.CatchStack - ExprCompiler *ExprCompiler - LiteralCompiler *LiteralCompiler - StmtCompiler *StmtCompiler - LoopCompiler *LoopCompiler - CollectCompiler *CollectCompiler - WaitCompiler *WaitCompiler + ExprCompiler *ExprCompiler + LiteralCompiler *LiteralCompiler + StmtCompiler *StmtCompiler + LoopCompiler *LoopCompiler + LoopSortCompiler *LoopSortCompiler + LoopCollectCompiler *LoopCollectCompiler + WaitCompiler *WaitCompiler } // NewCompilerContext initializes a new CompilerContext with default values. @@ -34,7 +35,8 @@ func NewCompilerContext() *CompilerContext { ctx.LiteralCompiler = NewLiteralCompiler(ctx) ctx.StmtCompiler = NewStmtCompiler(ctx) ctx.LoopCompiler = NewLoopCompiler(ctx) - ctx.CollectCompiler = NewCollectCompiler(ctx) + ctx.LoopSortCompiler = NewLoopSortCompiler(ctx) + ctx.LoopCollectCompiler = NewCollectCompiler(ctx) ctx.WaitCompiler = NewWaitCompiler(ctx) return ctx diff --git a/pkg/compiler/internal/core/emitter.go b/pkg/compiler/internal/core/emitter.go index d51a8ece..3720f066 100644 --- a/pkg/compiler/internal/core/emitter.go +++ b/pkg/compiler/internal/core/emitter.go @@ -23,6 +23,10 @@ func (e *Emitter) Size() int { return len(e.instructions) } +func (e *Emitter) Position() int { + return len(e.instructions) - 1 +} + // PatchSwapAB modifies an instruction at the given position to swap operands and update its operation and destination. func (e *Emitter) PatchSwapAB(pos int, op vm.Opcode, dst, src1 vm.Operand) { e.instructions[pos] = vm.Instruction{ diff --git a/pkg/compiler/internal/core/kv.go b/pkg/compiler/internal/core/kv.go new file mode 100644 index 00000000..8bd597a4 --- /dev/null +++ b/pkg/compiler/internal/core/kv.go @@ -0,0 +1,8 @@ +package core + +import "github.com/MontFerret/ferret/pkg/vm" + +type KV struct { + Key vm.Operand + Value vm.Operand +} diff --git a/pkg/compiler/internal/core/loop.go b/pkg/compiler/internal/core/loop.go index 3a88d80d..637ad592 100644 --- a/pkg/compiler/internal/core/loop.go +++ b/pkg/compiler/internal/core/loop.go @@ -64,11 +64,22 @@ func (l *Loop) DeclareValueVar(name string, st *SymbolTable) { } func (l *Loop) EmitInitialization(alloc *RegisterAllocator, emitter *Emitter) { + if l.Allocate { + emitter.EmitAb(vm.OpDataSet, l.Result, l.Distinct) + l.ResultPos = emitter.Position() + } + if l.Iterator == vm.NoopOperand { l.Iterator = alloc.Allocate(Temp) } emitter.EmitIter(l.Iterator, l.Src) + + // JumpPlaceholder is a placeholder for the exit jump position + l.Jump = emitter.EmitJumpc(vm.OpIterNext, JumpPlaceholder, l.Iterator) + + l.BindValueVar(emitter) + l.BindKeyVar(emitter) } func (l *Loop) EmitNext(emitter *Emitter) { diff --git a/pkg/compiler/internal/core/loops.go b/pkg/compiler/internal/core/loops.go index 47df8352..8b0c41b7 100644 --- a/pkg/compiler/internal/core/loops.go +++ b/pkg/compiler/internal/core/loops.go @@ -2,9 +2,8 @@ package core import ( "fmt" - "strings" - "github.com/MontFerret/ferret/pkg/vm" + "strings" ) type LoopTable struct { @@ -19,6 +18,30 @@ func NewLoopTable(registers *RegisterAllocator) *LoopTable { } } +func (lt *LoopTable) Create(loopType LoopType, kind LoopKind, distinct bool) *Loop { + parent := lt.Current() + allocate := parent == nil || parent.Type != PassThroughLoop + result := vm.NoopOperand + + if allocate && loopType != TemporalLoop { + result = lt.registers.Allocate(Result) + } else if parent != nil { + result = parent.Result + } + + loop := &Loop{ + Type: loopType, + Kind: kind, + Distinct: distinct, + Result: result, + Allocate: allocate, + } + + lt.Push(loop) + + return loop +} + func (lt *LoopTable) Push(loop *Loop) { lt.stack = append(lt.stack, loop) } @@ -43,26 +66,6 @@ func (lt *LoopTable) Depth() int { return len(lt.stack) } -func (lt *LoopTable) NewLoop(loopType LoopType, kind LoopKind, distinct bool) *Loop { - parent := lt.Current() - allocate := parent == nil || parent.Type != PassThroughLoop - result := vm.NoopOperand - - if allocate && loopType != TemporalLoop { - result = lt.registers.Allocate(Result) - } else if parent != nil { - result = parent.Result - } - - return &Loop{ - Type: loopType, - Kind: kind, - Distinct: distinct, - Result: result, - Allocate: allocate, - } -} - func (lt *LoopTable) DebugView() string { var out strings.Builder for i, loop := range lt.stack { diff --git a/pkg/compiler/internal/expr.go b/pkg/compiler/internal/expr.go index d352ae22..98a88574 100644 --- a/pkg/compiler/internal/expr.go +++ b/pkg/compiler/internal/expr.go @@ -456,7 +456,7 @@ func (ec *ExprCompiler) CompileArgumentList(ctx fql.IArgumentListContext) core.R // TODO: Figure out how to remove OpMove and use Registers returned from each expression // The reason we move is that the argument list must be a contiguous sequence of registers - // Otherwise, we cannot initialize neither a list nor an object literal with arguments + // Otherwise, we cannot compileInitialization neither a list nor an object literal with arguments ec.ctx.Emitter.EmitMove(seq[i], srcReg) // Free source register if temporary diff --git a/pkg/compiler/internal/loop.go b/pkg/compiler/internal/loop.go index fc6836e1..b5c497f9 100644 --- a/pkg/compiler/internal/loop.go +++ b/pkg/compiler/internal/loop.go @@ -18,6 +18,23 @@ func NewLoopCompiler(ctx *CompilerContext) *LoopCompiler { } func (lc *LoopCompiler) Compile(ctx fql.IForExpressionContext) vm.Operand { + returnRuleCtx := lc.compileInitialization(ctx) + + // body + if body := ctx.AllForExpressionBody(); body != nil && len(body) > 0 { + for _, b := range body { + if c := b.ForExpressionStatement(); c != nil { + lc.compileForExpressionStatement(c) + } else if c := b.ForExpressionClause(); c != nil { + lc.compileForExpressionClause(c) + } + } + } + + return lc.compileFinalization(returnRuleCtx) +} + +func (lc *LoopCompiler) compileInitialization(ctx fql.IForExpressionContext) antlr.RuleContext { var distinct bool var returnRuleCtx antlr.RuleContext var loopType core.LoopType @@ -32,12 +49,11 @@ func (lc *LoopCompiler) Compile(ctx fql.IForExpressionContext) vm.Operand { loopType = core.PassThroughLoop } - loop := lc.ctx.Loops.NewLoop(loopType, core.ForLoop, distinct) + loop := lc.ctx.Loops.Create(loopType, core.ForLoop, distinct) lc.ctx.Symbols.EnterScope() - lc.ctx.Loops.Push(loop) if loop.Kind == core.ForLoop { - loop.Src = lc.CompileForExpressionSource(ctx.ForExpressionSource()) + loop.Src = lc.compileForExpressionSource(ctx.ForExpressionSource()) if val := ctx.GetValueVariable(); val != nil { loop.DeclareValueVar(val.GetText(), lc.ctx.Symbols) @@ -49,42 +65,36 @@ func (lc *LoopCompiler) Compile(ctx fql.IForExpressionContext) vm.Operand { } else { } - lc.EmitLoopBegin(loop) + loop.EmitInitialization(lc.ctx.Registers, lc.ctx.Emitter) - // body - if body := ctx.AllForExpressionBody(); body != nil && len(body) > 0 { - for _, b := range body { - if c := b.ForExpressionStatement(); c != nil { - lc.CompileForExpressionStatement(c) - } else if c := b.ForExpressionClause(); c != nil { - lc.CompileForExpressionClause(c) - } - } - } + return returnRuleCtx +} - loop = lc.ctx.Loops.Current() +func (lc *LoopCompiler) compileFinalization(ctx antlr.RuleContext) vm.Operand { + loop := lc.ctx.Loops.Current() // RETURN if loop.Type != core.PassThroughLoop { - c := returnRuleCtx.(*fql.ReturnExpressionContext) + c := ctx.(*fql.ReturnExpressionContext) expReg := lc.ctx.ExprCompiler.Compile(c.Expression()) lc.ctx.Emitter.EmitAB(vm.OpPush, loop.Result, expReg) - } else if returnRuleCtx != nil { - if c, ok := returnRuleCtx.(*fql.ForExpressionContext); ok { + } else if ctx != nil { + if c, ok := ctx.(*fql.ForExpressionContext); ok { lc.Compile(c) } } - res := lc.EmitLoopEnd(loop) - + loop.EmitFinalization(lc.ctx.Emitter) lc.ctx.Symbols.ExitScope() lc.ctx.Loops.Pop() - return res + // TODO: Free operands + + return loop.Result } -func (lc *LoopCompiler) CompileForExpressionSource(ctx fql.IForExpressionSourceContext) vm.Operand { +func (lc *LoopCompiler) compileForExpressionSource(ctx fql.IForExpressionSourceContext) vm.Operand { if c := ctx.FunctionCallExpression(); c != nil { return lc.ctx.ExprCompiler.CompileFunctionCallExpression(c) } @@ -116,40 +126,40 @@ func (lc *LoopCompiler) CompileForExpressionSource(ctx fql.IForExpressionSourceC panic(runtime.Error(core.ErrUnexpectedToken, ctx.GetText())) } -func (lc *LoopCompiler) CompileForExpressionStatement(ctx fql.IForExpressionStatementContext) { +func (lc *LoopCompiler) compileForExpressionStatement(ctx fql.IForExpressionStatementContext) { if c := ctx.VariableDeclaration(); c != nil { _ = lc.ctx.StmtCompiler.CompileVariableDeclaration(c) } else if c := ctx.FunctionCallExpression(); c != nil { _ = lc.ctx.ExprCompiler.CompileFunctionCallExpression(c) - - // TODO: Free register if needed } + + // TODO: Free register if needed } -func (lc *LoopCompiler) CompileForExpressionClause(ctx fql.IForExpressionClauseContext) { +func (lc *LoopCompiler) compileForExpressionClause(ctx fql.IForExpressionClauseContext) { if c := ctx.LimitClause(); c != nil { - lc.CompileLimitClause(c) + lc.compileLimitClause(c) } else if c := ctx.FilterClause(); c != nil { - lc.CompileFilterClause(c) + lc.compileFilterClause(c) } else if c := ctx.SortClause(); c != nil { - lc.CompileSortClause(c) + lc.compileSortClause(c) } else if c := ctx.CollectClause(); c != nil { - lc.CompileCollectClause(c) + lc.compileCollectClause(c) } } -func (lc *LoopCompiler) CompileLimitClause(ctx fql.ILimitClauseContext) { +func (lc *LoopCompiler) compileLimitClause(ctx fql.ILimitClauseContext) { clauses := ctx.AllLimitClauseValue() if len(clauses) == 1 { - lc.CompileLimit(lc.CompileLimitClauseValue(clauses[0])) + lc.compileLimit(lc.compileLimitClauseValue(clauses[0])) } else { - lc.CompileOffset(lc.CompileLimitClauseValue(clauses[0])) - lc.CompileLimit(lc.CompileLimitClauseValue(clauses[1])) + lc.compileOffset(lc.compileLimitClauseValue(clauses[0])) + lc.compileLimit(lc.compileLimitClauseValue(clauses[1])) } } -func (lc *LoopCompiler) CompileLimitClauseValue(ctx fql.ILimitClauseValueContext) vm.Operand { +func (lc *LoopCompiler) compileLimitClauseValue(ctx fql.ILimitClauseValueContext) vm.Operand { if c := ctx.Param(); c != nil { return lc.ctx.ExprCompiler.CompileParam(c) } @@ -174,153 +184,25 @@ func (lc *LoopCompiler) CompileLimitClauseValue(ctx fql.ILimitClauseValueContext } -func (lc *LoopCompiler) CompileLimit(src vm.Operand) { +func (lc *LoopCompiler) compileLimit(src vm.Operand) { state := lc.ctx.Registers.Allocate(core.State) lc.ctx.Emitter.EmitABx(vm.OpIterLimit, state, src, lc.ctx.Loops.Current().Jump) } -func (lc *LoopCompiler) CompileOffset(src vm.Operand) { +func (lc *LoopCompiler) compileOffset(src vm.Operand) { state := lc.ctx.Registers.Allocate(core.State) lc.ctx.Emitter.EmitABx(vm.OpIterSkip, state, src, lc.ctx.Loops.Current().Jump) } -func (lc *LoopCompiler) CompileFilterClause(ctx fql.IFilterClauseContext) { +func (lc *LoopCompiler) compileFilterClause(ctx fql.IFilterClauseContext) { src := lc.ctx.ExprCompiler.Compile(ctx.Expression()) lc.ctx.Emitter.EmitJumpIfFalse(src, lc.ctx.Loops.Current().Jump) } -func (lc *LoopCompiler) CompileSortClause(ctx fql.ISortClauseContext) { - loop := lc.ctx.Loops.Current() - - // We collect the sorting conditions (keys - // And wrap each loop element by a KeyValuePair - // Where a key is either a single value or a list of values - // These KeyValuePairs are then added to the dataset - kvKeyReg := lc.ctx.Registers.Allocate(core.Temp) - clauses := ctx.AllSortClauseExpression() - var directions []runtime.SortDirection - isSortMany := len(clauses) > 1 - - if isSortMany { - clausesRegs := make([]vm.Operand, len(clauses)) - directions = make([]runtime.SortDirection, len(clauses)) - // We create a sequence of Registers for the clauses - // To pack them into an array - keyRegs := lc.ctx.Registers.AllocateSequence(len(clauses)) - - for i, clause := range clauses { - clauseReg := lc.ctx.ExprCompiler.Compile(clause.Expression()) - lc.ctx.Emitter.EmitMove(keyRegs[i], clauseReg) - clausesRegs[i] = keyRegs[i] - directions[i] = sortDirection(clause.SortDirection()) - // TODO: Free Registers - } - - arrReg := lc.ctx.Registers.Allocate(core.Temp) - lc.ctx.Emitter.EmitAs(vm.OpList, arrReg, keyRegs) - lc.ctx.Emitter.EmitAB(vm.OpMove, kvKeyReg, arrReg) // TODO: Free Registers - } else { - clausesReg := lc.ctx.ExprCompiler.Compile(clauses[0].Expression()) - lc.ctx.Emitter.EmitAB(vm.OpMove, kvKeyReg, clausesReg) - } - - var kvValReg vm.Operand - - // In case the value is not used in the loop body, and only key is used - if loop.ValueName != "" { - kvValReg = loop.Value - } else { - // If so, we need to load it from the iterator - kvValReg = lc.ctx.Registers.Allocate(core.Temp) - loop.EmitValue(kvKeyReg, lc.ctx.Emitter) - } - - if isSortMany { - encoded := runtime.EncodeSortDirections(directions) - count := len(clauses) - - lc.ctx.Emitter.PatchSwapAxy(loop.ResultPos, vm.OpDataSetMultiSorter, loop.Result, encoded, count) - } else { - dir := sortDirection(clauses[0].SortDirection()) - lc.ctx.Emitter.PatchSwapAx(loop.ResultPos, vm.OpDataSetSorter, loop.Result, int(dir)) - } - - lc.ctx.Emitter.EmitABC(vm.OpPushKV, loop.Result, kvKeyReg, kvValReg) - loop.EmitFinalization(lc.ctx.Emitter) - - // Replace source with the Sorter - lc.ctx.Emitter.EmitAB(vm.OpMove, loop.Src, loop.Result) - - // Create a new loop - lc.EmitLoopBegin(loop) +func (lc *LoopCompiler) compileSortClause(ctx fql.ISortClauseContext) { + lc.ctx.LoopSortCompiler.Compile(ctx) } -func (lc *LoopCompiler) CompileCollectClause(ctx fql.ICollectClauseContext) { - lc.ctx.CollectCompiler.Compile(ctx) -} - -// EmitLoopBegin emits an instruction to get the value from the iterator -func (lc *LoopCompiler) EmitLoopBegin(loop *core.Loop) { - if loop.Allocate { - lc.ctx.Emitter.EmitAb(vm.OpDataSet, loop.Result, loop.Distinct) - loop.ResultPos = lc.ctx.Emitter.Size() - 1 - } - - loop.Iterator = lc.ctx.Registers.Allocate(core.State) - - if loop.Kind == core.ForLoop { - lc.ctx.Emitter.EmitAB(vm.OpIter, loop.Iterator, loop.Src) - // core.JumpPlaceholder is a placeholder for the exit jump position - loop.Jump = lc.ctx.Emitter.EmitJumpc(vm.OpIterNext, core.JumpPlaceholder, loop.Iterator) - - if loop.Value != vm.NoopOperand { - lc.ctx.Emitter.EmitAB(vm.OpIterValue, loop.Value, loop.Iterator) - } - - if loop.Key != vm.NoopOperand { - lc.ctx.Emitter.EmitAB(vm.OpIterKey, loop.Key, loop.Iterator) - } - } else { - //counterReg := lc.ctx.Registers.Allocate(Storage) - // TODO: Set JumpOffset here - } -} - -func (lc *LoopCompiler) EmitLoopEnd(loop *core.Loop) vm.Operand { - lc.ctx.Emitter.EmitJump(loop.Jump - loop.JumpOffset) - - // TODO: Do not allocate for pass-through Loops - dst := lc.ctx.Registers.Allocate(core.Temp) - - if loop.Allocate { - // TODO: Reuse the dsReg register - lc.ctx.Emitter.EmitA(vm.OpClose, loop.Iterator) - lc.ctx.Emitter.EmitAB(vm.OpMove, dst, loop.Result) - - if loop.Kind == core.ForLoop { - lc.ctx.Emitter.PatchJump(loop.Jump) - } else { - lc.ctx.Emitter.PatchJumpAB(loop.Jump) - } - } else { - if loop.Kind == core.ForLoop { - lc.ctx.Emitter.PatchJumpNext(loop.Jump) - } else { - lc.ctx.Emitter.PatchJumpNextAB(loop.Jump) - } - } - - return dst -} - -func (lc *LoopCompiler) loopKind(ctx *fql.ForExpressionContext) core.LoopKind { - if ctx.While() == nil { - return core.ForLoop - } - - if ctx.Do() == nil { - return core.WhileLoop - } - - return core.DoWhileLoop +func (lc *LoopCompiler) compileCollectClause(ctx fql.ICollectClauseContext) { + lc.ctx.LoopCollectCompiler.Compile(ctx) } diff --git a/pkg/compiler/internal/loop_collect.go b/pkg/compiler/internal/loop_collect.go index e76b9bc6..a4db0b0f 100644 --- a/pkg/compiler/internal/loop_collect.go +++ b/pkg/compiler/internal/loop_collect.go @@ -9,15 +9,15 @@ import ( "github.com/MontFerret/ferret/pkg/vm" ) -type CollectCompiler struct { +type LoopCollectCompiler struct { ctx *CompilerContext } -func NewCollectCompiler(ctx *CompilerContext) *CollectCompiler { - return &CollectCompiler{ctx: ctx} +func NewCollectCompiler(ctx *CompilerContext) *LoopCollectCompiler { + return &LoopCollectCompiler{ctx: ctx} } -func (cc *CollectCompiler) Compile(ctx fql.ICollectClauseContext) { +func (cc *LoopCollectCompiler) Compile(ctx fql.ICollectClauseContext) { aggregator := ctx.CollectAggregator() kvKeyReg, kvValReg, groupSelectors := cc.compileCollect(ctx, aggregator != nil) @@ -32,7 +32,7 @@ func (cc *CollectCompiler) Compile(ctx fql.ICollectClauseContext) { } } -func (cc *CollectCompiler) compileCollect(ctx fql.ICollectClauseContext, aggregation bool) (vm.Operand, vm.Operand, []fql.ICollectSelectorContext) { +func (cc *LoopCollectCompiler) compileCollect(ctx fql.ICollectClauseContext, aggregation bool) (vm.Operand, vm.Operand, []fql.ICollectSelectorContext) { var kvKeyReg, kvValReg vm.Operand var groupSelectors []fql.ICollectSelectorContext grouping := ctx.CollectGrouping() @@ -96,13 +96,15 @@ func (cc *CollectCompiler) compileCollect(ctx fql.ICollectClauseContext, aggrega if projectionVariableName != "" { // Now we need to expand group variables from the dataset loop.DeclareValueVar(projectionVariableName, cc.ctx.Symbols) + loop.EmitInitialization(cc.ctx.Registers, cc.ctx.Emitter) - cc.ctx.LoopCompiler.EmitLoopBegin(loop) + //cc.ctx.LoopCompiler.EmitLoopBegin(loop) loop.EmitKey(kvValReg, cc.ctx.Emitter) loop.BindValueVar(cc.ctx.Emitter) } else { - cc.ctx.LoopCompiler.EmitLoopBegin(loop) + //cc.ctx.LoopCompiler.EmitLoopBegin(loop) + loop.EmitInitialization(cc.ctx.Registers, cc.ctx.Emitter) loop.EmitKey(kvKeyReg, cc.ctx.Emitter) //loop.EmitValue(kvValReg, cc.ctx.Emitter) @@ -111,7 +113,7 @@ func (cc *CollectCompiler) compileCollect(ctx fql.ICollectClauseContext, aggrega return kvKeyReg, kvValReg, groupSelectors } -func (cc *CollectCompiler) compileAggregation(c fql.ICollectAggregatorContext, isGrouped bool) { +func (cc *LoopCollectCompiler) compileAggregation(c fql.ICollectAggregatorContext, isGrouped bool) { if isGrouped { cc.compileGroupedAggregation(c) } else { @@ -119,12 +121,12 @@ func (cc *CollectCompiler) compileAggregation(c fql.ICollectAggregatorContext, i } } -func (cc *CollectCompiler) compileGroupedAggregation(c fql.ICollectAggregatorContext) { +func (cc *LoopCollectCompiler) compileGroupedAggregation(c fql.ICollectAggregatorContext) { parentLoop := cc.ctx.Loops.Current() // We need to allocate a temporary accumulators to store aggregation results selectors := c.AllCollectAggregateSelector() accums := cc.initAggrAccumulators(selectors) - loop := cc.ctx.Loops.NewLoop(core.TemporalLoop, core.ForLoop, false) + loop := cc.ctx.Loops.Create(core.TemporalLoop, core.ForLoop, false) loop.Src = cc.ctx.Registers.Allocate(core.Temp) // Now we iterate over the grouped items @@ -163,7 +165,7 @@ func (cc *CollectCompiler) compileGroupedAggregation(c fql.ICollectAggregatorCon // cc.ctx.Registers.Free(aggrIterVal) } -func (cc *CollectCompiler) compileGlobalAggregation(c fql.ICollectAggregatorContext) { +func (cc *LoopCollectCompiler) compileGlobalAggregation(c fql.ICollectAggregatorContext) { parentLoop := cc.ctx.Loops.Current() loop := parentLoop // we create a custom collector for aggregators @@ -171,9 +173,6 @@ func (cc *CollectCompiler) compileGlobalAggregation(c fql.ICollectAggregatorCont // Nested scope for aggregators cc.ctx.Symbols.EnterScope() - loop.DeclareValueVar(loop.ValueName, cc.ctx.Symbols) - loop.BindValueVar(cc.ctx.Emitter) - // Now we add value selectors to the accumulators selectors := c.AllCollectAggregateSelector() cc.collectAggregationFuncArgs(selectors, func(i int, resultReg vm.Operand) { @@ -183,32 +182,26 @@ func (cc *CollectCompiler) compileGlobalAggregation(c fql.ICollectAggregatorCont cc.ctx.Registers.Free(aggrKeyReg) }) - // Now we can iterate over the grouped items loop.EmitFinalization(cc.ctx.Emitter) - - // Now close the aggregators scope cc.ctx.Symbols.ExitScope() parentLoop.ValueName = "" parentLoop.KeyName = "" - // Since we are in the middle of the loop, we need to patch the loop result - // Now we just create a range with 1 item to push the aggregated values to the dataset - // Replace source with sorted array + // Now we can iterate over the grouped items zero := loadConstant(cc.ctx, runtime.Int(0)) one := loadConstant(cc.ctx, runtime.Int(1)) + // We move the aggregator to a temporary register to access it later from the new loop aggregator := cc.ctx.Registers.Allocate(core.Temp) cc.ctx.Emitter.EmitAB(vm.OpMove, aggregator, loop.Result) - cc.ctx.Symbols.ExitScope() + // Create new loop with 1 iteration only cc.ctx.Symbols.EnterScope() - - // Create new for loop cc.ctx.Emitter.EmitABC(vm.OpRange, loop.Src, zero, one) cc.ctx.Emitter.EmitAb(vm.OpDataSet, loop.Result, loop.Distinct) + loop.EmitInitialization(cc.ctx.Registers, cc.ctx.Emitter) - // In case of non-collected aggregators, we just iterate over the grouped items - // Retrieve the grouped values by key, execute aggregation funcs and assign variable names to the results + // We just need to take the grouped values and call aggregation functions using them as args var key vm.Operand var value vm.Operand cc.compileAggregationFuncCall(selectors, func(i int, selectorVarName string) core.RegisterSequence { @@ -229,7 +222,7 @@ func (cc *CollectCompiler) compileGlobalAggregation(c fql.ICollectAggregatorCont // cc.ctx.Registers.Free(aggrIterVal) } -func (cc *CollectCompiler) collectAggregationFuncArgs(selectors []fql.ICollectAggregateSelectorContext, collector func(int, vm.Operand)) { +func (cc *LoopCollectCompiler) collectAggregationFuncArgs(selectors []fql.ICollectAggregateSelectorContext, collector func(int, vm.Operand)) { for i := 0; i < len(selectors); i++ { selector := selectors[i] fcx := selector.FunctionCallExpression() @@ -251,7 +244,7 @@ func (cc *CollectCompiler) collectAggregationFuncArgs(selectors []fql.ICollectAg } } -func (cc *CollectCompiler) compileAggregationFuncCall(selectors []fql.ICollectAggregateSelectorContext, provider func(int, string) core.RegisterSequence, cleanup func(int)) { +func (cc *LoopCollectCompiler) compileAggregationFuncCall(selectors []fql.ICollectAggregateSelectorContext, provider func(int, string) core.RegisterSequence, cleanup func(int)) { for i, selector := range selectors { fcx := selector.FunctionCallExpression() // We won't make any checks here, as we already did it before @@ -269,7 +262,7 @@ func (cc *CollectCompiler) compileAggregationFuncCall(selectors []fql.ICollectAg } } -func (cc *CollectCompiler) compileGrouping(ctx fql.ICollectGroupingContext) (vm.Operand, []fql.ICollectSelectorContext) { +func (cc *LoopCollectCompiler) compileGrouping(ctx fql.ICollectGroupingContext) (vm.Operand, []fql.ICollectSelectorContext) { selectors := ctx.AllCollectSelector() if len(selectors) == 0 { @@ -300,7 +293,7 @@ func (cc *CollectCompiler) compileGrouping(ctx fql.ICollectGroupingContext) (vm. return kvKeyReg, selectors } -func (cc *CollectCompiler) compileGroupSelectorVariables(selectors []fql.ICollectSelectorContext, kvKeyReg, kvValReg vm.Operand, isAggregation bool) { +func (cc *LoopCollectCompiler) compileGroupSelectorVariables(selectors []fql.ICollectSelectorContext, kvKeyReg, kvValReg vm.Operand, isAggregation bool) { if len(selectors) > 1 { variables := make([]vm.Operand, len(selectors)) @@ -336,7 +329,7 @@ func (cc *CollectCompiler) compileGroupSelectorVariables(selectors []fql.ICollec } } -func (cc *CollectCompiler) compileDefaultGroupProjection(loop *core.Loop, kvValReg vm.Operand, identifier antlr.TerminalNode, keeper fql.ICollectGroupVariableKeeperContext) string { +func (cc *LoopCollectCompiler) compileDefaultGroupProjection(loop *core.Loop, kvValReg vm.Operand, identifier antlr.TerminalNode, keeper fql.ICollectGroupVariableKeeperContext) string { if keeper == nil { seq := cc.ctx.Registers.AllocateSequence(2) // Key and Value for Map @@ -371,7 +364,7 @@ func (cc *CollectCompiler) compileDefaultGroupProjection(loop *core.Loop, kvValR return identifier.GetText() } -func (cc *CollectCompiler) compileCustomGroupProjection(_ *core.Loop, kvValReg vm.Operand, selector fql.ICollectSelectorContext) string { +func (cc *LoopCollectCompiler) compileCustomGroupProjection(_ *core.Loop, kvValReg vm.Operand, selector fql.ICollectSelectorContext) string { selectorReg := cc.ctx.ExprCompiler.Compile(selector.Expression()) cc.ctx.Emitter.EmitMove(kvValReg, selectorReg) cc.ctx.Registers.Free(selectorReg) @@ -379,7 +372,7 @@ func (cc *CollectCompiler) compileCustomGroupProjection(_ *core.Loop, kvValReg v return selector.Identifier().GetText() } -func (cc *CollectCompiler) selectGroupKey(isAggregation bool, kvKeyReg, kvValReg vm.Operand) vm.Operand { +func (cc *LoopCollectCompiler) selectGroupKey(isAggregation bool, kvKeyReg, kvValReg vm.Operand) vm.Operand { if isAggregation { return kvKeyReg } @@ -387,7 +380,7 @@ func (cc *CollectCompiler) selectGroupKey(isAggregation bool, kvKeyReg, kvValReg return kvValReg } -func (cc *CollectCompiler) initAggrAccumulators(selectors []fql.ICollectAggregateSelectorContext) []vm.Operand { +func (cc *LoopCollectCompiler) initAggrAccumulators(selectors []fql.ICollectAggregateSelectorContext) []vm.Operand { accums := make([]vm.Operand, len(selectors)) // First of all, we allocate registers for accumulators @@ -405,7 +398,7 @@ func (cc *CollectCompiler) initAggrAccumulators(selectors []fql.ICollectAggregat return accums } -func (cc *CollectCompiler) emitPushToAggrAccumulators(accums []vm.Operand, selectors []fql.ICollectAggregateSelectorContext, loop *core.Loop) { +func (cc *LoopCollectCompiler) emitPushToAggrAccumulators(accums []vm.Operand, selectors []fql.ICollectAggregateSelectorContext, loop *core.Loop) { for i, selector := range selectors { fcx := selector.FunctionCallExpression() args := cc.ctx.ExprCompiler.CompileArgumentList(fcx.FunctionCall().ArgumentList()) diff --git a/pkg/compiler/internal/loop_sort.go b/pkg/compiler/internal/loop_sort.go new file mode 100644 index 00000000..8a635c00 --- /dev/null +++ b/pkg/compiler/internal/loop_sort.go @@ -0,0 +1,82 @@ +package internal + +import ( + "github.com/MontFerret/ferret/pkg/compiler/internal/core" + "github.com/MontFerret/ferret/pkg/parser/fql" + "github.com/MontFerret/ferret/pkg/runtime" + "github.com/MontFerret/ferret/pkg/vm" +) + +type LoopSortCompiler struct { + ctx *CompilerContext +} + +func NewLoopSortCompiler(ctx *CompilerContext) *LoopSortCompiler { + return &LoopSortCompiler{ctx: ctx} +} + +func (lc *LoopSortCompiler) Compile(ctx fql.ISortClauseContext) { + loop := lc.ctx.Loops.Current() + + // We collect the sorting conditions (keys + // And wrap each loop element by a KeyValuePair + // Where a key is either a single value or a list of values + // These KeyValuePairs are then added to the dataset + kvKeyReg := lc.ctx.Registers.Allocate(core.Temp) + clauses := ctx.AllSortClauseExpression() + var directions []runtime.SortDirection + isSortMany := len(clauses) > 1 + + if isSortMany { + clausesRegs := make([]vm.Operand, len(clauses)) + directions = make([]runtime.SortDirection, len(clauses)) + // We create a sequence of Registers for the clauses + // To pack them into an array + keyRegs := lc.ctx.Registers.AllocateSequence(len(clauses)) + + for i, clause := range clauses { + clauseReg := lc.ctx.ExprCompiler.Compile(clause.Expression()) + lc.ctx.Emitter.EmitMove(keyRegs[i], clauseReg) + clausesRegs[i] = keyRegs[i] + directions[i] = sortDirection(clause.SortDirection()) + // TODO: Free Registers + } + + arrReg := lc.ctx.Registers.Allocate(core.Temp) + lc.ctx.Emitter.EmitAs(vm.OpList, arrReg, keyRegs) + lc.ctx.Emitter.EmitAB(vm.OpMove, kvKeyReg, arrReg) // TODO: Free Registers + } else { + clausesReg := lc.ctx.ExprCompiler.Compile(clauses[0].Expression()) + lc.ctx.Emitter.EmitAB(vm.OpMove, kvKeyReg, clausesReg) + } + + var kvValReg vm.Operand + + // In case the value is not used in the loop body, and only key is used + if loop.ValueName != "" { + kvValReg = loop.Value + } else { + // If so, we need to load it from the iterator + kvValReg = lc.ctx.Registers.Allocate(core.Temp) + loop.EmitValue(kvKeyReg, lc.ctx.Emitter) + } + + if isSortMany { + encoded := runtime.EncodeSortDirections(directions) + count := len(clauses) + + lc.ctx.Emitter.PatchSwapAxy(loop.ResultPos, vm.OpDataSetMultiSorter, loop.Result, encoded, count) + } else { + dir := sortDirection(clauses[0].SortDirection()) + lc.ctx.Emitter.PatchSwapAx(loop.ResultPos, vm.OpDataSetSorter, loop.Result, int(dir)) + } + + lc.ctx.Emitter.EmitABC(vm.OpPushKV, loop.Result, kvKeyReg, kvValReg) + loop.EmitFinalization(lc.ctx.Emitter) + + // Replace source with the Sorter + lc.ctx.Emitter.EmitAB(vm.OpMove, loop.Src, loop.Result) + + // Create a new loop + loop.EmitInitialization(lc.ctx.Registers, lc.ctx.Emitter) +} diff --git a/test/integration/bytecode/bytecode_for_collect_agg_test.go b/test/integration/bytecode/bytecode_for_collect_agg_test.go index 5556117e..add152e8 100644 --- a/test/integration/bytecode/bytecode_for_collect_agg_test.go +++ b/test/integration/bytecode/bytecode_for_collect_agg_test.go @@ -8,7 +8,7 @@ import ( func TestCollectAggregate(t *testing.T) { RunUseCases(t, []UseCase{ - ByteCodeCase(` + SkipByteCodeCase(` LET users = [] FOR u IN users COLLECT genderGroup = u.gender @@ -19,6 +19,17 @@ FOR u IN users minAge, maxAge } +`, BC{ + I(vm.OpReturn, 0, 7), + }), + ByteCodeCase(` + LET users = [] + FOR u IN users + COLLECT AGGREGATE minAge = MIN(u.age), maxAge = MAX(u.age) + RETURN { + minAge, + maxAge + } `, BC{ I(vm.OpReturn, 0, 7), }), diff --git a/test/integration/bytecode/bytecode_for_test.go b/test/integration/bytecode/bytecode_for_test.go new file mode 100644 index 00000000..6ec181ef --- /dev/null +++ b/test/integration/bytecode/bytecode_for_test.go @@ -0,0 +1,18 @@ +package bytecode_test + +import ( + "testing" + + "github.com/MontFerret/ferret/pkg/vm" +) + +func TestFor(t *testing.T) { + RunUseCases(t, []UseCase{ + ByteCodeCase(` +FOR i IN 1..5 + RETURN i +`, BC{ + I(vm.OpReturn, 0, 7), + }), + }) +} diff --git a/test/integration/vm/vm_for_test.go b/test/integration/vm/vm_for_test.go index b9d49bdc..53333d6b 100644 --- a/test/integration/vm/vm_for_test.go +++ b/test/integration/vm/vm_for_test.go @@ -23,7 +23,10 @@ func TestFor(t *testing.T) { FOR foo IN foo RETURN foo `, "Should not compile FOR foo IN foo"), - CaseArray("FOR i IN 1..5 RETURN i", []any{1, 2, 3, 4, 5}), + CaseArray(` +FOR i IN 1..5 + RETURN i +`, []any{1, 2, 3, 4, 5}), CaseArray( ` FOR i IN 1..5