From cbc88d7a012c61dfdee9d3463dcd9cdb6b647700 Mon Sep 17 00:00:00 2001 From: Gani Georgiev Date: Fri, 18 Oct 2024 16:03:49 +0300 Subject: [PATCH] backported some of the v0.23.0 form validators, fixes and tests --- forms/realtime_subscribe.go | 1 + forms/realtime_subscribe_test.go | 83 +++++++++++++++++++++++----- forms/validators/record_data.go | 5 ++ forms/validators/record_data_test.go | 10 ++++ tools/rest/multi_binder.go | 23 +++++--- 5 files changed, 100 insertions(+), 22 deletions(-) diff --git a/forms/realtime_subscribe.go b/forms/realtime_subscribe.go index fc852fc8..def65342 100644 --- a/forms/realtime_subscribe.go +++ b/forms/realtime_subscribe.go @@ -19,5 +19,6 @@ func NewRealtimeSubscribe() *RealtimeSubscribe { func (form *RealtimeSubscribe) Validate() error { return validation.ValidateStruct(form, validation.Field(&form.ClientId, validation.Required, validation.Length(1, 255)), + validation.Field(&form.Subscriptions, validation.Length(0, 1000)), ) } diff --git a/forms/realtime_subscribe_test.go b/forms/realtime_subscribe_test.go index d4f8b1e7..d615c414 100644 --- a/forms/realtime_subscribe_test.go +++ b/forms/realtime_subscribe_test.go @@ -1,33 +1,86 @@ package forms_test import ( + "encoding/json" + "fmt" "strings" "testing" + validation "github.com/go-ozzo/ozzo-validation/v4" "github.com/pocketbase/pocketbase/forms" ) func TestRealtimeSubscribeValidate(t *testing.T) { t.Parallel() - scenarios := []struct { - clientId string - expectError bool - }{ - {"", true}, - {strings.Repeat("a", 256), true}, - {"test", false}, + validSubscriptionsLimit := make([]string, 1000) + for i := 0; i < len(validSubscriptionsLimit); i++ { + validSubscriptionsLimit[i] = fmt.Sprintf(`"%d"`, i) + } + invalidSubscriptionsLimit := make([]string, 1001) + for i := 0; i < len(invalidSubscriptionsLimit); i++ { + invalidSubscriptionsLimit[i] = fmt.Sprintf(`"%d"`, i) } - for i, s := range scenarios { - form := forms.NewRealtimeSubscribe() - form.ClientId = s.clientId + scenarios := []struct { + name string + data string + expectedErrors []string + }{ + { + "empty data", + `{}`, + []string{"clientId"}, + }, + { + "clientId > max chars limit", + `{"clientId":"` + strings.Repeat("a", 256) + `"}`, + []string{"clientId"}, + }, + { + "clientId <= max chars limit", + `{"clientId":"` + strings.Repeat("a", 255) + `"}`, + []string{}, + }, + { + "subscriptions > max limit", + `{"clientId":"test", "subscriptions":[` + strings.Join(invalidSubscriptionsLimit, ",") + `]}`, + []string{"subscriptions"}, + }, + { + "subscriptions <= max limit", + `{"clientId":"test", "subscriptions":[` + strings.Join(validSubscriptionsLimit, ",") + `]}`, + []string{}, + }, + } - err := form.Validate() + for _, s := range scenarios { + t.Run(s.name, func(t *testing.T) { + form := forms.NewRealtimeSubscribe() - hasErr := err != nil - if hasErr != s.expectError { - t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, s.expectError, hasErr, err) - } + err := json.Unmarshal([]byte(s.data), &form) + if err != nil { + t.Fatal(err) + } + + result := form.Validate() + + // parse errors + errs, ok := result.(validation.Errors) + if !ok && result != nil { + t.Fatalf("Failed to parse errors %v", result) + return + } + + // check errors + if len(errs) > len(s.expectedErrors) { + t.Fatalf("Expected error keys %v, got %v", s.expectedErrors, errs) + } + for _, k := range s.expectedErrors { + if _, ok := errs[k]; !ok { + t.Fatalf("Missing expected error key %q in %v", k, errs) + } + } + }) } } diff --git a/forms/validators/record_data.go b/forms/validators/record_data.go index 9c4232f8..b1d8359d 100644 --- a/forms/validators/record_data.go +++ b/forms/validators/record_data.go @@ -2,6 +2,7 @@ package validators import ( "fmt" + "math" "net/url" "regexp" "strings" @@ -159,6 +160,10 @@ func (validator *RecordDataValidator) checkNumberValue(field *schema.SchemaField return nil // nothing to check (skip zero-defaults) } + if math.IsInf(val, 0) || math.IsNaN(val) { + return validation.NewError("validation_nan", "The submitted number is not properly formatted") + } + options, _ := field.Options.(*schema.NumberOptions) if options.NoDecimal && val != float64(int64(val)) { diff --git a/forms/validators/record_data_test.go b/forms/validators/record_data_test.go index 9dc58692..029dc6ea 100644 --- a/forms/validators/record_data_test.go +++ b/forms/validators/record_data_test.go @@ -248,6 +248,16 @@ func TestRecordDataValidatorValidateNumber(t *testing.T) { nil, []string{"field2"}, }, + { + "(number) check infinities and NaN", + map[string]any{ + "field1": "Inf", + "field2": "-Inf", + "field4": "NaN", + }, + nil, + []string{"field1", "field2", "field4"}, + }, { "(number) check min constraint", map[string]any{ diff --git a/tools/rest/multi_binder.go b/tools/rest/multi_binder.go index 6fcc2e29..4df62559 100644 --- a/tools/rest/multi_binder.go +++ b/tools/rest/multi_binder.go @@ -6,10 +6,11 @@ import ( "io" "net/http" "reflect" + "regexp" + "strconv" "strings" "github.com/labstack/echo/v5" - "github.com/spf13/cast" ) // MultipartJsonKey is the key for the special multipart/form-data @@ -144,12 +145,16 @@ func bindFormData(c echo.Context, i any) error { return echo.BindBody(c, i) } +var inferNumberCharsRegex = regexp.MustCompile(`^[\-\.\d]+$`) + // In order to support more seamlessly both json and multipart/form-data requests, // the following normalization rules are applied for plain multipart string values: -// - "true" is converted to the json `true` -// - "false" is converted to the json `false` -// - numeric (non-scientific) strings are converted to json number -// - any other string (empty string too) is left as it is +// - "true" is converted to the json "true" +// - "false" is converted to the json "false" +// - numeric strings are converted to json number ONLY if the resulted +// minimal number string representation is the same as the provided raw string +// (aka. scientific notations, "Infinity", "0.0", "0001", etc. are kept as string) +// - any other string (empty string too) is left as it is func normalizeMultipartValue(raw string) any { switch raw { case "": @@ -159,8 +164,12 @@ func normalizeMultipartValue(raw string) any { case "false": return false default: - if raw[0] == '-' || (raw[0] >= '0' && raw[0] <= '9') { - if v, err := cast.ToFloat64E(raw); err == nil { + // try to convert to number + // + // note: expects the provided raw string to match exactly with the minimal string representation of the parsed float + if raw[0] == '-' || (raw[0] >= '0' && raw[0] <= '9') && inferNumberCharsRegex.Match([]byte(raw)) { + v, err := strconv.ParseFloat(raw, 64) + if err == nil && strconv.FormatFloat(v, 'f', -1, 64) == raw { return v } }