diff --git a/CHANGELOG.md b/CHANGELOG.md index b6d49f1e..243b29b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ - Upgraded to `golang-jwt/jwt/v5`. -- Added JSVM `new Timezone(name)` binding for constructing `time.Location` value ([#6219](https://github.com/pocketbase/pocketbase/discussions/6219)). +- Added support for case-insensitive password auth based on the related UNIQUE index field collation ([#6337](https://github.com/pocketbase/pocketbase/discussions/6337)). - Soft-deprecated `Record.GetUploadedFiles` in favor of `Record.GetUnsavedFiles` to minimize the ambiguities what the method do ([#6269](https://github.com/pocketbase/pocketbase/discussions/6269)). (@todo update docs to reflect the `:unsaved` getter change) @@ -22,6 +22,8 @@ - Replaced archived `github.com/AlecAivazis/survey` dependency with a simpler `osutils.YesNoPrompt(message, fallback)` helper. +- Added JSVM `new Timezone(name)` binding for constructing `time.Location` value ([#6219](https://github.com/pocketbase/pocketbase/discussions/6219)). + - Added `inflector.Camelize(str)` and `inflector.Singularize(str)` helper methods. - Other minor improvements (_replaced all `bool` exists db scans with `int` for broader drivers compatibility, use the non-transactional app instance during realtime records delete access checks to ensure that cascade deleted records with API rules relying on the parent will be resolved, updated UI dependencies, etc._) diff --git a/apis/record_auth_with_oauth2.go b/apis/record_auth_with_oauth2.go index 4f8a1273..5a9ec48d 100644 --- a/apis/record_auth_with_oauth2.go +++ b/apis/record_auth_with_oauth2.go @@ -8,6 +8,7 @@ import ( "log/slog" "maps" "net/http" + "strings" "time" validation "github.com/go-ozzo/ozzo-validation/v4" @@ -194,10 +195,20 @@ func (form *recordOAuth2LoginForm) checkProviderName(value any) error { func oldCanAssignUsername(txApp core.App, collection *core.Collection, username string) bool { // ensure that username is unique - checkUnique := dbutils.HasSingleColumnUniqueIndex(collection.OAuth2.MappedFields.Username, collection.Indexes) - if checkUnique { - if _, err := txApp.FindFirstRecordByData(collection, collection.OAuth2.MappedFields.Username, username); err == nil { - return false // already exist + index, hasUniqueue := dbutils.FindSingleColumnUniqueIndex(collection.Indexes, collection.OAuth2.MappedFields.Username) + if hasUniqueue { + var expr dbx.Expression + if strings.EqualFold(index.Columns[0].Collate, "nocase") { + // case-insensitive search + expr = dbx.NewExp("username = {:username} COLLATE NOCASE", dbx.Params{"username": username}) + } else { + expr = dbx.HashExp{"username": username} + } + + var exists int + _ = txApp.RecordQuery(collection).Select("(1)").AndWhere(expr).Limit(1).Row(&exists) + if exists > 0 { + return false } } diff --git a/apis/record_auth_with_oauth2_test.go b/apis/record_auth_with_oauth2_test.go index 2ca11f81..5be9097b 100644 --- a/apis/record_auth_with_oauth2_test.go +++ b/apis/record_auth_with_oauth2_test.go @@ -14,6 +14,7 @@ import ( "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/tests" "github.com/pocketbase/pocketbase/tools/auth" + "github.com/pocketbase/pocketbase/tools/dbutils" "golang.org/x/oauth2" ) @@ -1210,7 +1211,7 @@ func TestRecordAuthWithOAuth2(t *testing.T) { }, }, { - Name: "creating user (with mapped OAuth2 fields and avatarURL->non-file field)", + Name: "creating user (with mapped OAuth2 fields, case-sensitive username and avatarURL->non-file field)", Method: http.MethodPost, URL: "/api/collections/users/auth-with-oauth2", Body: strings.NewReader(`{ @@ -1230,7 +1231,7 @@ func TestRecordAuthWithOAuth2(t *testing.T) { AuthUser: &auth.AuthUser{ Id: "oauth2_id", Email: "oauth2@example.com", - Username: "oauth2_username", + Username: "tESt2_username", // wouldn't match with existing because the related field index is case-sensitive Name: "oauth2_name", AvatarURL: server.URL + "/oauth2_avatar.png", }, @@ -1258,7 +1259,7 @@ func TestRecordAuthWithOAuth2(t *testing.T) { ExpectedContent: []string{ `"email":"oauth2@example.com"`, `"emailVisibility":false`, - `"username":"oauth2_username"`, + `"username":"tESt2_username"`, `"name":"http://127.`, `"verified":true`, `"avatar":""`, @@ -1294,7 +1295,7 @@ func TestRecordAuthWithOAuth2(t *testing.T) { }, }, { - Name: "creating user (with mapped OAuth2 fields and duplicated username)", + Name: "creating user (with mapped OAuth2 fields and duplicated case-insensitive username)", Method: http.MethodPost, URL: "/api/collections/users/auth-with-oauth2", Body: strings.NewReader(`{ @@ -1314,13 +1315,21 @@ func TestRecordAuthWithOAuth2(t *testing.T) { AuthUser: &auth.AuthUser{ Id: "oauth2_id", Email: "oauth2@example.com", - Username: "test2_username", + Username: "tESt2_username", Name: "oauth2_name", }, Token: &oauth2.Token{AccessToken: "abc"}, } } + // make the username index case-insensitive to ensure that case-insensitive match is used + index, ok := dbutils.FindSingleColumnUniqueIndex(usersCol.Indexes, "username") + if ok { + index.Columns[0].Collate = "nocase" + usersCol.RemoveIndex(index.IndexName) + usersCol.Indexes = append(usersCol.Indexes, index.Build()) + } + // add the test provider in the collection usersCol.MFA.Enabled = false usersCol.OAuth2.Enabled = true diff --git a/apis/record_auth_with_password.go b/apis/record_auth_with_password.go index a65bb3b7..2d60274e 100644 --- a/apis/record_auth_with_password.go +++ b/apis/record_auth_with_password.go @@ -3,10 +3,14 @@ package apis import ( "database/sql" "errors" + "slices" + "strings" validation "github.com/go-ozzo/ozzo-validation/v4" "github.com/go-ozzo/ozzo-validation/v4/is" + "github.com/pocketbase/dbx" "github.com/pocketbase/pocketbase/core" + "github.com/pocketbase/pocketbase/tools/dbutils" "github.com/pocketbase/pocketbase/tools/list" ) @@ -32,12 +36,12 @@ func recordAuthWithPassword(e *core.RequestEvent) error { var foundErr error if form.IdentityField != "" { - foundRecord, foundErr = e.App.FindFirstRecordByData(collection.Id, form.IdentityField, form.Identity) + foundRecord, foundErr = findRecordByIdentityField(e.App, collection, form.IdentityField, form.Identity) } else { // prioritize email lookup isEmail := is.EmailFormat.Validate(form.Identity) == nil if isEmail && list.ExistInSlice(core.FieldNameEmail, collection.PasswordAuth.IdentityFields) { - foundRecord, foundErr = e.App.FindAuthRecordByEmail(collection.Id, form.Identity) + foundRecord, foundErr = findRecordByIdentityField(e.App, collection, core.FieldNameEmail, form.Identity) } // search by the other identity fields @@ -47,7 +51,7 @@ func recordAuthWithPassword(e *core.RequestEvent) error { continue // no need to search by the email field if it is not an email } - foundRecord, foundErr = e.App.FindFirstRecordByData(collection.Id, name, form.Identity) + foundRecord, foundErr = findRecordByIdentityField(e.App, collection, name, form.Identity) if foundErr == nil { break } @@ -95,3 +99,31 @@ func (form *authWithPasswordForm) validate(collection *core.Collection) error { validation.Field(&form.IdentityField, validation.In(list.ToInterfaceSlice(collection.PasswordAuth.IdentityFields)...)), ) } + +func findRecordByIdentityField(app core.App, collection *core.Collection, field string, value any) (*core.Record, error) { + if !slices.Contains(collection.PasswordAuth.IdentityFields, field) { + return nil, errors.New("invalid identity field " + field) + } + + index, ok := dbutils.FindSingleColumnUniqueIndex(collection.Indexes, field) + if !ok { + return nil, errors.New("missing " + field + " unique index constraint") + } + + var expr dbx.Expression + if strings.EqualFold(index.Columns[0].Collate, "nocase") { + // case-insensitive search + expr = dbx.NewExp("[["+field+"]] = {:identity} COLLATE NOCASE", dbx.Params{"identity": value}) + } else { + expr = dbx.HashExp{field: value} + } + + record := &core.Record{} + + err := app.RecordQuery(collection).AndWhere(expr).Limit(1).One(record) + if err != nil { + return nil, err + } + + return record, nil +} diff --git a/apis/record_auth_with_password_test.go b/apis/record_auth_with_password_test.go index 75d47bd2..9cc3fb2b 100644 --- a/apis/record_auth_with_password_test.go +++ b/apis/record_auth_with_password_test.go @@ -8,11 +8,38 @@ import ( "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/tests" + "github.com/pocketbase/pocketbase/tools/dbutils" ) func TestRecordAuthWithPassword(t *testing.T) { t.Parallel() + updateIdentityIndex := func(collectionIdOrName string, fieldCollateMap map[string]string) func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { + return func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { + collection, err := app.FindCollectionByNameOrId("clients") + if err != nil { + t.Fatal(err) + } + + for column, collate := range fieldCollateMap { + index, ok := dbutils.FindSingleColumnUniqueIndex(collection.Indexes, column) + if !ok { + t.Fatalf("Missing unique identityField index for column %q", column) + } + + index.Columns[0].Collate = collate + + collection.RemoveIndex(index.IndexName) + collection.Indexes = append(collection.Indexes, index.Build()) + } + + err = app.Save(collection) + if err != nil { + t.Fatalf("Failed to update identityField index: %v", err) + } + } + } + scenarios := []tests.ApiScenario{ { Name: "disabled password auth", @@ -164,6 +191,22 @@ func TestRecordAuthWithPassword(t *testing.T) { "OnMailerRecordAuthAlertSend": 1, }, }, + { + Name: "unknown explicit identityField", + Method: http.MethodPost, + URL: "/api/collections/clients/auth-with-password", + Body: strings.NewReader(`{ + "identityField": "created", + "identity":"test@example.com", + "password":"1234567890" + }`), + ExpectedStatus: 400, + ExpectedContent: []string{ + `"data":{`, + `"identityField":{"code":"validation_in_invalid"`, + }, + ExpectedEvents: map[string]int{"*": 0}, + }, { Name: "valid identity field and valid password with mismatched explicit identityField", Method: http.MethodPost, @@ -440,6 +483,141 @@ func TestRecordAuthWithPassword(t *testing.T) { }, }, + // case sensitivity checks + // ----------------------------------------------------------- + { + Name: "with explicit identityField (case-sensitive)", + Method: http.MethodPost, + URL: "/api/collections/clients/auth-with-password", + Body: strings.NewReader(`{ + "identityField": "username", + "identity":"Clients57772", + "password":"1234567890" + }`), + BeforeTestFunc: updateIdentityIndex("clients", map[string]string{"username": ""}), + ExpectedStatus: 400, + ExpectedContent: []string{`"data":{}`}, + ExpectedEvents: map[string]int{ + "*": 0, + "OnRecordAuthWithPasswordRequest": 1, + }, + }, + { + Name: "with explicit identityField (case-insensitive)", + Method: http.MethodPost, + URL: "/api/collections/clients/auth-with-password", + Body: strings.NewReader(`{ + "identityField": "username", + "identity":"Clients57772", + "password":"1234567890" + }`), + BeforeTestFunc: updateIdentityIndex("clients", map[string]string{"username": "nocase"}), + ExpectedStatus: 200, + ExpectedContent: []string{ + `"email":"test@example.com"`, + `"username":"clients57772"`, + `"token":`, + }, + NotExpectedContent: []string{ + // hidden fields + `"tokenKey"`, + `"password"`, + }, + ExpectedEvents: map[string]int{ + "*": 0, + "OnRecordAuthWithPasswordRequest": 1, + "OnRecordAuthRequest": 1, + "OnRecordEnrich": 1, + // authOrigin track + "OnModelCreate": 1, + "OnModelCreateExecute": 1, + "OnModelAfterCreateSuccess": 1, + "OnModelValidate": 1, + "OnRecordCreate": 1, + "OnRecordCreateExecute": 1, + "OnRecordAfterCreateSuccess": 1, + "OnRecordValidate": 1, + "OnMailerSend": 1, + "OnMailerRecordAuthAlertSend": 1, + }, + }, + { + Name: "without explicit identityField and non-email field (case-insensitive)", + Method: http.MethodPost, + URL: "/api/collections/clients/auth-with-password", + Body: strings.NewReader(`{ + "identity":"Clients57772", + "password":"1234567890" + }`), + BeforeTestFunc: updateIdentityIndex("clients", map[string]string{"username": "nocase"}), + ExpectedStatus: 200, + ExpectedContent: []string{ + `"email":"test@example.com"`, + `"username":"clients57772"`, + `"token":`, + }, + NotExpectedContent: []string{ + // hidden fields + `"tokenKey"`, + `"password"`, + }, + ExpectedEvents: map[string]int{ + "*": 0, + "OnRecordAuthWithPasswordRequest": 1, + "OnRecordAuthRequest": 1, + "OnRecordEnrich": 1, + // authOrigin track + "OnModelCreate": 1, + "OnModelCreateExecute": 1, + "OnModelAfterCreateSuccess": 1, + "OnModelValidate": 1, + "OnRecordCreate": 1, + "OnRecordCreateExecute": 1, + "OnRecordAfterCreateSuccess": 1, + "OnRecordValidate": 1, + "OnMailerSend": 1, + "OnMailerRecordAuthAlertSend": 1, + }, + }, + { + Name: "without explicit identityField and email field (case-insensitive)", + Method: http.MethodPost, + URL: "/api/collections/clients/auth-with-password", + Body: strings.NewReader(`{ + "identity":"tESt@example.com", + "password":"1234567890" + }`), + BeforeTestFunc: updateIdentityIndex("clients", map[string]string{"email": "nocase"}), + ExpectedStatus: 200, + ExpectedContent: []string{ + `"email":"test@example.com"`, + `"username":"clients57772"`, + `"token":`, + }, + NotExpectedContent: []string{ + // hidden fields + `"tokenKey"`, + `"password"`, + }, + ExpectedEvents: map[string]int{ + "*": 0, + "OnRecordAuthWithPasswordRequest": 1, + "OnRecordAuthRequest": 1, + "OnRecordEnrich": 1, + // authOrigin track + "OnModelCreate": 1, + "OnModelCreateExecute": 1, + "OnModelAfterCreateSuccess": 1, + "OnModelValidate": 1, + "OnRecordCreate": 1, + "OnRecordCreateExecute": 1, + "OnRecordAfterCreateSuccess": 1, + "OnRecordValidate": 1, + "OnMailerSend": 1, + "OnMailerRecordAuthAlertSend": 1, + }, + }, + // rate limit checks // ----------------------------------------------------------- { diff --git a/core/collection_model.go b/core/collection_model.go index 3841231e..e50eea34 100644 --- a/core/collection_model.go +++ b/core/collection_model.go @@ -989,7 +989,7 @@ func (c *Collection) initTokenKeyField() { } // ensure that there is a unique index for the field - if !dbutils.HasSingleColumnUniqueIndex(FieldNameTokenKey, c.Indexes) { + if _, ok := dbutils.FindSingleColumnUniqueIndex(c.Indexes, FieldNameTokenKey); !ok { c.Indexes = append(c.Indexes, fmt.Sprintf( "CREATE UNIQUE INDEX `%s` ON `%s` (`%s`)", c.fieldIndexName(FieldNameTokenKey), @@ -1015,7 +1015,7 @@ func (c *Collection) initEmailField() { } // ensure that there is a unique index for the email field - if !dbutils.HasSingleColumnUniqueIndex(FieldNameEmail, c.Indexes) { + if _, ok := dbutils.FindSingleColumnUniqueIndex(c.Indexes, FieldNameEmail); !ok { c.Indexes = append(c.Indexes, fmt.Sprintf( "CREATE UNIQUE INDEX `%s` ON `%s` (`%s`) WHERE `%s` != ''", c.fieldIndexName(FieldNameEmail), diff --git a/core/collection_validate.go b/core/collection_validate.go index 12ad7aad..ad9116b1 100644 --- a/core/collection_validate.go +++ b/core/collection_validate.go @@ -456,7 +456,7 @@ func (cv *collectionValidator) checkFieldsForUniqueIndex(value any) error { SetParams(map[string]any{"fieldName": name}) } - if !dbutils.HasSingleColumnUniqueIndex(name, cv.new.Indexes) { + if _, ok := dbutils.FindSingleColumnUniqueIndex(cv.new.Indexes, name); !ok { return validation.NewError("validation_missing_unique_constraint", "The field {{.fieldName}} doesn't have a UNIQUE constraint."). SetParams(map[string]any{"fieldName": name}) } @@ -666,7 +666,7 @@ func (cv *collectionValidator) checkIndexes(value any) error { if cv.new.IsAuth() { requiredNames := []string{FieldNameTokenKey, FieldNameEmail} for _, name := range requiredNames { - if !dbutils.HasSingleColumnUniqueIndex(name, indexes) { + if _, ok := dbutils.FindSingleColumnUniqueIndex(indexes, name); !ok { return validation.NewError( "validation_missing_required_unique_index", `Missing required unique index for field "{{.fieldName}}".`, diff --git a/core/record_field_resolver_runner.go b/core/record_field_resolver_runner.go index 89af528e..6403649a 100644 --- a/core/record_field_resolver_runner.go +++ b/core/record_field_resolver_runner.go @@ -505,7 +505,8 @@ func (r *runner) processActiveProps() (*search.ResolverResult, error) { isBackRelMultiple := backRelField.IsMultiple() if !isBackRelMultiple { // additionally check if the rel field has a single column unique index - isBackRelMultiple = !dbutils.HasSingleColumnUniqueIndex(backRelField.Name, backCollection.Indexes) + _, hasUniqueIndex := dbutils.FindSingleColumnUniqueIndex(backCollection.Indexes, backRelField.Name) + isBackRelMultiple = !hasUniqueIndex } if !isBackRelMultiple { diff --git a/core/record_query.go b/core/record_query.go index 98fd8e09..28b54ccc 100644 --- a/core/record_query.go +++ b/core/record_query.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/pocketbase/dbx" + "github.com/pocketbase/pocketbase/tools/dbutils" "github.com/pocketbase/pocketbase/tools/inflector" "github.com/pocketbase/pocketbase/tools/list" "github.com/pocketbase/pocketbase/tools/search" @@ -527,20 +528,34 @@ func (app *BaseApp) FindAuthRecordByToken(token string, validTypes ...string) (* // FindAuthRecordByEmail finds the auth record associated with the provided email. // +// The email check would be case-insensitive if the related collection +// email unique index has COLLATE NOCASE specified for the email column. +// // Returns an error if it is not an auth collection or the record is not found. func (app *BaseApp) FindAuthRecordByEmail(collectionModelOrIdentifier any, email string) (*Record, error) { collection, err := getCollectionByModelOrIdentifier(app, collectionModelOrIdentifier) if err != nil { return nil, fmt.Errorf("failed to fetch auth collection: %w", err) } + if !collection.IsAuth() { return nil, fmt.Errorf("%q is not an auth collection", collection.Name) } record := &Record{} + var expr dbx.Expression + + index, ok := dbutils.FindSingleColumnUniqueIndex(collection.Indexes, FieldNameEmail) + if ok && strings.EqualFold(index.Columns[0].Collate, "nocase") { + // case-insensitive search + expr = dbx.NewExp("[["+FieldNameEmail+"]] = {:email} COLLATE NOCASE", dbx.Params{"email": email}) + } else { + expr = dbx.HashExp{FieldNameEmail: email} + } + err = app.RecordQuery(collection). - AndWhere(dbx.HashExp{FieldNameEmail: email}). + AndWhere(expr). Limit(1). One(record) if err != nil { diff --git a/core/record_query_expand.go b/core/record_query_expand.go index 79df2969..04fc016a 100644 --- a/core/record_query_expand.go +++ b/core/record_query_expand.go @@ -143,7 +143,7 @@ func (app *BaseApp) expandRecords(records []*Record, expandPath string, fetchFun MaxSelect: 2147483647, CollectionId: indirectRel.Id, } - if dbutils.HasSingleColumnUniqueIndex(indirectRelField.GetName(), indirectRel.Indexes) { + if _, ok := dbutils.FindSingleColumnUniqueIndex(indirectRel.Indexes, indirectRelField.GetName()); ok { relField.MaxSelect = 1 } relCollection = indirectRel diff --git a/core/record_query_test.go b/core/record_query_test.go index 479a79ab..fe38b406 100644 --- a/core/record_query_test.go +++ b/core/record_query_test.go @@ -11,6 +11,7 @@ import ( "github.com/pocketbase/dbx" "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/tests" + "github.com/pocketbase/pocketbase/tools/dbutils" "github.com/pocketbase/pocketbase/tools/types" ) @@ -966,23 +967,46 @@ func TestFindAuthRecordByToken(t *testing.T) { func TestFindAuthRecordByEmail(t *testing.T) { t.Parallel() - app, _ := tests.NewTestApp() - defer app.Cleanup() - scenarios := []struct { collectionIdOrName string email string + nocaseIndex bool expectError bool }{ - {"missing", "test@example.com", true}, - {"demo2", "test@example.com", true}, - {"users", "missing@example.com", true}, - {"users", "test@example.com", false}, - {"clients", "test2@example.com", false}, + {"missing", "test@example.com", false, true}, + {"demo2", "test@example.com", false, true}, + {"users", "missing@example.com", false, true}, + {"users", "test@example.com", false, false}, + {"clients", "test2@example.com", false, false}, + // case-insensitive tests + {"clients", "TeSt2@example.com", false, true}, + {"clients", "TeSt2@example.com", true, false}, } for _, s := range scenarios { t.Run(fmt.Sprintf("%s_%s", s.collectionIdOrName, s.email), func(t *testing.T) { + app, _ := tests.NewTestApp() + defer app.Cleanup() + + collection, _ := app.FindCollectionByNameOrId(s.collectionIdOrName) + if collection != nil { + emailIndex, ok := dbutils.FindSingleColumnUniqueIndex(collection.Indexes, core.FieldNameEmail) + if ok { + if s.nocaseIndex { + emailIndex.Columns[0].Collate = "nocase" + } else { + emailIndex.Columns[0].Collate = "" + } + + collection.RemoveIndex(emailIndex.IndexName) + collection.Indexes = append(collection.Indexes, emailIndex.Build()) + err := app.Save(collection) + if err != nil { + t.Fatalf("Failed to update email index: %v", err) + } + } + } + record, err := app.FindAuthRecordByEmail(s.collectionIdOrName, s.email) hasErr := err != nil @@ -994,7 +1018,7 @@ func TestFindAuthRecordByEmail(t *testing.T) { return } - if record.Email() != s.email { + if !strings.EqualFold(record.Email(), s.email) { t.Fatalf("Expected record with email %s, got %s", s.email, record.Email()) } }) diff --git a/tools/dbutils/index.go b/tools/dbutils/index.go index c89a6fc8..d71098fd 100644 --- a/tools/dbutils/index.go +++ b/tools/dbutils/index.go @@ -21,13 +21,13 @@ type IndexColumn struct { // Index represents a single parsed SQL CREATE INDEX expression. type Index struct { - Unique bool `json:"unique"` - Optional bool `json:"optional"` SchemaName string `json:"schemaName"` IndexName string `json:"indexName"` TableName string `json:"tableName"` - Columns []IndexColumn `json:"columns"` Where string `json:"where"` + Columns []IndexColumn `json:"columns"` + Unique bool `json:"unique"` + Optional bool `json:"optional"` } // IsValid checks if the current Index contains the minimum required fields to be considered valid. @@ -193,15 +193,25 @@ func ParseIndex(createIndexExpr string) Index { return result } -// HasColumnUniqueIndex loosely checks whether the specified column has -// a single column unique index (WHERE statements are ignored). -func HasSingleColumnUniqueIndex(column string, indexes []string) bool { +// FindSingleColumnUniqueIndex returns the first matching single column unique index. +func FindSingleColumnUniqueIndex(indexes []string, column string) (Index, bool) { + var index Index + for _, idx := range indexes { - parsed := ParseIndex(idx) - if parsed.Unique && len(parsed.Columns) == 1 && strings.EqualFold(parsed.Columns[0].Name, column) { - return true + index := ParseIndex(idx) + if index.Unique && len(index.Columns) == 1 && strings.EqualFold(index.Columns[0].Name, column) { + return index, true } } - return false + return index, false +} + +// Deprecated: Use `_, ok := FindSingleColumnUniqueIndex(indexes, column)` instead. +// +// HasColumnUniqueIndex loosely checks whether the specified column has +// a single column unique index (WHERE statements are ignored). +func HasSingleColumnUniqueIndex(column string, indexes []string) bool { + _, ok := FindSingleColumnUniqueIndex(indexes, column) + return ok } diff --git a/tools/dbutils/index_test.go b/tools/dbutils/index_test.go index 03218e02..389ddcb1 100644 --- a/tools/dbutils/index_test.go +++ b/tools/dbutils/index_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "strings" "testing" "github.com/pocketbase/pocketbase/tools/dbutils" @@ -312,3 +313,93 @@ func TestHasSingleColumnUniqueIndex(t *testing.T) { }) } } + +func TestFindSingleColumnUniqueIndex(t *testing.T) { + scenarios := []struct { + name string + column string + indexes []string + expected bool + }{ + { + "empty indexes", + "test", + nil, + false, + }, + { + "empty column", + "", + []string{ + "CREATE UNIQUE INDEX `index1` ON `example` (`test`)", + }, + false, + }, + { + "mismatched column", + "test", + []string{ + "CREATE UNIQUE INDEX `index1` ON `example` (`test2`)", + }, + false, + }, + { + "non unique index", + "test", + []string{ + "CREATE INDEX `index1` ON `example` (`test`)", + }, + false, + }, + { + "matching columnd and unique index", + "test", + []string{ + "CREATE UNIQUE INDEX `index1` ON `example` (`test`)", + }, + true, + }, + { + "multiple columns", + "test", + []string{ + "CREATE UNIQUE INDEX `index1` ON `example` (`test`, `test2`)", + }, + false, + }, + { + "multiple indexes", + "test", + []string{ + "CREATE UNIQUE INDEX `index1` ON `example` (`test`, `test2`)", + "CREATE UNIQUE INDEX `index2` ON `example` (`test`)", + }, + true, + }, + { + "partial unique index", + "test", + []string{ + "CREATE UNIQUE INDEX `index` ON `example` (`test`) where test != ''", + }, + true, + }, + } + + for _, s := range scenarios { + t.Run(s.name, func(t *testing.T) { + index, exists := dbutils.FindSingleColumnUniqueIndex(s.indexes, s.column) + if exists != s.expected { + t.Fatalf("Expected exists %v got %v", s.expected, exists) + } + + if !exists && len(index.Columns) > 0 { + t.Fatal("Expected index.Columns to be empty") + } + + if exists && !strings.EqualFold(index.Columns[0].Name, s.column) { + t.Fatalf("Expected to find column %q in %v", s.column, index) + } + }) + } +}