1
0
mirror of https://github.com/MontFerret/ferret.git synced 2025-08-13 19:52:52 +02:00

Add COUNT_DISTINCT function to stdlib collections, update related integration tests, and refactor min/max calculation methods for clarity and consistency.

This commit is contained in:
Tim Voronov
2025-06-27 14:42:47 -04:00
parent f48f819607
commit 7dbcc24e04
8 changed files with 88 additions and 27 deletions

View File

@@ -65,6 +65,19 @@ func AssertBoolean(input Value) error {
return nil
}
func AssertCollection(input Value) error {
_, ok := input.(Collection)
if !ok {
return TypeErrorOf(
input,
TypeCollection,
)
}
return nil
}
func AssertArray(input Value) error {
_, ok := input.(*Array)

View File

@@ -100,6 +100,16 @@ func SafeCastDateTime(input Value, fallback DateTime) DateTime {
return fallback
}
func CastCollection(input Value) (Collection, error) {
arr, ok := input.(Collection)
if ok {
return arr, nil
}
return nil, TypeErrorOf(input, TypeCollection)
}
func CastList(input Value) (List, error) {
arr, ok := input.(List)

View File

@@ -37,6 +37,7 @@ const (
TypeBinary = Type("binary")
// Interfaces
TypeCollection = Type("collection")
TypeList = Type("list")
TypeMap = Type("map")
TypeIndexed = Type("indexed")

View File

@@ -0,0 +1,43 @@
package collections
import (
"context"
"github.com/MontFerret/ferret/pkg/runtime"
"github.com/MontFerret/ferret/pkg/runtime/core"
)
// COUNT_DISTINCT computes the number of distinct elements in the given collection and returns the count as an integer.
func CountDistinct(ctx context.Context, args ...core.Value) (core.Value, error) {
if err := runtime.ValidateArgs(args, 1, 1); err != nil {
return runtime.None, err
}
collection, err := runtime.CastCollection(args[0])
if err != nil {
return runtime.ZeroInt, err
}
// TODO: Use storage backend
hashmap := map[uint64]bool{}
var res runtime.Int
err = runtime.ForEach(ctx, collection, func(c context.Context, value, idx runtime.Value) (runtime.Boolean, error) {
hash := value.Hash()
_, exists := hashmap[hash]
if !exists {
hashmap[hash] = true
res++
}
return true, nil
})
if err != nil {
return runtime.ZeroInt, err
}
return res, nil
}

View File

@@ -7,7 +7,8 @@ import (
func RegisterLib(ns runtime.Namespace) error {
return ns.RegisterFunctions(
runtime.NewFunctionsFromMap(map[string]runtime.Function{
"INCLUDES": Includes,
"REVERSE": Reverse,
"COUNT_DISTINCT": CountDistinct,
"INCLUDES": Includes,
"REVERSE": Reverse,
}))
}

View File

@@ -30,19 +30,13 @@ func Max(ctx context.Context, args ...runtime.Value) (runtime.Value, error) {
return runtime.None, nil
}
var max float64
var res float64
err = arr.ForEach(ctx, func(c context.Context, value runtime.Value, idx runtime.Int) (runtime.Boolean, error) {
err = runtime.AssertNumber(value)
if err != nil {
return false, nil
}
fv := toFloat(value)
if fv > max {
max = fv
if fv > res {
res = fv
}
return true, nil
@@ -52,5 +46,5 @@ func Max(ctx context.Context, args ...runtime.Value) (runtime.Value, error) {
return runtime.None, nil
}
return runtime.NewFloat(max), nil
return runtime.NewFloat(res), nil
}

View File

@@ -30,19 +30,13 @@ func Min(ctx context.Context, args ...runtime.Value) (runtime.Value, error) {
return runtime.None, nil
}
var min float64
var res float64
err = arr.ForEach(ctx, func(c context.Context, value runtime.Value, idx runtime.Int) (runtime.Boolean, error) {
err = runtime.AssertNumber(value)
if err != nil {
return false, nil
}
fv := toFloat(value)
if min > fv || idx == 0 {
min = fv
if res > fv || idx == 0 {
res = fv
}
return true, nil
@@ -52,5 +46,5 @@ func Min(ctx context.Context, args ...runtime.Value) (runtime.Value, error) {
return runtime.None, nil
}
return runtime.NewFloat(min), nil
return runtime.NewFloat(res), nil
}

View File

@@ -39,7 +39,7 @@ func TestCollectAggregate(t *testing.T) {
}
`, []any{
map[string]any{"gender": "f", "minAge": 25, "maxAge": 25},
map[string]any{"gender": "m", "minAge": nil, "maxAge": nil},
map[string]any{"gender": "m", "minAge": 0, "maxAge": 0},
}, "Should handle null values in aggregation"),
CaseArray(`
LET users = [
@@ -204,7 +204,7 @@ FOR u IN users
"gender": "m",
"activeCount": 2,
"marriedCount": 2,
"highSalaryCount": 2,
"highSalaryCount": 3,
},
}, "Should aggregate with conditional expressions"),
CaseArray(`
@@ -498,7 +498,7 @@ FOR u IN users
]
FOR u IN users
COLLECT AGGREGATE
allSkills = UNION(u.skills),
allSkills = UNION_DISTINCT(u.skills, u.skills),
uniqueSkillCount = COUNT_DISTINCT(u.skills)
RETURN {
allSkills: SORTED(allSkills),
@@ -506,8 +506,13 @@ FOR u IN users
}
`, []any{
map[string]any{
"allSkills": []any{"C++", "Go", "Java", "JavaScript", "Python", "Rust", "TypeScript"},
"uniqueSkillCount": 7,
"allSkills": []any{
[]any{"Go", "Rust"},
[]any{"JavaScript", "TypeScript"},
[]any{"Java", "C++", "Python"},
[]any{"JavaScript", "Python", "Go"},
},
"uniqueSkillCount": 4,
},
}, "Should aggregate with array operations"),
})