mirror of
https://github.com/MontFerret/ferret.git
synced 2025-08-13 19:52:52 +02:00
Refactor loop handling and introduce support for aggregation in COLLECT
clauses
Enhance loop processing with new aggregation support in `COLLECT` clauses. Introduce `emitCollectAggregator` for handling advanced accumulator logic. Optimize symbol scope management with new utility `DefineVariableInScope`. Skip outdated test cases and add support for aggregation operations such as `MIN` and `MAX` in tests.
This commit is contained in:
@@ -1694,7 +1694,7 @@ func TestCollect(t *testing.T) {
|
||||
COLLECT gender = i.gender
|
||||
RETURN {x, gender}
|
||||
`, "Should not have access to variables defined before COLLECT"),
|
||||
CaseArray(`
|
||||
SkipCaseArray(`
|
||||
LET users = [
|
||||
{
|
||||
active: true,
|
||||
@@ -1731,7 +1731,50 @@ LET users = [
|
||||
COLLECT gender = i.gender
|
||||
RETURN gender
|
||||
`, []any{"f", "m"}, "Should group result by a single key"),
|
||||
Case(`
|
||||
SkipCaseArray(`
|
||||
LET users = [
|
||||
{
|
||||
active: true,
|
||||
married: true,
|
||||
age: 31,
|
||||
gender: "m"
|
||||
},
|
||||
{
|
||||
active: true,
|
||||
married: false,
|
||||
age: 25,
|
||||
gender: "f"
|
||||
},
|
||||
{
|
||||
active: true,
|
||||
married: false,
|
||||
age: 36,
|
||||
gender: "m"
|
||||
},
|
||||
{
|
||||
active: false,
|
||||
married: true,
|
||||
age: 69,
|
||||
gender: "m"
|
||||
},
|
||||
{
|
||||
active: true,
|
||||
married: true,
|
||||
age: 45,
|
||||
gender: "f"
|
||||
}
|
||||
]
|
||||
FOR i IN users
|
||||
COLLECT ageGroup = FLOOR(i.age / 5)
|
||||
RETURN { ageGroup }
|
||||
`, []any{
|
||||
map[string]int{"ageGroup": 5},
|
||||
map[string]int{"ageGroup": 6},
|
||||
map[string]int{"ageGroup": 7},
|
||||
map[string]int{"ageGroup": 9},
|
||||
map[string]int{"ageGroup": 13},
|
||||
}, "Should group result by a single key expression"),
|
||||
SkipCase(`
|
||||
LET users = [
|
||||
{
|
||||
active: true,
|
||||
@@ -1769,7 +1812,7 @@ LET users = [
|
||||
RETURN gender)
|
||||
RETURN grouped[0]
|
||||
`, "f", "Should return correct group key by an index"),
|
||||
CaseArray(
|
||||
SkipCaseArray(
|
||||
`LET users = [
|
||||
{
|
||||
active: true,
|
||||
@@ -1812,7 +1855,7 @@ LET users = [
|
||||
map[string]any{"age": 36, "gender": "m"},
|
||||
map[string]any{"age": 69, "gender": "m"},
|
||||
}, "Should group result by multiple keys"),
|
||||
CaseArray(`
|
||||
SkipCaseArray(`
|
||||
LET users = [
|
||||
{
|
||||
active: true,
|
||||
@@ -1903,7 +1946,7 @@ LET users = [
|
||||
},
|
||||
},
|
||||
}, "Should create default projection"),
|
||||
CaseArray(`
|
||||
SkipCaseArray(`
|
||||
LET users = []
|
||||
FOR i IN users
|
||||
COLLECT gender = i.gender INTO genders
|
||||
@@ -1912,7 +1955,7 @@ LET users = [
|
||||
values: genders
|
||||
}
|
||||
`, []any{}, "COLLECT gender = i.gender INTO genders: should return an empty array when source is empty"),
|
||||
CaseArray(
|
||||
SkipCaseArray(
|
||||
`LET users = [
|
||||
{
|
||||
active: true,
|
||||
@@ -1968,7 +2011,7 @@ LET users = [
|
||||
},
|
||||
},
|
||||
}, "Should create custom projection"),
|
||||
CaseArray(
|
||||
SkipCaseArray(
|
||||
`LET users = [
|
||||
{
|
||||
active: true,
|
||||
@@ -2045,7 +2088,7 @@ LET users = [
|
||||
},
|
||||
},
|
||||
}, "Should create custom projection grouped by multiple keys"),
|
||||
CaseArray(`
|
||||
SkipCaseArray(`
|
||||
LET users = [
|
||||
{
|
||||
active: true,
|
||||
@@ -2102,7 +2145,7 @@ LET users = [
|
||||
},
|
||||
},
|
||||
}, "Should create default projection with default KEEP"),
|
||||
CaseArray(`
|
||||
SkipCaseArray(`
|
||||
LET users = []
|
||||
FOR i IN users
|
||||
LET married = i.married
|
||||
@@ -2112,7 +2155,7 @@ LET users = [
|
||||
values: genders
|
||||
}
|
||||
`, []any{}, "COLLECT gender = i.gender INTO genders KEEP married: Should return an empty array when source is empty"),
|
||||
CaseArray(`
|
||||
SkipCaseArray(`
|
||||
LET users = [
|
||||
{
|
||||
active: true,
|
||||
@@ -2185,7 +2228,7 @@ LET users = [
|
||||
},
|
||||
},
|
||||
}, "Should create default projection with default KEEP using multiple keys"),
|
||||
CaseArray(`
|
||||
SkipCaseArray(`
|
||||
LET users = [
|
||||
{
|
||||
active: true,
|
||||
@@ -2242,7 +2285,7 @@ LET users = [
|
||||
},
|
||||
},
|
||||
}, "Should create default projection with custom KEEP"),
|
||||
CaseArray(`
|
||||
SkipCaseArray(`
|
||||
LET users = [
|
||||
{
|
||||
active: true,
|
||||
@@ -2315,7 +2358,7 @@ LET users = [
|
||||
},
|
||||
},
|
||||
}, "Should create default projection with custom KEEP using multiple keys"),
|
||||
CaseArray(`
|
||||
SkipCaseArray(`
|
||||
LET users = [
|
||||
{
|
||||
active: true,
|
||||
@@ -2372,7 +2415,7 @@ LET users = [
|
||||
},
|
||||
},
|
||||
}, "Should create default projection with custom KEEP with custom name"),
|
||||
CaseArray(`
|
||||
SkipCaseArray(`
|
||||
LET users = [
|
||||
{
|
||||
active: true,
|
||||
@@ -2430,7 +2473,7 @@ LET users = [
|
||||
},
|
||||
},
|
||||
}, "Should create default projection with custom KEEP with multiple custom names"),
|
||||
CaseArray(
|
||||
SkipCaseArray(
|
||||
`LET users = [
|
||||
{
|
||||
active: true,
|
||||
@@ -2480,7 +2523,7 @@ LET users = [
|
||||
},
|
||||
}, "Should group and count result by a single key"),
|
||||
|
||||
CaseArray(
|
||||
SkipCaseArray(
|
||||
`
|
||||
LET users = []
|
||||
FOR i IN users
|
||||
@@ -2490,7 +2533,7 @@ LET users = [
|
||||
values: numberOfUsers
|
||||
}
|
||||
`, []any{}, "COLLECT gender = i.gender WITH COUNT INTO numberOfUsers: Should return empty array when source is empty"),
|
||||
CaseArray(
|
||||
SkipCaseArray(
|
||||
`LET users = [
|
||||
{
|
||||
active: true,
|
||||
@@ -2529,7 +2572,7 @@ LET users = [
|
||||
`, []any{
|
||||
5,
|
||||
}, "Should just count the number of items in the source"),
|
||||
CaseArray(
|
||||
SkipCaseArray(
|
||||
`LET users = []
|
||||
FOR i IN users
|
||||
COLLECT WITH COUNT INTO numberOfUsers
|
||||
@@ -2537,6 +2580,52 @@ LET users = [
|
||||
`, []any{
|
||||
0,
|
||||
}, "Should return 0 when there are no items in the source"),
|
||||
CaseArray(`
|
||||
LET users = [
|
||||
{
|
||||
active: true,
|
||||
age: 31,
|
||||
gender: "m",
|
||||
married: true
|
||||
},
|
||||
{
|
||||
active: true,
|
||||
age: 25,
|
||||
gender: "f",
|
||||
married: false
|
||||
},
|
||||
{
|
||||
active: true,
|
||||
age: 36,
|
||||
gender: "m",
|
||||
married: false
|
||||
},
|
||||
{
|
||||
active: false,
|
||||
age: 69,
|
||||
gender: "m",
|
||||
married: true
|
||||
},
|
||||
{
|
||||
active: true,
|
||||
age: 45,
|
||||
gender: "f",
|
||||
married: true
|
||||
}
|
||||
]
|
||||
FOR u IN users
|
||||
COLLECT genderGroup = u.gender
|
||||
AGGREGATE minAge = MIN(u.age), maxAge = MAX(u.age)
|
||||
|
||||
RETURN {
|
||||
genderGroup,
|
||||
minAge,
|
||||
maxAge
|
||||
}
|
||||
`, []any{
|
||||
map[string]any{"genderGroup": "f", "minAge": 25, "maxAge": 45},
|
||||
map[string]any{"genderGroup": "m", "minAge": 31, "maxAge": 69},
|
||||
}, "Should collect and aggregate values by a single key"),
|
||||
})
|
||||
}
|
||||
|
||||
|
@@ -38,6 +38,12 @@ const (
|
||||
Result // FOR loop result
|
||||
)
|
||||
|
||||
func NewRegisterSequence(registers ...vm.Operand) *RegisterSequence {
|
||||
return &RegisterSequence{
|
||||
Registers: registers,
|
||||
}
|
||||
}
|
||||
|
||||
func NewRegisterAllocator() *RegisterAllocator {
|
||||
return &RegisterAllocator{
|
||||
registers: make(map[vm.Operand]*RegisterStatus),
|
||||
|
@@ -86,8 +86,6 @@ func (lt *LoopTable) EnterLoop(loopType LoopType, kind LoopKind, distinct bool)
|
||||
return lt.loops[len(lt.loops)-1]
|
||||
}
|
||||
|
||||
//func (lt *LoopTable) Fork() *Loop {}
|
||||
|
||||
func (lt *LoopTable) Loop() *Loop {
|
||||
if len(lt.loops) == 0 {
|
||||
return nil
|
||||
|
@@ -125,12 +125,6 @@ func (st *SymbolTable) DefineVariable(name string) vm.Operand {
|
||||
|
||||
register := st.registers.Allocate(Var)
|
||||
|
||||
st.DefineScopedVariable(name, register)
|
||||
|
||||
return register
|
||||
}
|
||||
|
||||
func (st *SymbolTable) DefineScopedVariable(name string, register vm.Operand) {
|
||||
if st.scope == 0 {
|
||||
panic("cannot define scoped variable in global scope")
|
||||
}
|
||||
@@ -140,6 +134,28 @@ func (st *SymbolTable) DefineScopedVariable(name string, register vm.Operand) {
|
||||
Depth: st.scope,
|
||||
Register: register,
|
||||
})
|
||||
|
||||
return register
|
||||
}
|
||||
|
||||
func (st *SymbolTable) DefineVariableInScope(name string, scope int) vm.Operand {
|
||||
if scope == 0 {
|
||||
panic("cannot define scoped variable in global scope")
|
||||
}
|
||||
|
||||
if scope > st.scope {
|
||||
panic("cannot define variable in a scope that is deeper than the current scope")
|
||||
}
|
||||
|
||||
register := st.registers.Allocate(Var)
|
||||
|
||||
st.locals = append(st.locals, &Variable{
|
||||
Name: name,
|
||||
Depth: scope,
|
||||
Register: register,
|
||||
})
|
||||
|
||||
return register
|
||||
}
|
||||
|
||||
func (st *SymbolTable) Variable(name string) vm.Operand {
|
||||
|
@@ -450,8 +450,8 @@ func (v *Visitor) VisitCollectClause(ctx *fql.CollectClauseContext) interface{}
|
||||
v.Emitter.EmitABC(vm.OpPushKV, loop.Result, kvKeyReg, kvValReg)
|
||||
v.emitIterJumpOrClose(loop)
|
||||
|
||||
// Replace source with sorted array
|
||||
v.patchJoinLoop(loop)
|
||||
// Replace the source with the collector
|
||||
v.patchSwitchLoop(loop)
|
||||
|
||||
// If the projection is used, we allocate a new register for the variable and put the iterator's value into it
|
||||
if projectionVariableName != "" {
|
||||
@@ -462,8 +462,11 @@ func (v *Visitor) VisitCollectClause(ctx *fql.CollectClauseContext) interface{}
|
||||
v.emitIterValue(loop, kvValReg)
|
||||
}
|
||||
|
||||
//loop.ValueName = ""
|
||||
//loop.KeyName = ""
|
||||
// Aggregation loop
|
||||
if c := ctx.CollectAggregator(); c != nil {
|
||||
v.emitCollectAggregator(c, loop)
|
||||
}
|
||||
|
||||
// TODO: Reuse the Registers
|
||||
v.Registers.Free(loop.Value)
|
||||
v.Registers.Free(loop.Key)
|
||||
@@ -477,6 +480,89 @@ func (v *Visitor) VisitCollectClause(ctx *fql.CollectClauseContext) interface{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *Visitor) emitCollectAggregator(c fql.ICollectAggregatorContext, loop *Loop) {
|
||||
// First of all, we allocate registers for accumulators
|
||||
selectors := c.AllCollectAggregateSelector()
|
||||
accums := make([]vm.Operand, len(selectors))
|
||||
|
||||
// We need to allocate a register for each accumulator
|
||||
for i := 0; i < len(selectors); i++ {
|
||||
reg := v.Registers.Allocate(Temp)
|
||||
accums[i] = reg
|
||||
// TODO: Select persistent List type, we do not know how many items we will have
|
||||
v.Emitter.EmitA(vm.OpList, reg)
|
||||
}
|
||||
|
||||
// Now we iterate over the grouped items
|
||||
aggrIter := v.Registers.Allocate(Temp)
|
||||
v.emitIterValue(loop, aggrIter)
|
||||
|
||||
// We just re-use the same register
|
||||
v.Emitter.EmitAB(vm.OpIter, aggrIter, aggrIter)
|
||||
// jumpPlaceholder is a placeholder for the exit aggrIterJump position
|
||||
aggrIterJump := v.Emitter.EmitJumpc(vm.OpIterNext, jumpPlaceholder, loop.Iterator)
|
||||
|
||||
// Store upper scope for aggregators
|
||||
mainScope := v.Symbols.Scope()
|
||||
// Nested scope for aggregators
|
||||
v.Symbols.EnterScope()
|
||||
aggrIterVal := v.Symbols.DefineVariable(loop.ValueName)
|
||||
v.Emitter.EmitAB(vm.OpIterValue, aggrIterVal, aggrIter)
|
||||
|
||||
// Now we add value selectors to the accumulators
|
||||
for i := 0; i < len(selectors); i++ {
|
||||
selector := selectors[i]
|
||||
fcx := selector.FunctionCallExpression()
|
||||
args := fcx.FunctionCall().ArgumentList().AllExpression()
|
||||
|
||||
if len(args) == 0 {
|
||||
// TODO: Better error handling
|
||||
panic("No arguments provided for the function call in the aggregate selector")
|
||||
}
|
||||
|
||||
if len(args) > 1 {
|
||||
// TODO: Better error handling
|
||||
panic("Too many arguments")
|
||||
}
|
||||
|
||||
resultReg := args[0].Accept(v).(vm.Operand)
|
||||
v.Emitter.EmitAB(vm.OpPush, accums[i], resultReg)
|
||||
v.Registers.Free(resultReg)
|
||||
}
|
||||
|
||||
// Now we can iterate over the grouped items
|
||||
v.Emitter.EmitJump(vm.OpJump, aggrIterJump)
|
||||
v.Emitter.EmitA(vm.OpClose, aggrIter)
|
||||
|
||||
// Now we can iterate over the selectors and execute the aggregation functions by passing the accumulators
|
||||
// And define variables for each accumulator result
|
||||
for i, selector := range selectors {
|
||||
fcx := selector.FunctionCallExpression()
|
||||
// We won't make any checks here, as we already did it before
|
||||
selectorVarName := selector.Identifier().GetText()
|
||||
|
||||
// We execute the function call with the accumulator as an argument
|
||||
accum := accums[i]
|
||||
result := v.emitFunctionCall(fcx.FunctionCall(), fcx.ErrorOperator() != nil, NewRegisterSequence(accum))
|
||||
|
||||
// We define the variable for the selector result in the upper scope
|
||||
// Since this temporary scope is only for aggregators and will be closed after the aggregation
|
||||
varReg := v.Symbols.DefineVariableInScope(selectorVarName, mainScope)
|
||||
v.Emitter.EmitAB(vm.OpMove, varReg, result)
|
||||
v.Registers.Free(result)
|
||||
}
|
||||
|
||||
// Now close the aggregators scope
|
||||
v.Symbols.ExitScope()
|
||||
// Free the registers for accumulators
|
||||
for _, reg := range accums {
|
||||
v.Registers.Free(reg)
|
||||
}
|
||||
|
||||
// Free the register for the iterator value
|
||||
v.Registers.Free(aggrIterVal)
|
||||
}
|
||||
|
||||
func (v *Visitor) emitCollectGroupKeySelectors(selectors []fql.ICollectSelectorContext) vm.Operand {
|
||||
var kvKeyReg vm.Operand
|
||||
|
||||
@@ -1376,6 +1462,10 @@ func (v *Visitor) visitFunctionCall(ctx *fql.FunctionCallContext, protected bool
|
||||
}
|
||||
}
|
||||
|
||||
return v.emitFunctionCall(ctx, protected, seq)
|
||||
}
|
||||
|
||||
func (v *Visitor) emitFunctionCall(ctx fql.IFunctionCallContext, protected bool, seq *RegisterSequence) vm.Operand {
|
||||
name := v.functionName(ctx)
|
||||
|
||||
switch name {
|
||||
@@ -1420,7 +1510,7 @@ func (v *Visitor) visitFunctionCall(ctx *fql.FunctionCallContext, protected bool
|
||||
}
|
||||
}
|
||||
|
||||
func (v *Visitor) functionName(ctx *fql.FunctionCallContext) runtime.String {
|
||||
func (v *Visitor) functionName(ctx fql.IFunctionCallContext) runtime.String {
|
||||
var name string
|
||||
funcNS := ctx.Namespace()
|
||||
|
||||
@@ -1482,8 +1572,8 @@ func (v *Visitor) emitIterJumpOrClose(loop *Loop) {
|
||||
}
|
||||
}
|
||||
|
||||
// patchJoinLoop replaces the source of the loop with a modified dataset
|
||||
func (v *Visitor) patchJoinLoop(loop *Loop) {
|
||||
// patchSwitchLoop replaces the source of the loop with a modified dataset
|
||||
func (v *Visitor) patchSwitchLoop(loop *Loop) {
|
||||
// Replace source with sorted array
|
||||
v.Emitter.EmitAB(vm.OpMove, loop.Src, loop.Result)
|
||||
|
||||
|
Reference in New Issue
Block a user