From 38625ad059f9b153d03a522a688fc104d1b351b3 Mon Sep 17 00:00:00 2001 From: Tim Voronov Date: Fri, 6 Jun 2025 16:34:14 -0400 Subject: [PATCH] Refactor collectors and sorters; introduce Transformers Refactor `Collector` into `Transformer` for enhanced flexibility and modularity. Introduce `Sorter` and `MultiSorter` as specialized transformers to handle sorting operations. Streamline VM operations by replacing dataset-based methods with transformer logic. Add encoding/decoding utilities for multiple sorting directions. Optimize `Emit` logic and update related tests. --- pkg/compiler/internal/emitter.go | 21 +++ pkg/compiler/internal/visitor.go | 85 +++++------ pkg/{vm/internal => runtime}/box.go | 30 +++- pkg/runtime/sort_direction.go | 42 ++++++ pkg/vm/const.go | 10 -- pkg/vm/internal/collector.go | 41 +----- pkg/vm/internal/collector_counter.go | 17 ++- pkg/vm/internal/collector_key.go | 20 +-- pkg/vm/internal/collector_key_counter.go | 24 ++-- pkg/vm/internal/collector_key_group.go | 20 +-- pkg/vm/internal/dataset.go | 171 +---------------------- pkg/vm/internal/sorter.go | 59 ++++++++ pkg/vm/internal/stream.go | 4 +- pkg/vm/internal/transformer.go | 14 ++ pkg/vm/opcode.go | 27 ++-- pkg/vm/vm.go | 106 ++------------ 16 files changed, 266 insertions(+), 425 deletions(-) rename pkg/{vm/internal => runtime}/box.go (51%) create mode 100644 pkg/runtime/sort_direction.go delete mode 100644 pkg/vm/const.go create mode 100644 pkg/vm/internal/sorter.go create mode 100644 pkg/vm/internal/transformer.go diff --git a/pkg/compiler/internal/emitter.go b/pkg/compiler/internal/emitter.go index 06e0364d..7b267a9a 100644 --- a/pkg/compiler/internal/emitter.go +++ b/pkg/compiler/internal/emitter.go @@ -43,6 +43,13 @@ func (e *Emitter) EmitJumpc(op vm.Opcode, pos int, reg vm.Operand) int { return len(e.instructions) - 1 } +func (e *Emitter) PatchSwapAB(pos int, op vm.Opcode, dst, src1 vm.Operand) { + e.instructions[pos] = vm.Instruction{ + Opcode: op, + Operands: [3]vm.Operand{dst, src1, vm.NoopOperand}, + } +} + func (e *Emitter) PatchSwapAx(pos int, op vm.Opcode, dst vm.Operand, arg int) { e.instructions[pos] = vm.Instruction{ Opcode: op, @@ -50,6 +57,20 @@ func (e *Emitter) PatchSwapAx(pos int, op vm.Opcode, dst vm.Operand, arg int) { } } +func (e *Emitter) PatchSwapAxy(pos int, op vm.Opcode, dst vm.Operand, arg1, agr2 int) { + e.instructions[pos] = vm.Instruction{ + Opcode: op, + Operands: [3]vm.Operand{dst, vm.Operand(arg1), vm.Operand(agr2)}, + } +} + +func (e *Emitter) PatchSwapAs(pos int, op vm.Opcode, dst vm.Operand, seq *RegisterSequence) { + e.instructions[pos] = vm.Instruction{ + Opcode: op, + Operands: [3]vm.Operand{dst, seq.Registers[0], seq.Registers[len(seq.Registers)-1]}, + } +} + // PatchJump patches a jump opcode. func (e *Emitter) PatchJump(instr int) { e.instructions[instr].Operands[0] = vm.Operand(len(e.instructions) - 1) diff --git a/pkg/compiler/internal/visitor.go b/pkg/compiler/internal/visitor.go index 58cd0bb0..1d712fd6 100644 --- a/pkg/compiler/internal/visitor.go +++ b/pkg/compiler/internal/visitor.go @@ -411,13 +411,12 @@ func (v *Visitor) VisitCollectClause(ctx *fql.CollectClauseContext) interface{} kvValReg := v.Registers.Allocate(Temp) var groupSelectors []fql.ICollectSelectorContext var isGrouping bool - var isCounting bool grouping := ctx.CollectGrouping() if grouping != nil { isGrouping = true groupSelectors = grouping.AllCollectSelector() - kvKeyReg = v.emitGroupingKeySelectors(groupSelectors) + kvKeyReg = v.emitCollectGroupKeySelectors(groupSelectors) } v.emitIterValue(loop, kvValReg) @@ -430,17 +429,15 @@ func (v *Visitor) VisitCollectClause(ctx *fql.CollectClauseContext) interface{} if groupVar := ctx.CollectGroupVariable(); groupVar != nil { // Projection can be either a default projection (identifier) or a custom projection (selector expression) if identifier := groupVar.Identifier(); identifier != nil { - projectionVariableName = v.emitDefaultCollectGroupProjection(loop, kvValReg, identifier, groupVar.CollectGroupVariableKeeper()) + projectionVariableName = v.emitCollectDefaultGroupProjection(loop, kvValReg, identifier, groupVar.CollectGroupVariableKeeper()) } else if selector := groupVar.CollectSelector(); selector != nil { - projectionVariableName = v.emitCustomCollectGroupProjection(loop, kvValReg, selector) + projectionVariableName = v.emitCollectCustomGroupProjection(loop, kvValReg, selector) } collectorType = 3 } else if countVar := ctx.CollectCounter(); countVar != nil { projectionVariableName = v.emitCollectCountProjection(loop, kvValReg, countVar) - isCounting = true - if isGrouping { collectorType = 2 } else { @@ -449,12 +446,12 @@ func (v *Visitor) VisitCollectClause(ctx *fql.CollectClauseContext) interface{} } // We replace DataSet initialization with Collector initialization - v.Emitter.PatchSwapAx(loop.ResultPos, vm.OpCollector, loop.Result, collectorType) + v.Emitter.PatchSwapAx(loop.ResultPos, vm.OpDataSetCollector, loop.Result, collectorType) v.Emitter.EmitABC(vm.OpPushKV, loop.Result, kvKeyReg, kvValReg) v.emitIterJumpOrClose(loop) // Replace source with sorted array - v.patchLoop(loop) + v.patchJoinLoop(loop) // If the projection is used, we allocate a new register for the variable and put the iterator's value into it if projectionVariableName != "" { @@ -465,9 +462,6 @@ func (v *Visitor) VisitCollectClause(ctx *fql.CollectClauseContext) interface{} v.emitIterValue(loop, kvValReg) } - if isCounting { - } - //loop.ValueName = "" //loop.KeyName = "" // TODO: Reuse the Registers @@ -477,13 +471,13 @@ func (v *Visitor) VisitCollectClause(ctx *fql.CollectClauseContext) interface{} loop.Key = vm.NoopOperand if isGrouping { - v.emitGroupingKeySelectorVariables(groupSelectors, kvValReg) + v.emitCollectGroupKeySelectorVariables(groupSelectors, kvValReg) } return nil } -func (v *Visitor) emitGroupingKeySelectors(selectors []fql.ICollectSelectorContext) vm.Operand { +func (v *Visitor) emitCollectGroupKeySelectors(selectors []fql.ICollectSelectorContext) vm.Operand { var kvKeyReg vm.Operand if len(selectors) > 1 { @@ -508,7 +502,7 @@ func (v *Visitor) emitGroupingKeySelectors(selectors []fql.ICollectSelectorConte return kvKeyReg } -func (v *Visitor) emitGroupingKeySelectorVariables(selectors []fql.ICollectSelectorContext, kvValReg vm.Operand) { +func (v *Visitor) emitCollectGroupKeySelectorVariables(selectors []fql.ICollectSelectorContext, kvValReg vm.Operand) { if len(selectors) > 1 { variables := make([]vm.Operand, len(selectors)) @@ -536,7 +530,7 @@ func (v *Visitor) emitGroupingKeySelectorVariables(selectors []fql.ICollectSelec } } -func (v *Visitor) emitDefaultCollectGroupProjection(loop *Loop, kvValReg vm.Operand, identifier antlr.TerminalNode, keeper fql.ICollectGroupVariableKeeperContext) string { +func (v *Visitor) emitCollectDefaultGroupProjection(loop *Loop, kvValReg vm.Operand, identifier antlr.TerminalNode, keeper fql.ICollectGroupVariableKeeperContext) string { if keeper == nil { seq := v.Registers.AllocateSequence(2) // Key and Value for Map @@ -564,7 +558,7 @@ func (v *Visitor) emitDefaultCollectGroupProjection(loop *Loop, kvValReg vm.Oper return identifier.GetText() } -func (v *Visitor) emitCustomCollectGroupProjection(_ *Loop, kvValReg vm.Operand, selector fql.ICollectSelectorContext) string { +func (v *Visitor) emitCollectCustomGroupProjection(_ *Loop, kvValReg vm.Operand, selector fql.ICollectSelectorContext) string { selectorReg := selector.Expression().Accept(v).(vm.Operand) v.Emitter.EmitAB(vm.OpMove, kvValReg, selectorReg) v.Registers.Free(selectorReg) @@ -593,26 +587,21 @@ func (v *Visitor) VisitSortClause(ctx *fql.SortClauseContext) interface{} { // These KeyValuePairs are then added to the dataset kvKeyReg := v.Registers.Allocate(Temp) clauses := ctx.AllSortClauseExpression() + var directions []runtime.SortDirection isSortMany := len(clauses) > 1 - // For multi-sort - var directionRegs *RegisterSequence - if isSortMany { clausesRegs := make([]vm.Operand, len(clauses)) + directions = make([]runtime.SortDirection, len(clauses)) // We create a sequence of Registers for the clauses // To pack them into an array keyRegs := v.Registers.AllocateSequence(len(clauses)) - // We create a sequence of Registers for the directions - directionRegs = v.Registers.AllocateSequence(len(clauses)) - for i, clause := range clauses { clauseReg := clause.Accept(v).(vm.Operand) v.Emitter.EmitAB(vm.OpMove, keyRegs.Registers[i], clauseReg) clausesRegs[i] = keyRegs.Registers[i] - v.visitSortDirection(clause.SortDirection(), directionRegs.Registers[i]) - + directions[i] = v.sortDirection(clause.SortDirection()) // TODO: Free Registers } @@ -635,38 +624,38 @@ func (v *Visitor) VisitSortClause(ctx *fql.SortClauseContext) interface{} { v.emitIterValue(loop, kvValReg) } + if isSortMany { + encoded := runtime.EncodeSortDirections(directions) + count := len(clauses) + + v.Emitter.PatchSwapAxy(loop.ResultPos, vm.OpDataSetMultiSorter, loop.Result, encoded, count) + } else { + dir := v.sortDirection(clauses[0].SortDirection()) + v.Emitter.PatchSwapAx(loop.ResultPos, vm.OpDataSetSorter, loop.Result, int(dir)) + } + v.Emitter.EmitABC(vm.OpPushKV, loop.Result, kvKeyReg, kvValReg) v.emitIterJumpOrClose(loop) - if isSortMany { - v.Emitter.EmitAs(vm.OpSortMany, loop.Result, directionRegs) - } else { - directionReg := v.Registers.Allocate(Temp) - v.visitSortDirection(clauses[0].SortDirection(), directionReg) - v.Emitter.EmitAB(vm.OpSort, loop.Result, directionReg) - } - - // Replace source with sorted array + // Replace source with the Sorter v.Emitter.EmitAB(vm.OpMove, loop.Src, loop.Result) - // Create new for loop - // TODO: Reuse existing DataSet instance + // Create a new loop v.emitLoopBegin(loop) return nil } -func (v *Visitor) visitSortDirection(dir antlr.TerminalNode, dest vm.Operand) { - var val runtime.Int = vm.SortAsc - - if dir != nil { - if strings.ToLower(dir.GetText()) == "desc" { - val = vm.SortDesc - } +func (v *Visitor) sortDirection(dir antlr.TerminalNode) runtime.SortDirection { + if dir == nil { + return runtime.SortDirectionAsc } - // TODO: Free constant Registers - v.Emitter.EmitAB(vm.OpMove, dest, v.loadConstant(val)) + if strings.ToLower(dir.GetText()) == "desc" { + return runtime.SortDirectionDesc + } + + return runtime.SortDirectionAsc } func (v *Visitor) VisitSortClauseExpression(ctx *fql.SortClauseExpressionContext) interface{} { @@ -675,14 +664,14 @@ func (v *Visitor) VisitSortClauseExpression(ctx *fql.SortClauseExpressionContext func (v *Visitor) visitOffset(src1 vm.Operand) interface{} { state := v.Registers.Allocate(State) - v.Emitter.EmitABx(vm.OpSkip, state, src1, v.Loops.Loop().Jump) + v.Emitter.EmitABx(vm.OpIterSkip, state, src1, v.Loops.Loop().Jump) return state } func (v *Visitor) visitLimit(src1 vm.Operand) interface{} { state := v.Registers.Allocate(State) - v.Emitter.EmitABx(vm.OpLimit, state, src1, v.Loops.Loop().Jump) + v.Emitter.EmitABx(vm.OpIterLimit, state, src1, v.Loops.Loop().Jump) return state } @@ -1493,8 +1482,8 @@ func (v *Visitor) emitIterJumpOrClose(loop *Loop) { } } -// patchLoop replaces the source of the loop with a modified dataset -func (v *Visitor) patchLoop(loop *Loop) { +// patchJoinLoop replaces the source of the loop with a modified dataset +func (v *Visitor) patchJoinLoop(loop *Loop) { // Replace source with sorted array v.Emitter.EmitAB(vm.OpMove, loop.Src, loop.Result) diff --git a/pkg/vm/internal/box.go b/pkg/runtime/box.go similarity index 51% rename from pkg/vm/internal/box.go rename to pkg/runtime/box.go index 138e6e59..981ff61e 100644 --- a/pkg/vm/internal/box.go +++ b/pkg/runtime/box.go @@ -1,15 +1,22 @@ -package internal +package runtime import ( - "github.com/wI2L/jettison" + "hash/fnv" - "github.com/MontFerret/ferret/pkg/runtime" + "github.com/wI2L/jettison" ) +// Box is a generic wrapper for any value type. type Box[T any] struct { Value T } +func NewBox[T any](value T) *Box[T] { + return &Box[T]{ + Value: value, + } +} + func (v *Box[T]) MarshalJSON() ([]byte, error) { return jettison.MarshalOpts(v.Value, jettison.NoHTMLEscaping()) } @@ -23,9 +30,22 @@ func (v *Box[T]) Unwrap() interface{} { } func (v *Box[T]) Hash() uint64 { - panic("not supported") + h := fnv.New64a() + + _, _ = h.Write([]byte("box:")) + + data, err := v.MarshalJSON() + + if err != nil { + // TODO: Panic? + return 0 + } + + _, _ = h.Write(data) + + return h.Sum64() } -func (v *Box[T]) Copy() runtime.Value { +func (v *Box[T]) Copy() Value { return &Box[T]{Value: v.Value} } diff --git a/pkg/runtime/sort_direction.go b/pkg/runtime/sort_direction.go new file mode 100644 index 00000000..90c81836 --- /dev/null +++ b/pkg/runtime/sort_direction.go @@ -0,0 +1,42 @@ +package runtime + +// SortDirection represents the sorting direction, either ascending or descending. +type SortDirection = Int + +const ( + SortDirectionAsc SortDirection = iota // Ascending sort direction + SortDirectionDesc // Descending sort direction +) + +func NewSortDirection(direction Int) SortDirection { + if direction == 0 { + return SortDirectionAsc + } + + return SortDirectionDesc +} + +// EncodeSortDirections encodes a slice of SortDirection values into a single integer by combining their bit representations. +func EncodeSortDirections(directions []SortDirection) int { + result := 0 + + for _, dir := range directions { + result = (result << 1) | int(dir) + } + + return result +} + +// DecodeSortDirections decodes an integer into a slice of SortDirection values representing sorting directions. +// The number of decoded directions is determined by the count argument. +// Each bit of the encoded integer corresponds to a SortDirection value in the resulting slice. +func DecodeSortDirections(encoded int, count int) []SortDirection { + directions := make([]SortDirection, count) + + for i := count - 1; i >= 0; i-- { + directions[i] = SortDirection(encoded & 1) + encoded >>= 1 + } + + return directions +} diff --git a/pkg/vm/const.go b/pkg/vm/const.go deleted file mode 100644 index a8fe662f..00000000 --- a/pkg/vm/const.go +++ /dev/null @@ -1,10 +0,0 @@ -package vm - -import ( - "github.com/MontFerret/ferret/pkg/vm/internal" -) - -const ( - SortAsc = internal.SortAsc - SortDesc = internal.SortDesc -) diff --git a/pkg/vm/internal/collector.go b/pkg/vm/internal/collector.go index 3559c72d..8be74d15 100644 --- a/pkg/vm/internal/collector.go +++ b/pkg/vm/internal/collector.go @@ -1,23 +1,6 @@ package internal -import ( - "context" - - "github.com/MontFerret/ferret/pkg/runtime" -) - -type ( - CollectorType int - - Collector interface { - runtime.Value - runtime.Iterable - - Collect(ctx context.Context, key, value runtime.Value) error - } - - BaseCollector struct{} -) +type CollectorType int const ( CollectorTypeCounter CollectorType = iota @@ -26,7 +9,7 @@ const ( CollectorTypeKeyGroup ) -func NewCollector(typ CollectorType) Collector { +func NewCollector(typ CollectorType) Transformer { switch typ { case CollectorTypeCounter: return NewCounterCollector() @@ -40,23 +23,3 @@ func NewCollector(typ CollectorType) Collector { panic("unknown collector type") } } - -func (*BaseCollector) MarshalJSON() ([]byte, error) { - panic("not supported") -} - -func (*BaseCollector) String() string { - return "[Collector]" -} - -func (*BaseCollector) Unwrap() interface{} { - panic("not supported") -} - -func (*BaseCollector) Hash() uint64 { - panic("not supported") -} - -func (*BaseCollector) Copy() runtime.Value { - panic("not supported") -} diff --git a/pkg/vm/internal/collector_counter.go b/pkg/vm/internal/collector_counter.go index b0e02123..5329f707 100644 --- a/pkg/vm/internal/collector_counter.go +++ b/pkg/vm/internal/collector_counter.go @@ -7,24 +7,23 @@ import ( ) type CounterCollector struct { - *BaseCollector - - counter runtime.Int + *runtime.Box[runtime.Int] } -func NewCounterCollector() Collector { +func NewCounterCollector() Transformer { return &CounterCollector{ - BaseCollector: &BaseCollector{}, - counter: 0, + Box: &runtime.Box[runtime.Int]{ + Value: 0, + }, } } func (c *CounterCollector) Iterate(ctx context.Context) (runtime.Iterator, error) { - return runtime.NewArrayWith(c.counter).Iterate(ctx) + return runtime.NewArrayWith(c.Value).Iterate(ctx) } -func (c *CounterCollector) Collect(ctx context.Context, key, value runtime.Value) error { - c.counter++ +func (c *CounterCollector) Add(_ context.Context, _, _ runtime.Value) error { + c.Value++ return nil } diff --git a/pkg/vm/internal/collector_key.go b/pkg/vm/internal/collector_key.go index 7976cecc..8197bf29 100644 --- a/pkg/vm/internal/collector_key.go +++ b/pkg/vm/internal/collector_key.go @@ -7,33 +7,33 @@ import ( ) type KeyCollector struct { - *BaseCollector - values runtime.List + *runtime.Box[runtime.List] grouping map[string]runtime.Value sorted bool } -func NewKeyCollector() Collector { +func NewKeyCollector() Transformer { return &KeyCollector{ - BaseCollector: &BaseCollector{}, - values: runtime.NewArray(16), - grouping: make(map[string]runtime.Value), + Box: &runtime.Box[runtime.List]{ + Value: runtime.NewArray(16), + }, + grouping: make(map[string]runtime.Value), } } func (c *KeyCollector) Iterate(ctx context.Context) (runtime.Iterator, error) { if !c.sorted { - if err := runtime.SortAsc(ctx, c.values); err != nil { + if err := runtime.SortAsc(ctx, c.Value); err != nil { return nil, err } c.sorted = true } - return c.values.Iterate(ctx) + return c.Value.Iterate(ctx) } -func (c *KeyCollector) Collect(ctx context.Context, key, _ runtime.Value) error { +func (c *KeyCollector) Add(ctx context.Context, key, _ runtime.Value) error { k, err := Stringify(ctx, key) if err != nil { @@ -45,7 +45,7 @@ func (c *KeyCollector) Collect(ctx context.Context, key, _ runtime.Value) error if !exists { c.grouping[k] = runtime.None - return c.values.Add(ctx, key) + return c.Value.Add(ctx, key) } return nil diff --git a/pkg/vm/internal/collector_key_counter.go b/pkg/vm/internal/collector_key_counter.go index f31ec01a..64afdd2c 100644 --- a/pkg/vm/internal/collector_key_counter.go +++ b/pkg/vm/internal/collector_key_counter.go @@ -7,17 +7,17 @@ import ( ) type KeyCounterCollector struct { - *BaseCollector - values runtime.List + *runtime.Box[runtime.List] grouping map[string]runtime.Int sorted bool } -func NewKeyCounterCollector() Collector { +func NewKeyCounterCollector() Transformer { return &KeyCounterCollector{ - BaseCollector: &BaseCollector{}, - values: runtime.NewArray(8), - grouping: make(map[string]runtime.Int), + Box: &runtime.Box[runtime.List]{ + Value: runtime.NewArray(8), + }, + grouping: make(map[string]runtime.Int), } } @@ -30,7 +30,7 @@ func (c *KeyCounterCollector) Iterate(ctx context.Context) (runtime.Iterator, er c.sorted = true } - iter, err := c.values.Iterate(ctx) + iter, err := c.Value.Iterate(ctx) if err != nil { return nil, err @@ -40,7 +40,7 @@ func (c *KeyCounterCollector) Iterate(ctx context.Context) (runtime.Iterator, er } func (c *KeyCounterCollector) sort(ctx context.Context) error { - return runtime.SortListWith(ctx, c.values, func(first, second runtime.Value) int64 { + return runtime.SortListWith(ctx, c.Value, func(first, second runtime.Value) int64 { firstKV, firstOk := first.(*KV) secondKV, secondOk := second.(*KV) @@ -56,7 +56,7 @@ func (c *KeyCounterCollector) sort(ctx context.Context) error { }) } -func (c *KeyCounterCollector) Collect(ctx context.Context, key, _ runtime.Value) error { +func (c *KeyCounterCollector) Add(ctx context.Context, key, _ runtime.Value) error { k, err := Stringify(ctx, key) if err != nil { @@ -68,7 +68,7 @@ func (c *KeyCounterCollector) Collect(ctx context.Context, key, _ runtime.Value) var kv *KV if !exists { - size, err := c.values.Length(ctx) + size, err := c.Value.Length(ctx) if err != nil { return err @@ -77,13 +77,13 @@ func (c *KeyCounterCollector) Collect(ctx context.Context, key, _ runtime.Value) idx = size kv = NewKV(key, runtime.ZeroInt) - if err := c.values.Add(ctx, kv); err != nil { + if err := c.Value.Add(ctx, kv); err != nil { return err } c.grouping[k] = idx } else { - value, err := c.values.Get(ctx, idx) + value, err := c.Value.Get(ctx, idx) if err != nil { return err diff --git a/pkg/vm/internal/collector_key_group.go b/pkg/vm/internal/collector_key_group.go index 1f9b33b4..775001e7 100644 --- a/pkg/vm/internal/collector_key_group.go +++ b/pkg/vm/internal/collector_key_group.go @@ -7,17 +7,17 @@ import ( ) type KeyGroupCollector struct { - *BaseCollector - values runtime.List + *runtime.Box[runtime.List] grouping map[string]runtime.List sorted bool } -func NewKeyGroupCollector() Collector { +func NewKeyGroupCollector() Transformer { return &KeyGroupCollector{ - BaseCollector: &BaseCollector{}, - values: runtime.NewArray(8), - grouping: make(map[string]runtime.List), + Box: &runtime.Box[runtime.List]{ + Value: runtime.NewArray(8), + }, + grouping: make(map[string]runtime.List), } } @@ -30,7 +30,7 @@ func (c *KeyGroupCollector) Iterate(ctx context.Context) (runtime.Iterator, erro c.sorted = true } - iter, err := c.values.Iterate(ctx) + iter, err := c.Value.Iterate(ctx) if err != nil { return nil, err @@ -40,7 +40,7 @@ func (c *KeyGroupCollector) Iterate(ctx context.Context) (runtime.Iterator, erro } func (c *KeyGroupCollector) sort(ctx context.Context) error { - return runtime.SortListWith(ctx, c.values, func(first, second runtime.Value) int64 { + return runtime.SortListWith(ctx, c.Value, func(first, second runtime.Value) int64 { firstKV, firstOk := first.(*KV) secondKV, secondOk := second.(*KV) @@ -56,7 +56,7 @@ func (c *KeyGroupCollector) sort(ctx context.Context) error { }) } -func (c *KeyGroupCollector) Collect(ctx context.Context, key, value runtime.Value) error { +func (c *KeyGroupCollector) Add(ctx context.Context, key, value runtime.Value) error { k, err := Stringify(ctx, key) if err != nil { @@ -70,7 +70,7 @@ func (c *KeyGroupCollector) Collect(ctx context.Context, key, value runtime.Valu c.grouping[k] = group - err = c.values.Add(ctx, NewKV(key, group)) + err = c.Value.Add(ctx, NewKV(key, group)) if err != nil { return err diff --git a/pkg/vm/internal/dataset.go b/pkg/vm/internal/dataset.go index 2065e731..632afc40 100644 --- a/pkg/vm/internal/dataset.go +++ b/pkg/vm/internal/dataset.go @@ -9,8 +9,6 @@ import ( type DataSet struct { values runtime.List uniqueness map[uint64]bool - grouping map[string]runtime.Value - keyed bool } // TODO: Remove implementation of runtime.List interface. Add an unwrap opcode in the VM to unwrap the values. @@ -28,163 +26,6 @@ func NewDataSet(distinct bool) runtime.List { } } -func (ds *DataSet) Sort(ctx context.Context, direction runtime.Int) error { - return runtime.SortListWith(ctx, ds.values, func(first, second runtime.Value) int64 { - firstKV, firstOk := first.(*KV) - secondKV, secondOk := second.(*KV) - - var comp int64 - - if firstOk && secondOk { - comp = runtime.CompareValues(firstKV.Key, secondKV.Key) - } else { - comp = runtime.CompareValues(first, second) - } - - if direction == SortAsc { - return comp - } - - return -comp - }) -} - -func (ds *DataSet) SortMany(ctx context.Context, directions []runtime.Int) error { - return runtime.SortListWith(ctx, ds.values, func(first, second runtime.Value) int64 { - firstKV, firstOk := first.(*KV) - secondKV, secondOk := second.(*KV) - - if firstOk && secondOk { - firstKVKey := firstKV.Key.(runtime.List) - secondKVKey := secondKV.Key.(runtime.List) - - for idx, direction := range directions { - firstKey, _ := firstKVKey.Get(ctx, runtime.NewInt(idx)) - secondKey, _ := secondKVKey.Get(ctx, runtime.NewInt(idx)) - comp := runtime.CompareValues(firstKey, secondKey) - - if comp != 0 { - if direction == SortAsc { - return comp - } - - return -comp - } - } - } else { - comp := runtime.CompareValues(first, second) - - if comp != 0 { - if directions[0] == SortAsc { - return comp - } - - return -comp - } - } - - return 0 - }) -} - -func (ds *DataSet) AddKV(ctx context.Context, key, value runtime.Value) error { - can, err := ds.canAdd(ctx, value) - - if err != nil { - return err - } - - if can { - _ = ds.values.Add(ctx, NewKV(key, value)) - } - - ds.keyed = true - - return nil -} - -func (ds *DataSet) CollectK(ctx context.Context, key runtime.Value) error { - k, err := Stringify(ctx, key) - - if err != nil { - return err - } - - if ds.grouping == nil { - ds.grouping = make(map[string]runtime.Value) - } - - _, exists := ds.grouping[k] - - if !exists { - ds.grouping[k] = runtime.None - _ = ds.values.Add(ctx, NewKV(key, runtime.None)) - } - - ds.keyed = true - - return nil -} - -func (ds *DataSet) CollectKc(ctx context.Context, key runtime.Value) error { - k, err := Stringify(ctx, key) - - if err != nil { - return err - } - - if ds.grouping == nil { - ds.grouping = make(map[string]runtime.Value) - } - - group, exists := ds.grouping[k] - - if !exists { - group = NewKV(key, runtime.ZeroInt) - ds.grouping[k] = group - _ = ds.values.Add(ctx, group) - } - - kv := group.(*KV) - if count, ok := kv.Value.(runtime.Int); ok { - sum := count + 1 - kv.Value = sum - } else { - kv.Value = runtime.NewInt(1) - } - - ds.keyed = true - - return nil -} - -func (ds *DataSet) CollectKV(ctx context.Context, key, value runtime.Value) error { - k, err := Stringify(ctx, key) - - if err != nil { - return err - } - - if ds.grouping == nil { - ds.grouping = make(map[string]runtime.Value) - } - - group, exists := ds.grouping[k] - - if !exists { - group = runtime.NewArray(4) - ds.grouping[k] = group - _ = ds.values.Add(ctx, NewKV(key, group)) - } - - // TODO: Avoid type casting - _ = group.(runtime.List).Add(ctx, value) - - ds.keyed = true - - return nil -} - func (ds *DataSet) Add(ctx context.Context, item runtime.Value) error { can, err := ds.canAdd(ctx, item) @@ -204,17 +45,7 @@ func (ds *DataSet) Get(ctx context.Context, idx runtime.Int) (runtime.Value, err } func (ds *DataSet) Iterate(ctx context.Context) (runtime.Iterator, error) { - iter, err := ds.values.Iterate(ctx) - - if err != nil { - return nil, err - } - - if !ds.keyed { - return iter, nil - } - - return NewKVIterator(iter), nil + return ds.values.Iterate(ctx) } func (ds *DataSet) Length(ctx context.Context) (runtime.Int, error) { diff --git a/pkg/vm/internal/sorter.go b/pkg/vm/internal/sorter.go new file mode 100644 index 00000000..63a064f2 --- /dev/null +++ b/pkg/vm/internal/sorter.go @@ -0,0 +1,59 @@ +package internal + +import ( + "context" + + "github.com/MontFerret/ferret/pkg/runtime" +) + +type Sorter struct { + *runtime.Box[runtime.List] + direction runtime.SortDirection + sorted bool +} + +func NewSorter(direction runtime.SortDirection) Transformer { + return &Sorter{ + Box: &runtime.Box[runtime.List]{ + Value: runtime.NewArray(8), + }, + direction: direction, + } +} + +func (s *Sorter) Iterate(ctx context.Context) (runtime.Iterator, error) { + if !s.sorted { + if err := s.sort(ctx); err != nil { + return nil, err + } + + s.sorted = true + } + + iter, err := s.Value.Iterate(ctx) + + if err != nil { + return nil, err + } + + return NewKVIterator(iter), nil +} + +func (s *Sorter) Add(ctx context.Context, key, value runtime.Value) error { + return s.Value.Add(ctx, NewKV(key, value)) +} + +func (s *Sorter) sort(ctx context.Context) error { + return runtime.SortListWith(ctx, s.Value, func(first, second runtime.Value) int64 { + firstKV := first.(*KV) + secondKV := second.(*KV) + + comp := runtime.CompareValues(firstKV.Key, secondKV.Key) + + if s.direction == runtime.SortDirectionAsc { + return comp + } + + return -comp + }) +} diff --git a/pkg/vm/internal/stream.go b/pkg/vm/internal/stream.go index 443505b6..d74b016a 100644 --- a/pkg/vm/internal/stream.go +++ b/pkg/vm/internal/stream.go @@ -7,12 +7,12 @@ import ( ) type StreamValue struct { - *Box[runtime.Stream] + *runtime.Box[runtime.Stream] } func NewStreamValue(stream runtime.Stream) runtime.Value { return &StreamValue{ - Box: &Box[runtime.Stream]{ + Box: &runtime.Box[runtime.Stream]{ Value: stream, }, } diff --git a/pkg/vm/internal/transformer.go b/pkg/vm/internal/transformer.go new file mode 100644 index 00000000..cd5bc512 --- /dev/null +++ b/pkg/vm/internal/transformer.go @@ -0,0 +1,14 @@ +package internal + +import ( + "context" + + "github.com/MontFerret/ferret/pkg/runtime" +) + +type Transformer interface { + runtime.Value + runtime.Iterable + + Add(ctx context.Context, key, value runtime.Value) error +} diff --git a/pkg/vm/opcode.go b/pkg/vm/opcode.go index e1d3dbdb..1727286f 100644 --- a/pkg/vm/opcode.go +++ b/pkg/vm/opcode.go @@ -1,5 +1,6 @@ package vm +// Opcode represents an operation code used in virtual machine instruction sets. type Opcode byte const ( @@ -52,11 +53,6 @@ const ( OpRegexpPositive OpRegexpNegative - OpList // Load an array from a list of registers (ARR R2, R3 R5 - creates an array in R2 with elements from R3 to R5) - OpMap // Load an object from a list of registers (OBJ R2, R3 R5 - creates an object in R2 with elements from R3 to R5) - OpRange // Load a range from a list of registers (RNG R2, R3, R4 - creates a range in R2 with start from R3 and end at R4) - OpDataSet // Load a dataset to a register A - OpLength OpType OpClose @@ -65,6 +61,14 @@ const ( OpCall OpProtectedCall + OpList // Load an array from a list of registers (ARR R2, R3 R5 - creates an array in R2 with elements from R3 to R5) + OpMap // Load an object from a list of registers (OBJ R2, R3 R5 - creates an object in R2 with elements from R3 to R5) + OpRange // Load a range from a list of registers (RNG R2, R3, R4 - creates a range in R2 with start from R3 and end at R4) + OpDataSet // Load a dataset to a register A + OpDataSetCollector + OpDataSetSorter + OpDataSetMultiSorter + OpStream // Subscribes to a stream (SMRCV R2, R3, R4 - subscribes to a stream in R2 with a collection from R3 and optional params from R4) OpStreamIter // Consumes a stream (SMRD R2, R3 - consumes a stream in R2 with a collection from R3) @@ -72,18 +76,9 @@ const ( OpIterNext // Moves to the next element in the iterator (ITER R2, R3 - moves to the next element in the iterator in R2 with a collection from R3) OpIterValue // Returns the current value from the iterator (ITER R2, R3 - returns the current value from the iterator in R2 with a collection from R3) OpIterKey // Returns the current key from the iterator (ITER R2, R3 - returns the current key from the iterator in R2 with a collection from R3) + OpIterLimit + OpIterSkip OpPush // Adds a value to a dataset OpPushKV // Adds a key-value pair to a dataset - - OpCollector - OpCollectK // Adds a key to a group - OpCollectKc // Adds a key to a group and counts it - OpCollectKV // Adds a value to a group using key - - OpLimit - OpSkip - - OpSort // Sorts a collection of KeyValue pairs. (SORT R2, R3 - sorts a collection in R2 with a sorting direction in R3) - OpSortMany // Sorts a collection of KeyValue pairs with compound key and multiple directions. (SORT R2, R3, R4 - sorts a collection in R2 with a sorting direction from R3 to R4) ) diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index f93f49e5..99f467c2 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -361,7 +361,14 @@ loop: } case OpDataSet: reg[dst] = internal.NewDataSet(src1 == 1) - case OpCollector: + case OpDataSetSorter: + reg[dst] = internal.NewSorter(runtime.SortDirection(src1)) + case OpDataSetMultiSorter: + encoded := src1.Register() + count := src2.Register() + + reg[dst] = internal.NewMultiSorter(runtime.DecodeSortDirections(encoded, count)) + case OpDataSetCollector: reg[dst] = internal.NewCollector(internal.CollectorType(src1)) case OpPush: ds := reg[dst].(*internal.DataSet) @@ -374,52 +381,9 @@ loop: } } case OpPushKV: - var err error + tr := reg[dst].(internal.Transformer) - switch target := reg[dst].(type) { - case *internal.DataSet: - err = target.AddKV(ctx, reg[src1], reg[src2]) - case internal.Collector: - err = target.Collect(ctx, reg[src1], reg[src2]) - default: - return nil, runtime.TypeError(target, "vm.Collector") - } - - if err != nil { - if _, catch := tryCatch(vm.pc); catch { - continue - } - - return nil, err - } - case OpCollectK: - ds := reg[dst].(*internal.DataSet) - key := reg[src1] - - if err := ds.CollectK(ctx, key); err != nil { - if _, catch := tryCatch(vm.pc); catch { - continue - } - - return nil, err - } - case OpCollectKc: - ds := reg[dst].(*internal.DataSet) - key := reg[src1] - - if err := ds.CollectKc(ctx, key); err != nil { - if _, catch := tryCatch(vm.pc); catch { - continue - } - - return nil, err - } - case OpCollectKV: - ds := reg[dst].(*internal.DataSet) - key := reg[src1] - value := reg[src2] - - if err := ds.CollectKV(ctx, key, value); err != nil { + if err := tr.Add(ctx, reg[src1], reg[src2]); err != nil { if _, catch := tryCatch(vm.pc); catch { continue } @@ -467,7 +431,7 @@ loop: case OpIterKey: iterator := reg[src1].(*internal.Iterator) reg[dst] = iterator.Key() - case OpSkip: + case OpIterSkip: state := runtime.ToIntSafe(ctx, reg[dst]) threshold := runtime.ToIntSafe(ctx, reg[src1]) jump := int(src2) @@ -477,7 +441,7 @@ loop: reg[dst] = state vm.pc = jump } - case OpLimit: + case OpIterLimit: state := runtime.ToIntSafe(ctx, reg[dst]) threshold := runtime.ToIntSafe(ctx, reg[src1]) jump := int(src2) @@ -488,52 +452,6 @@ loop: } else { vm.pc = jump } - case OpSort: - var err error - dir := runtime.ToIntSafe(ctx, reg[src1]) - - switch target := reg[dst].(type) { - case *internal.DataSet: - err = target.Sort(ctx, dir) - case runtime.Sortable: - if dir == internal.SortAsc { - err = target.SortAsc(ctx) - } else { - err = target.SortDesc(ctx) - } - } - - if err != nil { - if _, catch := tryCatch(vm.pc); catch { - continue - } else { - return nil, err - } - } - case OpSortMany: - ds := reg[dst].(*internal.DataSet) - var size int - - if src1 > 0 { - size = src2.Register() - src1.Register() + 1 - } - - directions := make([]runtime.Int, 0, size) - start := int(src1) - end := int(src1) + size - - // Iterate over registers starting from src1 and up to the src2 - for i := start; i < end; i++ { - directions = append(directions, runtime.ToIntSafe(ctx, reg[i])) - } - - if err := ds.SortMany(ctx, directions); err != nil { - if _, catch := tryCatch(vm.pc); catch { - continue - } else { - return nil, err - } - } case OpStream: observable, eventName, options, err := vm.castSubscribeArgs(reg[dst], reg[src1], reg[src2])