From 8731b7d71fcca019eee7a6befae9c1ea35c254cc Mon Sep 17 00:00:00 2001 From: Tim Voronov Date: Wed, 25 Jun 2025 14:10:41 -0400 Subject: [PATCH] Refactor loop and collect compilation; replace `lc` and `cc` parameters with `c` for consistency, unify naming conventions across loop and collect compilers, and improve method clarity and readability. --- pkg/compiler/internal/core/loop.go | 26 ++++ pkg/compiler/internal/loop.go | 174 +++++++++++----------- pkg/compiler/internal/loop_collect.go | 134 ++++++++--------- pkg/compiler/internal/loop_collect_agg.go | 124 +++++++-------- pkg/compiler/internal/loop_sort.go | 24 +-- test/integration/vm/vm_for_nested_test.go | 100 +++++++++++++ 6 files changed, 344 insertions(+), 238 deletions(-) diff --git a/pkg/compiler/internal/core/loop.go b/pkg/compiler/internal/core/loop.go index 8d803bc2..6c7d3c42 100644 --- a/pkg/compiler/internal/core/loop.go +++ b/pkg/compiler/internal/core/loop.go @@ -103,6 +103,32 @@ func (l *Loop) EmitFinalization(emitter *Emitter) { emitter.PatchJump(l.Jump) } +func (l *Loop) PatchDestinationAx(alloc *RegisterAllocator, emitter *Emitter, op vm.Opcode, arg int) vm.Operand { + if l.Allocate { + emitter.PatchSwapAx(l.Pos, op, l.Dst, arg) + + return l.Dst + } + + tmp := alloc.Allocate(Temp) + emitter.PatchInsertAx(l.Pos, op, tmp, arg) + l.Jump++ + return tmp +} + +func (l *Loop) PatchDestinationAxy(alloc *RegisterAllocator, emitter *Emitter, op vm.Opcode, arg1, arg2 int) vm.Operand { + if l.Allocate { + emitter.PatchSwapAxy(l.Pos, op, l.Dst, arg1, arg2) + + return l.Dst + } + + tmp := alloc.Allocate(Temp) + emitter.PatchInsertAxy(l.Pos, op, tmp, arg1, arg2) + l.Jump++ + return tmp +} + func (l *Loop) canDeclareVar(name string) bool { return name != "" && name != IgnorePseudoVariable } diff --git a/pkg/compiler/internal/loop.go b/pkg/compiler/internal/loop.go index 1fa94401..504f8a59 100644 --- a/pkg/compiler/internal/loop.go +++ b/pkg/compiler/internal/loop.go @@ -17,189 +17,189 @@ func NewLoopCompiler(ctx *CompilerContext) *LoopCompiler { return &LoopCompiler{ctx: ctx} } -func (lc *LoopCompiler) Compile(ctx fql.IForExpressionContext) vm.Operand { - returnRuleCtx := lc.compileInitialization(ctx) +func (c *LoopCompiler) Compile(ctx fql.IForExpressionContext) vm.Operand { + returnRuleCtx := c.compileInitialization(ctx) // body if body := ctx.AllForExpressionBody(); body != nil && len(body) > 0 { for _, b := range body { - if c := b.ForExpressionStatement(); c != nil { - lc.compileForExpressionStatement(c) - } else if c := b.ForExpressionClause(); c != nil { - lc.compileForExpressionClause(c) + if ec := b.ForExpressionStatement(); ec != nil { + c.compileForExpressionStatement(ec) + } else if ec := b.ForExpressionClause(); ec != nil { + c.compileForExpressionClause(ec) } } } - return lc.compileFinalization(returnRuleCtx) + return c.compileFinalization(returnRuleCtx) } -func (lc *LoopCompiler) compileInitialization(ctx fql.IForExpressionContext) antlr.RuleContext { +func (c *LoopCompiler) compileInitialization(ctx fql.IForExpressionContext) antlr.RuleContext { var distinct bool var returnRuleCtx antlr.RuleContext var loopType core.LoopType returnCtx := ctx.ForExpressionReturn() - if c := returnCtx.ReturnExpression(); c != nil { - returnRuleCtx = c - distinct = c.Distinct() != nil + if re := returnCtx.ReturnExpression(); re != nil { + returnRuleCtx = re + distinct = re.Distinct() != nil loopType = core.NormalLoop - } else if c := returnCtx.ForExpression(); c != nil { - returnRuleCtx = c + } else if fe := returnCtx.ForExpression(); fe != nil { + returnRuleCtx = fe loopType = core.PassThroughLoop } - src := lc.compileForExpressionSource(ctx.ForExpressionSource()) - loop := lc.ctx.Loops.CreateFor(loopType, src, distinct) - lc.ctx.Loops.Push(loop) - lc.ctx.Symbols.EnterScope() + src := c.compileForExpressionSource(ctx.ForExpressionSource()) + loop := c.ctx.Loops.CreateFor(loopType, src, distinct) + c.ctx.Loops.Push(loop) + c.ctx.Symbols.EnterScope() if val := ctx.GetValueVariable(); val != nil { - loop.DeclareValueVar(val.GetText(), lc.ctx.Symbols) + loop.DeclareValueVar(val.GetText(), c.ctx.Symbols) } if ctr := ctx.GetCounterVariable(); ctr != nil { - loop.DeclareKeyVar(ctr.GetText(), lc.ctx.Symbols) + loop.DeclareKeyVar(ctr.GetText(), c.ctx.Symbols) } - loop.EmitInitialization(lc.ctx.Registers, lc.ctx.Emitter) + loop.EmitInitialization(c.ctx.Registers, c.ctx.Emitter) return returnRuleCtx } -func (lc *LoopCompiler) compileFinalization(ctx antlr.RuleContext) vm.Operand { - loop := lc.ctx.Loops.Current() +func (c *LoopCompiler) compileFinalization(ctx antlr.RuleContext) vm.Operand { + loop := c.ctx.Loops.Current() // RETURN if loop.Type != core.PassThroughLoop { - c := ctx.(*fql.ReturnExpressionContext) - expReg := lc.ctx.ExprCompiler.Compile(c.Expression()) + re := ctx.(*fql.ReturnExpressionContext) + expReg := c.ctx.ExprCompiler.Compile(re.Expression()) - lc.ctx.Emitter.EmitAB(vm.OpPush, loop.Dst, expReg) + c.ctx.Emitter.EmitAB(vm.OpPush, loop.Dst, expReg) } else if ctx != nil { - if c, ok := ctx.(*fql.ForExpressionContext); ok { - lc.Compile(c) + if fe, ok := ctx.(*fql.ForExpressionContext); ok { + c.Compile(fe) } } - loop.EmitFinalization(lc.ctx.Emitter) - lc.ctx.Symbols.ExitScope() - lc.ctx.Loops.Pop() + loop.EmitFinalization(c.ctx.Emitter) + c.ctx.Symbols.ExitScope() + c.ctx.Loops.Pop() // TODO: Free operands return loop.Dst } -func (lc *LoopCompiler) compileForExpressionSource(ctx fql.IForExpressionSourceContext) vm.Operand { - if c := ctx.FunctionCallExpression(); c != nil { - return lc.ctx.ExprCompiler.CompileFunctionCallExpression(c) +func (c *LoopCompiler) compileForExpressionSource(ctx fql.IForExpressionSourceContext) vm.Operand { + if fce := ctx.FunctionCallExpression(); fce != nil { + return c.ctx.ExprCompiler.CompileFunctionCallExpression(fce) } - if c := ctx.MemberExpression(); c != nil { - return lc.ctx.ExprCompiler.CompileMemberExpression(c) + if me := ctx.MemberExpression(); me != nil { + return c.ctx.ExprCompiler.CompileMemberExpression(me) } - if c := ctx.Variable(); c != nil { - return lc.ctx.ExprCompiler.CompileVariable(c) + if v := ctx.Variable(); v != nil { + return c.ctx.ExprCompiler.CompileVariable(v) } - if c := ctx.Param(); c != nil { - return lc.ctx.ExprCompiler.CompileParam(c) + if p := ctx.Param(); p != nil { + return c.ctx.ExprCompiler.CompileParam(p) } - if c := ctx.RangeOperator(); c != nil { - return lc.ctx.ExprCompiler.CompileRangeOperator(c) + if ro := ctx.RangeOperator(); ro != nil { + return c.ctx.ExprCompiler.CompileRangeOperator(ro) } - if c := ctx.ArrayLiteral(); c != nil { - return lc.ctx.LiteralCompiler.CompileArrayLiteral(c) + if al := ctx.ArrayLiteral(); al != nil { + return c.ctx.LiteralCompiler.CompileArrayLiteral(al) } - if c := ctx.ObjectLiteral(); c != nil { - return lc.ctx.LiteralCompiler.CompileObjectLiteral(c) + if ol := ctx.ObjectLiteral(); ol != nil { + return c.ctx.LiteralCompiler.CompileObjectLiteral(ol) } panic(runtime.Error(core.ErrUnexpectedToken, ctx.GetText())) } -func (lc *LoopCompiler) compileForExpressionStatement(ctx fql.IForExpressionStatementContext) { - if c := ctx.VariableDeclaration(); c != nil { - _ = lc.ctx.StmtCompiler.CompileVariableDeclaration(c) - } else if c := ctx.FunctionCallExpression(); c != nil { - _ = lc.ctx.ExprCompiler.CompileFunctionCallExpression(c) +func (c *LoopCompiler) compileForExpressionStatement(ctx fql.IForExpressionStatementContext) { + if vd := ctx.VariableDeclaration(); vd != nil { + _ = c.ctx.StmtCompiler.CompileVariableDeclaration(vd) + } else if fce := ctx.FunctionCallExpression(); fce != nil { + _ = c.ctx.ExprCompiler.CompileFunctionCallExpression(fce) } // TODO: Free register if needed } -func (lc *LoopCompiler) compileForExpressionClause(ctx fql.IForExpressionClauseContext) { - if c := ctx.LimitClause(); c != nil { - lc.compileLimitClause(c) - } else if c := ctx.FilterClause(); c != nil { - lc.compileFilterClause(c) - } else if c := ctx.SortClause(); c != nil { - lc.compileSortClause(c) - } else if c := ctx.CollectClause(); c != nil { - lc.compileCollectClause(c) +func (c *LoopCompiler) compileForExpressionClause(ctx fql.IForExpressionClauseContext) { + if lc := ctx.LimitClause(); lc != nil { + c.compileLimitClause(lc) + } else if fc := ctx.FilterClause(); fc != nil { + c.compileFilterClause(fc) + } else if sc := ctx.SortClause(); sc != nil { + c.compileSortClause(sc) + } else if cc := ctx.CollectClause(); cc != nil { + c.compileCollectClause(cc) } } -func (lc *LoopCompiler) compileLimitClause(ctx fql.ILimitClauseContext) { +func (c *LoopCompiler) compileLimitClause(ctx fql.ILimitClauseContext) { clauses := ctx.AllLimitClauseValue() if len(clauses) == 1 { - lc.compileLimit(lc.compileLimitClauseValue(clauses[0])) + c.compileLimit(c.compileLimitClauseValue(clauses[0])) } else { - lc.compileOffset(lc.compileLimitClauseValue(clauses[0])) - lc.compileLimit(lc.compileLimitClauseValue(clauses[1])) + c.compileOffset(c.compileLimitClauseValue(clauses[0])) + c.compileLimit(c.compileLimitClauseValue(clauses[1])) } } -func (lc *LoopCompiler) compileLimitClauseValue(ctx fql.ILimitClauseValueContext) vm.Operand { - if c := ctx.Param(); c != nil { - return lc.ctx.ExprCompiler.CompileParam(c) +func (c *LoopCompiler) compileLimitClauseValue(ctx fql.ILimitClauseValueContext) vm.Operand { + if pm := ctx.Param(); pm != nil { + return c.ctx.ExprCompiler.CompileParam(pm) } - if c := ctx.IntegerLiteral(); c != nil { - return lc.ctx.LiteralCompiler.CompileIntegerLiteral(c) + if il := ctx.IntegerLiteral(); il != nil { + return c.ctx.LiteralCompiler.CompileIntegerLiteral(il) } - if c := ctx.Variable(); c != nil { - return lc.ctx.ExprCompiler.CompileVariable(c) + if vb := ctx.Variable(); vb != nil { + return c.ctx.ExprCompiler.CompileVariable(vb) } - if c := ctx.MemberExpression(); c != nil { - return lc.ctx.ExprCompiler.CompileMemberExpression(c) + if me := ctx.MemberExpression(); me != nil { + return c.ctx.ExprCompiler.CompileMemberExpression(me) } - if c := ctx.FunctionCallExpression(); c != nil { - return lc.ctx.ExprCompiler.CompileFunctionCallExpression(c) + if fce := ctx.FunctionCallExpression(); fce != nil { + return c.ctx.ExprCompiler.CompileFunctionCallExpression(fce) } panic(runtime.Error(core.ErrUnexpectedToken, ctx.GetText())) } -func (lc *LoopCompiler) compileLimit(src vm.Operand) { - state := lc.ctx.Registers.Allocate(core.State) - lc.ctx.Emitter.EmitABx(vm.OpIterLimit, state, src, lc.ctx.Loops.Current().Jump) +func (c *LoopCompiler) compileLimit(src vm.Operand) { + state := c.ctx.Registers.Allocate(core.State) + c.ctx.Emitter.EmitABx(vm.OpIterLimit, state, src, c.ctx.Loops.Current().Jump) } -func (lc *LoopCompiler) compileOffset(src vm.Operand) { - state := lc.ctx.Registers.Allocate(core.State) - lc.ctx.Emitter.EmitABx(vm.OpIterSkip, state, src, lc.ctx.Loops.Current().Jump) +func (c *LoopCompiler) compileOffset(src vm.Operand) { + state := c.ctx.Registers.Allocate(core.State) + c.ctx.Emitter.EmitABx(vm.OpIterSkip, state, src, c.ctx.Loops.Current().Jump) } -func (lc *LoopCompiler) compileFilterClause(ctx fql.IFilterClauseContext) { - src := lc.ctx.ExprCompiler.Compile(ctx.Expression()) - lc.ctx.Emitter.EmitJumpIfFalse(src, lc.ctx.Loops.Current().Jump) +func (c *LoopCompiler) compileFilterClause(ctx fql.IFilterClauseContext) { + src := c.ctx.ExprCompiler.Compile(ctx.Expression()) + c.ctx.Emitter.EmitJumpIfFalse(src, c.ctx.Loops.Current().Jump) } -func (lc *LoopCompiler) compileSortClause(ctx fql.ISortClauseContext) { - lc.ctx.LoopSortCompiler.Compile(ctx) +func (c *LoopCompiler) compileSortClause(ctx fql.ISortClauseContext) { + c.ctx.LoopSortCompiler.Compile(ctx) } -func (lc *LoopCompiler) compileCollectClause(ctx fql.ICollectClauseContext) { - lc.ctx.LoopCollectCompiler.Compile(ctx) +func (c *LoopCompiler) compileCollectClause(ctx fql.ICollectClauseContext) { + c.ctx.LoopCollectCompiler.Compile(ctx) } diff --git a/pkg/compiler/internal/loop_collect.go b/pkg/compiler/internal/loop_collect.go index b4e1deaf..9b74b5a9 100644 --- a/pkg/compiler/internal/loop_collect.go +++ b/pkg/compiler/internal/loop_collect.go @@ -17,22 +17,22 @@ func NewCollectCompiler(ctx *CompilerContext) *LoopCollectCompiler { return &LoopCollectCompiler{ctx: ctx} } -func (cc *LoopCollectCompiler) Compile(ctx fql.ICollectClauseContext) { +func (c *LoopCollectCompiler) Compile(ctx fql.ICollectClauseContext) { aggregator := ctx.CollectAggregator() - kv, groupSelectors := cc.compileCollect(ctx, aggregator != nil) + kv, groupSelectors := c.compileCollect(ctx, aggregator != nil) // Aggregation loop if aggregator != nil { - cc.compileAggregation(aggregator, len(groupSelectors) > 0) + c.compileAggregation(aggregator, len(groupSelectors) > 0) } if len(groupSelectors) > 0 { // Now we are defining new variables for the group selectors - cc.compileGroupSelectorVariables(groupSelectors, kv, aggregator != nil) + c.compileGroupSelectorVariables(groupSelectors, kv, aggregator != nil) } } -func (cc *LoopCollectCompiler) compileCollect(ctx fql.ICollectClauseContext, aggregation bool) (*core.KV, []fql.ICollectSelectorContext) { +func (c *LoopCollectCompiler) compileCollect(ctx fql.ICollectClauseContext, aggregation bool) (*core.KV, []fql.ICollectSelectorContext) { grouping := ctx.CollectGrouping() counter := ctx.CollectCounter() @@ -40,10 +40,10 @@ func (cc *LoopCollectCompiler) compileCollect(ctx fql.ICollectClauseContext, agg return core.NewKV(vm.NoopOperand, vm.NoopOperand), nil } - loop := cc.ctx.Loops.Current() + loop := c.ctx.Loops.Current() - kv, groupSelectors := cc.initializeCollector(grouping) - projectionVariableName, collectorType := cc.initializeProjection(ctx, loop, kv, counter, grouping != nil) + kv, groupSelectors := c.initializeCollector(grouping) + projectionVarName, collectorType := c.initializeProjection(ctx, loop, kv, counter, grouping != nil) // If we use aggregators, we need to collect group items by key if aggregation && collectorType != core.CollectorTypeKeyGroup { @@ -51,67 +51,67 @@ func (cc *LoopCollectCompiler) compileCollect(ctx fql.ICollectClauseContext, agg collectorType = core.CollectorTypeKeyGroup } - cc.finalizeCollector(loop, collectorType, kv) + c.finalizeCollector(loop, collectorType, kv) // If the projection is used, we allocate a new register for the variable and put the iterator's value into it - if projectionVariableName != "" { + if projectionVarName != "" { // Now we need to expand group variables from the dataset - loop.DeclareValueVar(projectionVariableName, cc.ctx.Symbols) - loop.EmitInitialization(cc.ctx.Registers, cc.ctx.Emitter) + loop.DeclareValueVar(projectionVarName, c.ctx.Symbols) + loop.EmitInitialization(c.ctx.Registers, c.ctx.Emitter) - loop.EmitKey(kv.Value, cc.ctx.Emitter) + loop.EmitKey(kv.Value, c.ctx.Emitter) } else { - loop.EmitInitialization(cc.ctx.Registers, cc.ctx.Emitter) - loop.EmitKey(kv.Key, cc.ctx.Emitter) + loop.EmitInitialization(c.ctx.Registers, c.ctx.Emitter) + loop.EmitKey(kv.Key, c.ctx.Emitter) } return kv, groupSelectors } // initializeKeyValue creates the KeyValue pair for collection, handling both grouping and value setup. -func (cc *LoopCollectCompiler) initializeCollector(grouping fql.ICollectGroupingContext) (*core.KV, []fql.ICollectSelectorContext) { +func (c *LoopCollectCompiler) initializeCollector(grouping fql.ICollectGroupingContext) (*core.KV, []fql.ICollectSelectorContext) { var groupSelectors []fql.ICollectSelectorContext kv := core.NewKV(vm.NoopOperand, vm.NoopOperand) - loop := cc.ctx.Loops.Current() + loop := c.ctx.Loops.Current() // Handle grouping key if present if grouping != nil { - keyReg, selectors := cc.compileGrouping(grouping) + keyReg, selectors := c.compileGrouping(grouping) kv.Key = keyReg groupSelectors = selectors } // Setup value register and emit value from current loop - kv.Value = cc.ctx.Registers.Allocate(core.Temp) - loop.EmitValue(kv.Value, cc.ctx.Emitter) + kv.Value = c.ctx.Registers.Allocate(core.Temp) + loop.EmitValue(kv.Value, c.ctx.Emitter) return kv, groupSelectors } -func (cc *LoopCollectCompiler) finalizeCollector(loop *core.Loop, collectorType core.CollectorType, kv *core.KV) { +func (c *LoopCollectCompiler) finalizeCollector(loop *core.Loop, collectorType core.CollectorType, kv *core.KV) { // We replace DataSet initialization with Collector initialization - cc.ctx.Emitter.PatchSwapAx(loop.Pos, vm.OpDataSetCollector, loop.Dst, int(collectorType)) - cc.ctx.Emitter.EmitABC(vm.OpPushKV, loop.Dst, kv.Key, kv.Value) - loop.EmitFinalization(cc.ctx.Emitter) + dst := loop.PatchDestinationAx(c.ctx.Registers, c.ctx.Emitter, vm.OpDataSetCollector, int(collectorType)) + c.ctx.Emitter.EmitABC(vm.OpPushKV, dst, kv.Key, kv.Value) + loop.EmitFinalization(c.ctx.Emitter) - cc.ctx.Emitter.EmitMove(loop.Src, loop.Dst) + c.ctx.Emitter.EmitMove(loop.Src, dst) - cc.ctx.Registers.Free(loop.Value) - cc.ctx.Registers.Free(loop.Key) + c.ctx.Registers.Free(loop.Value) + c.ctx.Registers.Free(loop.Key) loop.Value = kv.Value loop.Key = vm.NoopOperand } // initializeProjection handles the projection setup for group variables and counters. // Returns the projection variable name and the appropriate collector type. -func (cc *LoopCollectCompiler) initializeProjection(ctx fql.ICollectClauseContext, loop *core.Loop, kv *core.KV, counter fql.ICollectCounterContext, hasGrouping bool) (string, core.CollectorType) { +func (c *LoopCollectCompiler) initializeProjection(ctx fql.ICollectClauseContext, loop *core.Loop, kv *core.KV, counter fql.ICollectCounterContext, hasGrouping bool) (string, core.CollectorType) { projectionVariableName := "" collectorType := core.CollectorTypeKey // Handle group variable projection if groupVar := ctx.CollectGroupVariable(); groupVar != nil { - projectionVariableName = cc.compileGroupVariableProjection(loop, kv, groupVar) + projectionVariableName = c.compileGroupVariableProjection(loop, kv, groupVar) collectorType = core.CollectorTypeKeyGroup return projectionVariableName, collectorType } @@ -119,14 +119,14 @@ func (cc *LoopCollectCompiler) initializeProjection(ctx fql.ICollectClauseContex // Handle counter projection if counter != nil { projectionVariableName = counter.Identifier().GetText() - collectorType = cc.determineCounterCollectorType(hasGrouping) + collectorType = c.determineCounterCollectorType(hasGrouping) } return projectionVariableName, collectorType } // determineCounterCollectorType returns the appropriate collector type for counter operations. -func (cc *LoopCollectCompiler) determineCounterCollectorType(hasGrouping bool) core.CollectorType { +func (c *LoopCollectCompiler) determineCounterCollectorType(hasGrouping bool) core.CollectorType { if hasGrouping { return core.CollectorTypeKeyCounter } @@ -134,7 +134,7 @@ func (cc *LoopCollectCompiler) determineCounterCollectorType(hasGrouping bool) c return core.CollectorTypeCounter } -func (cc *LoopCollectCompiler) compileGrouping(ctx fql.ICollectGroupingContext) (vm.Operand, []fql.ICollectSelectorContext) { +func (c *LoopCollectCompiler) compileGrouping(ctx fql.ICollectGroupingContext) (vm.Operand, []fql.ICollectSelectorContext) { selectors := ctx.AllCollectSelector() if len(selectors) == 0 { @@ -146,41 +146,41 @@ func (cc *LoopCollectCompiler) compileGrouping(ctx fql.ICollectGroupingContext) if len(selectors) > 1 { // We create a sequence of Registers for the clauses // To pack them into an array - selectorRegs := cc.ctx.Registers.AllocateSequence(len(selectors)) + selectorRegs := c.ctx.Registers.AllocateSequence(len(selectors)) for i, selector := range selectors { - reg := cc.ctx.ExprCompiler.Compile(selector.Expression()) - cc.ctx.Emitter.EmitAB(vm.OpMove, selectorRegs[i], reg) + reg := c.ctx.ExprCompiler.Compile(selector.Expression()) + c.ctx.Emitter.EmitAB(vm.OpMove, selectorRegs[i], reg) // Free the register after moving its value to the sequence register - cc.ctx.Registers.Free(reg) + c.ctx.Registers.Free(reg) } - kvKeyReg = cc.ctx.Registers.Allocate(core.Temp) - cc.ctx.Emitter.EmitAs(vm.OpList, kvKeyReg, selectorRegs) - cc.ctx.Registers.FreeSequence(selectorRegs) + kvKeyReg = c.ctx.Registers.Allocate(core.Temp) + c.ctx.Emitter.EmitAs(vm.OpList, kvKeyReg, selectorRegs) + c.ctx.Registers.FreeSequence(selectorRegs) } else { - kvKeyReg = cc.ctx.ExprCompiler.Compile(selectors[0].Expression()) + kvKeyReg = c.ctx.ExprCompiler.Compile(selectors[0].Expression()) } return kvKeyReg, selectors } // compileGroupVariableProjection processes group variable projections (both default and custom). -func (cc *LoopCollectCompiler) compileGroupVariableProjection(loop *core.Loop, kv *core.KV, groupVar fql.ICollectGroupVariableContext) string { +func (c *LoopCollectCompiler) compileGroupVariableProjection(loop *core.Loop, kv *core.KV, groupVar fql.ICollectGroupVariableContext) string { // Handle default projection (identifier) if identifier := groupVar.Identifier(); identifier != nil { - return cc.compileDefaultGroupProjection(loop, kv, identifier, groupVar.CollectGroupVariableKeeper()) + return c.compileDefaultGroupProjection(loop, kv, identifier, groupVar.CollectGroupVariableKeeper()) } // Handle custom projection (selector expression) if selector := groupVar.CollectSelector(); selector != nil { - return cc.compileCustomGroupProjection(loop, kv, selector) + return c.compileCustomGroupProjection(loop, kv, selector) } return "" } -func (cc *LoopCollectCompiler) compileGroupSelectorVariables(selectors []fql.ICollectSelectorContext, kv *core.KV, isAggregation bool) { +func (c *LoopCollectCompiler) compileGroupSelectorVariables(selectors []fql.ICollectSelectorContext, kv *core.KV, isAggregation bool) { if len(selectors) > 1 { variables := make([]vm.Operand, len(selectors)) @@ -188,7 +188,7 @@ func (cc *LoopCollectCompiler) compileGroupSelectorVariables(selectors []fql.ICo name := selector.Identifier().GetText() if variables[i] == vm.NoopOperand { - variables[i] = cc.ctx.Symbols.DeclareLocal(name) + variables[i] = c.ctx.Symbols.DeclareLocal(name) } reg := kv.Value @@ -197,69 +197,69 @@ func (cc *LoopCollectCompiler) compileGroupSelectorVariables(selectors []fql.ICo reg = kv.Key } - cc.ctx.Emitter.EmitABC(vm.OpLoadIndex, variables[i], reg, loadConstant(cc.ctx, runtime.Int(i))) + c.ctx.Emitter.EmitABC(vm.OpLoadIndex, variables[i], reg, loadConstant(c.ctx, runtime.Int(i))) } // Free the register after moving its value to the variable for _, reg := range variables { - cc.ctx.Registers.Free(reg) + c.ctx.Registers.Free(reg) } } else { // Get the variable name name := selectors[0].Identifier().GetText() // Define a variable for each selector - varReg := cc.ctx.Symbols.DeclareLocal(name) - reg := cc.selectGroupKey(isAggregation, kv) + varReg := c.ctx.Symbols.DeclareLocal(name) + reg := c.selectGroupKey(isAggregation, kv) // If we have a single selector, we can just move the value - cc.ctx.Emitter.EmitAB(vm.OpMove, varReg, reg) + c.ctx.Emitter.EmitAB(vm.OpMove, varReg, reg) } } -func (cc *LoopCollectCompiler) compileDefaultGroupProjection(loop *core.Loop, kv *core.KV, identifier antlr.TerminalNode, keeper fql.ICollectGroupVariableKeeperContext) string { +func (c *LoopCollectCompiler) compileDefaultGroupProjection(loop *core.Loop, kv *core.KV, identifier antlr.TerminalNode, keeper fql.ICollectGroupVariableKeeperContext) string { if keeper == nil { - seq := cc.ctx.Registers.AllocateSequence(2) // Key and Value for Map + seq := c.ctx.Registers.AllocateSequence(2) // Key and Value for Map // TODO: Review this. It's quite a questionable ArrangoDB feature of wrapping group items by a nested object // We will keep it for now for backward compatibility. - loadConstantTo(cc.ctx, runtime.String(loop.ValueName), seq[0]) // Map key - cc.ctx.Emitter.EmitAB(vm.OpMove, seq[1], kv.Value) // Map value - cc.ctx.Emitter.EmitAs(vm.OpMap, kv.Value, seq) + loadConstantTo(c.ctx, runtime.String(loop.ValueName), seq[0]) // Map key + c.ctx.Emitter.EmitAB(vm.OpMove, seq[1], kv.Value) // Map value + c.ctx.Emitter.EmitAs(vm.OpMap, kv.Value, seq) - cc.ctx.Registers.FreeSequence(seq) + c.ctx.Registers.FreeSequence(seq) } else { variables := keeper.AllIdentifier() - seq := cc.ctx.Registers.AllocateSequence(len(variables) * 2) + seq := c.ctx.Registers.AllocateSequence(len(variables) * 2) for i, j := 0, 0; i < len(variables); i, j = i+1, j+2 { varName := variables[i].GetText() - loadConstantTo(cc.ctx, runtime.String(varName), seq[j]) + loadConstantTo(c.ctx, runtime.String(varName), seq[j]) - variable, _, found := cc.ctx.Symbols.Resolve(varName) + variable, _, found := c.ctx.Symbols.Resolve(varName) if !found { panic("variable not found: " + varName) } - cc.ctx.Emitter.EmitAB(vm.OpMove, seq[j+1], variable) + c.ctx.Emitter.EmitAB(vm.OpMove, seq[j+1], variable) } - cc.ctx.Emitter.EmitAs(vm.OpMap, kv.Value, seq) - cc.ctx.Registers.FreeSequence(seq) + c.ctx.Emitter.EmitAs(vm.OpMap, kv.Value, seq) + c.ctx.Registers.FreeSequence(seq) } return identifier.GetText() } -func (cc *LoopCollectCompiler) compileCustomGroupProjection(_ *core.Loop, kv *core.KV, selector fql.ICollectSelectorContext) string { - selectorReg := cc.ctx.ExprCompiler.Compile(selector.Expression()) - cc.ctx.Emitter.EmitMove(kv.Value, selectorReg) - cc.ctx.Registers.Free(selectorReg) +func (c *LoopCollectCompiler) compileCustomGroupProjection(_ *core.Loop, kv *core.KV, selector fql.ICollectSelectorContext) string { + selectorReg := c.ctx.ExprCompiler.Compile(selector.Expression()) + c.ctx.Emitter.EmitMove(kv.Value, selectorReg) + c.ctx.Registers.Free(selectorReg) return selector.Identifier().GetText() } -func (cc *LoopCollectCompiler) selectGroupKey(isAggregation bool, kv *core.KV) vm.Operand { +func (c *LoopCollectCompiler) selectGroupKey(isAggregation bool, kv *core.KV) vm.Operand { if isAggregation { return kv.Key } diff --git a/pkg/compiler/internal/loop_collect_agg.go b/pkg/compiler/internal/loop_collect_agg.go index 5526f58e..0ea6fb5f 100644 --- a/pkg/compiler/internal/loop_collect_agg.go +++ b/pkg/compiler/internal/loop_collect_agg.go @@ -9,114 +9,114 @@ import ( "github.com/MontFerret/ferret/pkg/vm" ) -func (cc *LoopCollectCompiler) compileAggregation(c fql.ICollectAggregatorContext, isGrouped bool) { +func (c *LoopCollectCompiler) compileAggregation(ctx fql.ICollectAggregatorContext, isGrouped bool) { if isGrouped { - cc.compileGroupedAggregation(c) + c.compileGroupedAggregation(ctx) } else { - cc.compileGlobalAggregation(c) + c.compileGlobalAggregation(ctx) } } -func (cc *LoopCollectCompiler) compileGroupedAggregation(c fql.ICollectAggregatorContext) { - parentLoop := cc.ctx.Loops.Current() +func (c *LoopCollectCompiler) compileGroupedAggregation(ctx fql.ICollectAggregatorContext) { + parentLoop := c.ctx.Loops.Current() // We need to allocate a temporary accumulator to store aggregation results - selectors := c.AllCollectAggregateSelector() - accumulator := cc.ctx.Registers.Allocate(core.Temp) - cc.ctx.Emitter.EmitAx(vm.OpDataSetCollector, accumulator, int(core.CollectorTypeKeyGroup)) + selectors := ctx.AllCollectAggregateSelector() + accumulator := c.ctx.Registers.Allocate(core.Temp) + c.ctx.Emitter.EmitAx(vm.OpDataSetCollector, accumulator, int(core.CollectorTypeKeyGroup)) - loop := cc.ctx.Loops.CreateFor(core.TemporalLoop, cc.ctx.Registers.Allocate(core.Temp), false) + loop := c.ctx.Loops.CreateFor(core.TemporalLoop, c.ctx.Registers.Allocate(core.Temp), false) // Now we iterate over the grouped items - parentLoop.EmitValue(loop.Src, cc.ctx.Emitter) + parentLoop.EmitValue(loop.Src, c.ctx.Emitter) // Nested scope for aggregators - cc.ctx.Symbols.EnterScope() - loop.DeclareValueVar(parentLoop.ValueName, cc.ctx.Symbols) - loop.EmitInitialization(cc.ctx.Registers, cc.ctx.Emitter) + c.ctx.Symbols.EnterScope() + loop.DeclareValueVar(parentLoop.ValueName, c.ctx.Symbols) + loop.EmitInitialization(c.ctx.Registers, c.ctx.Emitter) // Add value selectors to the accumulators - argsPkg := cc.compileAggregationFuncArgs(selectors, accumulator) + argsPkg := c.compileAggregationFuncArgs(selectors, accumulator) - loop.EmitFinalization(cc.ctx.Emitter) - cc.ctx.Symbols.ExitScope() + loop.EmitFinalization(c.ctx.Emitter) + c.ctx.Symbols.ExitScope() // Now we can iterate over the selectors and execute the aggregation functions by passing the accumulators // And define variables for each accumulator result - cc.compileAggregationFuncCall(selectors, accumulator, argsPkg) - cc.ctx.Registers.Free(accumulator) + c.compileAggregationFuncCall(selectors, accumulator, argsPkg) + c.ctx.Registers.Free(accumulator) } -func (cc *LoopCollectCompiler) compileGlobalAggregation(c fql.ICollectAggregatorContext) { - parentLoop := cc.ctx.Loops.Current() +func (c *LoopCollectCompiler) compileGlobalAggregation(ctx fql.ICollectAggregatorContext) { + parentLoop := c.ctx.Loops.Current() // we create a custom collector for aggregators - cc.ctx.Emitter.PatchSwapAx(parentLoop.Pos, vm.OpDataSetCollector, parentLoop.Dst, int(core.CollectorTypeKeyGroup)) + c.ctx.Emitter.PatchSwapAx(parentLoop.Pos, vm.OpDataSetCollector, parentLoop.Dst, int(core.CollectorTypeKeyGroup)) // Nested scope for aggregators - cc.ctx.Symbols.EnterScope() + c.ctx.Symbols.EnterScope() // Now we add value selectors to the collector - selectors := c.AllCollectAggregateSelector() - argsPkg := cc.compileAggregationFuncArgs(selectors, parentLoop.Dst) + selectors := ctx.AllCollectAggregateSelector() + argsPkg := c.compileAggregationFuncArgs(selectors, parentLoop.Dst) - parentLoop.EmitFinalization(cc.ctx.Emitter) - cc.ctx.Loops.Pop() - cc.ctx.Symbols.ExitScope() + parentLoop.EmitFinalization(c.ctx.Emitter) + c.ctx.Loops.Pop() + c.ctx.Symbols.ExitScope() // Now we can iterate over the grouped items - zero := cc.ctx.Registers.Allocate(core.Temp) - cc.ctx.Emitter.EmitA(vm.OpLoadZero, zero) + zero := c.ctx.Registers.Allocate(core.Temp) + c.ctx.Emitter.EmitA(vm.OpLoadZero, zero) // We move the aggregator to a temporary register to access it later from the new loop - aggregator := cc.ctx.Registers.Allocate(core.Temp) - cc.ctx.Emitter.EmitAB(vm.OpMove, aggregator, parentLoop.Dst) + aggregator := c.ctx.Registers.Allocate(core.Temp) + c.ctx.Emitter.EmitAB(vm.OpMove, aggregator, parentLoop.Dst) // CreateFor new loop with 1 iteration only - cc.ctx.Symbols.EnterScope() - cc.ctx.Emitter.EmitABC(vm.OpRange, parentLoop.Src, zero, zero) - loop := cc.ctx.Loops.CreateFor(core.TemporalLoop, parentLoop.Src, parentLoop.Distinct) + c.ctx.Symbols.EnterScope() + c.ctx.Emitter.EmitABC(vm.OpRange, parentLoop.Src, zero, zero) + loop := c.ctx.Loops.CreateFor(core.TemporalLoop, parentLoop.Src, parentLoop.Distinct) loop.Dst = parentLoop.Dst loop.Allocate = true - cc.ctx.Loops.Push(loop) - loop.EmitInitialization(cc.ctx.Registers, cc.ctx.Emitter) + c.ctx.Loops.Push(loop) + loop.EmitInitialization(c.ctx.Registers, c.ctx.Emitter) // We just need to take the grouped values and call aggregation functions using them as args - cc.compileAggregationFuncCall(selectors, aggregator, argsPkg) - cc.ctx.Registers.Free(aggregator) + c.compileAggregationFuncCall(selectors, aggregator, argsPkg) + c.ctx.Registers.Free(aggregator) } -func (cc *LoopCollectCompiler) compileAggregationFuncArgs(selectors []fql.ICollectAggregateSelectorContext, collector vm.Operand) []int { +func (c *LoopCollectCompiler) compileAggregationFuncArgs(selectors []fql.ICollectAggregateSelectorContext, collector vm.Operand) []int { argsPkg := make([]int, len(selectors)) for i := 0; i < len(selectors); i++ { selector := selectors[i] fcx := selector.FunctionCallExpression() - args := cc.ctx.ExprCompiler.CompileArgumentList(fcx.FunctionCall().ArgumentList()) + args := c.ctx.ExprCompiler.CompileArgumentList(fcx.FunctionCall().ArgumentList()) if len(args) == 0 { // TODO: Better error handling panic("No arguments provided for the function call in the aggregate selector") } - aggrKeyReg := loadConstant(cc.ctx, runtime.Int(i)) + aggrKeyReg := loadConstant(c.ctx, runtime.Int(i)) // we keep information about the args - whether we need to unpack them or not argsPkg[i] = len(args) if len(args) > 1 { for y, arg := range args { - argKeyReg := cc.loadAggregationArgKey(i, y) - cc.ctx.Emitter.EmitABC(vm.OpPushKV, collector, argKeyReg, arg) - cc.ctx.Registers.Free(argKeyReg) + argKeyReg := c.loadAggregationArgKey(i, y) + c.ctx.Emitter.EmitABC(vm.OpPushKV, collector, argKeyReg, arg) + c.ctx.Registers.Free(argKeyReg) } } else { - cc.ctx.Emitter.EmitABC(vm.OpPushKV, collector, aggrKeyReg, args[0]) + c.ctx.Emitter.EmitABC(vm.OpPushKV, collector, aggrKeyReg, args[0]) } - cc.ctx.Registers.Free(aggrKeyReg) - cc.ctx.Registers.FreeSequence(args) + c.ctx.Registers.Free(aggrKeyReg) + c.ctx.Registers.FreeSequence(args) } return argsPkg } -func (cc *LoopCollectCompiler) compileAggregationFuncCall(selectors []fql.ICollectAggregateSelectorContext, accumulator vm.Operand, argsPkg []int) { +func (c *LoopCollectCompiler) compileAggregationFuncCall(selectors []fql.ICollectAggregateSelectorContext, accumulator vm.Operand, argsPkg []int) { for i, selector := range selectors { argsNum := argsPkg[i] @@ -124,35 +124,35 @@ func (cc *LoopCollectCompiler) compileAggregationFuncCall(selectors []fql.IColle // We need to unpack arguments if argsNum > 1 { - args = cc.ctx.Registers.AllocateSequence(argsNum) + args = c.ctx.Registers.AllocateSequence(argsNum) for y, reg := range args { - argKeyReg := cc.loadAggregationArgKey(i, y) - cc.ctx.Emitter.EmitABC(vm.OpLoadKey, reg, accumulator, argKeyReg) + argKeyReg := c.loadAggregationArgKey(i, y) + c.ctx.Emitter.EmitABC(vm.OpLoadKey, reg, accumulator, argKeyReg) - cc.ctx.Registers.Free(argKeyReg) + c.ctx.Registers.Free(argKeyReg) } } else { - key := loadConstant(cc.ctx, runtime.Int(i)) - value := cc.ctx.Registers.Allocate(core.Temp) - cc.ctx.Emitter.EmitABC(vm.OpLoadKey, value, accumulator, key) + key := loadConstant(c.ctx, runtime.Int(i)) + value := c.ctx.Registers.Allocate(core.Temp) + c.ctx.Emitter.EmitABC(vm.OpLoadKey, value, accumulator, key) args = core.RegisterSequence{value} - cc.ctx.Registers.Free(key) + c.ctx.Registers.Free(key) } fcx := selector.FunctionCallExpression() - result := cc.ctx.ExprCompiler.CompileFunctionCallWith(fcx.FunctionCall(), fcx.ErrorOperator() != nil, args) + result := c.ctx.ExprCompiler.CompileFunctionCallWith(fcx.FunctionCall(), fcx.ErrorOperator() != nil, args) // We define the variable for the selector result in the upper scope // Since this temporary scope is only for aggregators and will be closed after the aggregation selectorVarName := selector.Identifier().GetText() - varReg := cc.ctx.Symbols.DeclareLocal(selectorVarName) - cc.ctx.Emitter.EmitAB(vm.OpMove, varReg, result) - cc.ctx.Registers.Free(result) + varReg := c.ctx.Symbols.DeclareLocal(selectorVarName) + c.ctx.Emitter.EmitAB(vm.OpMove, varReg, result) + c.ctx.Registers.Free(result) } } -func (cc *LoopCollectCompiler) loadAggregationArgKey(selector int, arg int) vm.Operand { +func (c *LoopCollectCompiler) loadAggregationArgKey(selector int, arg int) vm.Operand { argKey := strconv.Itoa(selector) + ":" + strconv.Itoa(arg) - return loadConstant(cc.ctx, runtime.String(argKey)) + return loadConstant(c.ctx, runtime.String(argKey)) } diff --git a/pkg/compiler/internal/loop_sort.go b/pkg/compiler/internal/loop_sort.go index 8d038725..48a1bce9 100644 --- a/pkg/compiler/internal/loop_sort.go +++ b/pkg/compiler/internal/loop_sort.go @@ -112,33 +112,13 @@ func (c *LoopSortCompiler) compileSorter(loop *core.Loop, clauses []fql.ISortCla encoded := runtime.EncodeSortDirections(directions) count := len(clauses) - if loop.Allocate { - c.ctx.Emitter.PatchSwapAxy(loop.Pos, vm.OpDataSetMultiSorter, loop.Dst, encoded, count) - - return loop.Dst - } - - dst := c.ctx.Registers.Allocate(core.Temp) - c.ctx.Emitter.PatchInsertAxy(loop.Pos, vm.OpDataSetMultiSorter, loop.Dst, encoded, count) - loop.Jump++ - - return dst + return loop.PatchDestinationAxy(c.ctx.Registers, c.ctx.Emitter, vm.OpDataSetMultiSorter, encoded, count) } // Single-key sorting only needs the direction dir := sortDirection(clauses[0].SortDirection()) - if loop.Allocate { - c.ctx.Emitter.PatchSwapAx(loop.Pos, vm.OpDataSetSorter, loop.Dst, int(dir)) - - return loop.Dst - } - - dst := c.ctx.Registers.Allocate(core.Temp) - c.ctx.Emitter.PatchInsertAx(loop.Pos, vm.OpDataSetSorter, dst, int(dir)) - loop.Jump++ - - return dst + return loop.PatchDestinationAx(c.ctx.Registers, c.ctx.Emitter, vm.OpDataSetSorter, int(dir)) } // finalizeSorting completes the sorting process by: diff --git a/test/integration/vm/vm_for_nested_test.go b/test/integration/vm/vm_for_nested_test.go index 0bee81a1..af5945c5 100644 --- a/test/integration/vm/vm_for_nested_test.go +++ b/test/integration/vm/vm_for_nested_test.go @@ -68,6 +68,30 @@ FOR n IN 0..1 RETURN CONCAT(s, n) `, []any{"abc0", "bar0", "foo0", "qaz0", "abc1", "bar1", "foo1", "qaz1"}), CaseArray(` +LET users = [ + { + name: "Ron", + age: 31, + gender: "m" + }, + { + name: "Angela", + age: 29, + gender: "f" + }, + { + name: "Bob", + age: 36, + gender: "m" + } +] + +FOR n IN 0..1 + FOR u IN users + SORT u.gender, u.age + RETURN CONCAT(u.name, n) +`, []any{"Angela0", "Ron0", "Bob0", "Angela1", "Ron1", "Bob1"}), + CaseArray(` LET strs = ["foo", "bar", "qaz", "abc"] FOR n IN 0..1 @@ -85,5 +109,81 @@ FOR n IN 0..1 FOR m IN 0..1 RETURN CONCAT(s, n, m) `, []any{"abc00", "abc01", "bar00", "bar01", "foo00", "foo01", "qaz00", "qaz01", "abc10", "abc11", "bar10", "bar11", "foo10", "foo11", "qaz10", "qaz11"}), + CaseArray(` +LET users = [ + { + active: true, + married: true, + age: 31, + gender: "m" + }, + { + active: true, + married: false, + age: 25, + gender: "f" + }, + { + active: true, + married: false, + age: 36, + gender: "m" + }, + { + active: false, + married: true, + age: 69, + gender: "m" + }, + { + active: true, + married: true, + age: 45, + gender: "f" + } +] +FOR n IN 0..1 + FOR i IN users + COLLECT gender = i.gender + RETURN CONCAT(gender, n) +`, []any{"f0", "m0", "f1", "m1"}), + CaseArray(` +LET users = [ + { + active: true, + married: true, + age: 31, + gender: "m" + }, + { + active: true, + married: false, + age: 25, + gender: "f" + }, + { + active: true, + married: false, + age: 36, + gender: "m" + }, + { + active: false, + married: true, + age: 69, + gender: "m" + }, + { + active: true, + married: true, + age: 45, + gender: "f" + } +] +FOR i IN users + COLLECT gender = i.gender + FOR n IN 0..1 + RETURN CONCAT(gender, n) +`, []any{"f0", "f1", "m0", "m1"}), }) }