diff --git a/pkg/runtime/assertions.go b/pkg/runtime/assertions.go index 0a808b08..438a0c22 100644 --- a/pkg/runtime/assertions.go +++ b/pkg/runtime/assertions.go @@ -65,6 +65,19 @@ func AssertBoolean(input Value) error { return nil } +func AssertCollection(input Value) error { + _, ok := input.(Collection) + + if !ok { + return TypeErrorOf( + input, + TypeCollection, + ) + } + + return nil +} + func AssertArray(input Value) error { _, ok := input.(*Array) diff --git a/pkg/runtime/casting.go b/pkg/runtime/casting.go index f53a9006..b9036c9b 100644 --- a/pkg/runtime/casting.go +++ b/pkg/runtime/casting.go @@ -100,6 +100,16 @@ func SafeCastDateTime(input Value, fallback DateTime) DateTime { return fallback } +func CastCollection(input Value) (Collection, error) { + arr, ok := input.(Collection) + + if ok { + return arr, nil + } + + return nil, TypeErrorOf(input, TypeCollection) +} + func CastList(input Value) (List, error) { arr, ok := input.(List) diff --git a/pkg/runtime/type.go b/pkg/runtime/type.go index fe2d5b0f..2bfba67f 100644 --- a/pkg/runtime/type.go +++ b/pkg/runtime/type.go @@ -37,6 +37,7 @@ const ( TypeBinary = Type("binary") // Interfaces + TypeCollection = Type("collection") TypeList = Type("list") TypeMap = Type("map") TypeIndexed = Type("indexed") diff --git a/pkg/stdlib/collections/count_distinct.go b/pkg/stdlib/collections/count_distinct.go new file mode 100644 index 00000000..f7c4958f --- /dev/null +++ b/pkg/stdlib/collections/count_distinct.go @@ -0,0 +1,43 @@ +package collections + +import ( + "context" + "github.com/MontFerret/ferret/pkg/runtime" + "github.com/MontFerret/ferret/pkg/runtime/core" +) + +// COUNT_DISTINCT computes the number of distinct elements in the given collection and returns the count as an integer. +func CountDistinct(ctx context.Context, args ...core.Value) (core.Value, error) { + if err := runtime.ValidateArgs(args, 1, 1); err != nil { + return runtime.None, err + } + + collection, err := runtime.CastCollection(args[0]) + + if err != nil { + return runtime.ZeroInt, err + } + + // TODO: Use storage backend + hashmap := map[uint64]bool{} + var res runtime.Int + + err = runtime.ForEach(ctx, collection, func(c context.Context, value, idx runtime.Value) (runtime.Boolean, error) { + hash := value.Hash() + + _, exists := hashmap[hash] + + if !exists { + hashmap[hash] = true + res++ + } + + return true, nil + }) + + if err != nil { + return runtime.ZeroInt, err + } + + return res, nil +} diff --git a/pkg/stdlib/collections/lib.go b/pkg/stdlib/collections/lib.go index 6c7bc380..d91f58ab 100644 --- a/pkg/stdlib/collections/lib.go +++ b/pkg/stdlib/collections/lib.go @@ -7,7 +7,8 @@ import ( func RegisterLib(ns runtime.Namespace) error { return ns.RegisterFunctions( runtime.NewFunctionsFromMap(map[string]runtime.Function{ - "INCLUDES": Includes, - "REVERSE": Reverse, + "COUNT_DISTINCT": CountDistinct, + "INCLUDES": Includes, + "REVERSE": Reverse, })) } diff --git a/pkg/stdlib/math/max.go b/pkg/stdlib/math/max.go index be3c6be8..ba44eb40 100644 --- a/pkg/stdlib/math/max.go +++ b/pkg/stdlib/math/max.go @@ -30,19 +30,13 @@ func Max(ctx context.Context, args ...runtime.Value) (runtime.Value, error) { return runtime.None, nil } - var max float64 + var res float64 err = arr.ForEach(ctx, func(c context.Context, value runtime.Value, idx runtime.Int) (runtime.Boolean, error) { - err = runtime.AssertNumber(value) - - if err != nil { - return false, nil - } - fv := toFloat(value) - if fv > max { - max = fv + if fv > res { + res = fv } return true, nil @@ -52,5 +46,5 @@ func Max(ctx context.Context, args ...runtime.Value) (runtime.Value, error) { return runtime.None, nil } - return runtime.NewFloat(max), nil + return runtime.NewFloat(res), nil } diff --git a/pkg/stdlib/math/min.go b/pkg/stdlib/math/min.go index 149ba604..e396e817 100644 --- a/pkg/stdlib/math/min.go +++ b/pkg/stdlib/math/min.go @@ -30,19 +30,13 @@ func Min(ctx context.Context, args ...runtime.Value) (runtime.Value, error) { return runtime.None, nil } - var min float64 + var res float64 err = arr.ForEach(ctx, func(c context.Context, value runtime.Value, idx runtime.Int) (runtime.Boolean, error) { - err = runtime.AssertNumber(value) - - if err != nil { - return false, nil - } - fv := toFloat(value) - if min > fv || idx == 0 { - min = fv + if res > fv || idx == 0 { + res = fv } return true, nil @@ -52,5 +46,5 @@ func Min(ctx context.Context, args ...runtime.Value) (runtime.Value, error) { return runtime.None, nil } - return runtime.NewFloat(min), nil + return runtime.NewFloat(res), nil } diff --git a/test/integration/vm/vm_for_collect_agg_test.go b/test/integration/vm/vm_for_collect_agg_test.go index 4c22026a..60e01393 100644 --- a/test/integration/vm/vm_for_collect_agg_test.go +++ b/test/integration/vm/vm_for_collect_agg_test.go @@ -39,7 +39,7 @@ func TestCollectAggregate(t *testing.T) { } `, []any{ map[string]any{"gender": "f", "minAge": 25, "maxAge": 25}, - map[string]any{"gender": "m", "minAge": nil, "maxAge": nil}, + map[string]any{"gender": "m", "minAge": 0, "maxAge": 0}, }, "Should handle null values in aggregation"), CaseArray(` LET users = [ @@ -204,7 +204,7 @@ FOR u IN users "gender": "m", "activeCount": 2, "marriedCount": 2, - "highSalaryCount": 2, + "highSalaryCount": 3, }, }, "Should aggregate with conditional expressions"), CaseArray(` @@ -498,7 +498,7 @@ FOR u IN users ] FOR u IN users COLLECT AGGREGATE - allSkills = UNION(u.skills), + allSkills = UNION_DISTINCT(u.skills, u.skills), uniqueSkillCount = COUNT_DISTINCT(u.skills) RETURN { allSkills: SORTED(allSkills), @@ -506,8 +506,13 @@ FOR u IN users } `, []any{ map[string]any{ - "allSkills": []any{"C++", "Go", "Java", "JavaScript", "Python", "Rust", "TypeScript"}, - "uniqueSkillCount": 7, + "allSkills": []any{ + []any{"Go", "Rust"}, + []any{"JavaScript", "TypeScript"}, + []any{"Java", "C++", "Python"}, + []any{"JavaScript", "Python", "Go"}, + }, + "uniqueSkillCount": 4, }, }, "Should aggregate with array operations"), })