2024-09-29 19:23:19 +03:00
package apis
import (
"context"
"errors"
"fmt"
2024-10-14 14:31:39 +03:00
"log/slog"
2024-09-29 19:23:19 +03:00
"maps"
"net/http"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/auth"
"github.com/pocketbase/pocketbase/tools/dbutils"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/security"
"golang.org/x/oauth2"
)
func recordAuthWithOAuth2 ( e * core . RequestEvent ) error {
collection , err := findAuthCollection ( e )
if err != nil {
return err
}
if ! collection . OAuth2 . Enabled {
return e . ForbiddenError ( "The collection is not configured to allow OAuth2 authentication." , nil )
}
var fallbackAuthRecord * core . Record
if e . Auth != nil && e . Auth . Collection ( ) . Id == collection . Id {
fallbackAuthRecord = e . Auth
}
form := new ( recordOAuth2LoginForm )
form . collection = collection
if err = e . BindBody ( form ) ; err != nil {
return firstApiError ( err , e . BadRequestError ( "An error occurred while loading the submitted data." , err ) )
}
if form . RedirectUrl != "" && form . RedirectURL == "" {
e . App . Logger ( ) . Warn ( "[recordAuthWithOAuth2] redirectUrl body param is deprecated and will be removed in the future. Please replace it with redirectURL." )
form . RedirectURL = form . RedirectUrl
}
if err = form . validate ( ) ; err != nil {
return firstApiError ( err , e . BadRequestError ( "An error occurred while loading the submitted data." , err ) )
}
// exchange token for OAuth2 user info and locate existing ExternalAuth rel
// ---------------------------------------------------------------
// load provider configuration
providerConfig , ok := collection . OAuth2 . GetProviderConfig ( form . Provider )
if ! ok {
return e . InternalServerError ( "Missing or invalid provider config." , nil )
}
provider , err := providerConfig . InitProvider ( )
if err != nil {
return firstApiError ( err , e . InternalServerError ( "Failed to init provider " + form . Provider , err ) )
}
ctx , cancel := context . WithTimeout ( e . Request . Context ( ) , 30 * time . Second )
defer cancel ( )
provider . SetContext ( ctx )
provider . SetRedirectURL ( form . RedirectURL )
var opts [ ] oauth2 . AuthCodeOption
if provider . PKCE ( ) {
opts = append ( opts , oauth2 . SetAuthURLParam ( "code_verifier" , form . CodeVerifier ) )
}
// fetch token
token , err := provider . FetchToken ( form . Code , opts ... )
if err != nil {
return firstApiError ( err , e . BadRequestError ( "Failed to fetch OAuth2 token." , err ) )
}
// fetch external auth user
authUser , err := provider . FetchAuthUser ( token )
if err != nil {
return firstApiError ( err , e . BadRequestError ( "Failed to fetch OAuth2 user." , err ) )
}
var authRecord * core . Record
2024-10-05 22:01:06 +03:00
// check for existing relation with the auth collection
2024-09-29 19:23:19 +03:00
externalAuthRel , err := e . App . FindFirstExternalAuthByExpr ( dbx . HashExp {
"collectionRef" : form . collection . Id ,
"provider" : form . Provider ,
"providerId" : authUser . Id ,
} )
switch {
case err == nil && externalAuthRel != nil :
authRecord , err = e . App . FindRecordById ( form . collection , externalAuthRel . RecordRef ( ) )
if err != nil {
return err
}
case fallbackAuthRecord != nil && fallbackAuthRecord . Collection ( ) . Id == form . collection . Id :
// fallback to the logged auth record (if any)
authRecord = fallbackAuthRecord
case authUser . Email != "" :
// look for an existing auth record by the external auth record's email
authRecord , _ = e . App . FindAuthRecordByEmail ( form . collection . Id , authUser . Email )
}
// ---------------------------------------------------------------
event := new ( core . RecordAuthWithOAuth2RequestEvent )
event . RequestEvent = e
event . Collection = collection
event . ProviderName = form . Provider
event . ProviderClient = provider
event . OAuth2User = authUser
event . CreateData = form . CreateData
event . Record = authRecord
event . IsNewRecord = authRecord == nil
return e . App . OnRecordAuthWithOAuth2Request ( ) . Trigger ( event , func ( e * core . RecordAuthWithOAuth2RequestEvent ) error {
if err := oauth2Submit ( e , externalAuthRel ) ; err != nil {
return firstApiError ( err , e . BadRequestError ( "Failed to authenticate." , err ) )
}
meta := struct {
* auth . AuthUser
IsNew bool ` json:"isNew" `
} {
AuthUser : e . OAuth2User ,
IsNew : e . IsNewRecord ,
}
return RecordAuthResponse ( e . RequestEvent , e . Record , core . MFAMethodOAuth2 , meta )
} )
}
// -------------------------------------------------------------------
type recordOAuth2LoginForm struct {
collection * core . Collection
// Additional data that will be used for creating a new auth record
// if an existing OAuth2 account doesn't exist.
CreateData map [ string ] any ` form:"createData" json:"createData" `
// The name of the OAuth2 client provider (eg. "google")
Provider string ` form:"provider" json:"provider" `
// The authorization code returned from the initial request.
Code string ` form:"code" json:"code" `
// The optional PKCE code verifier as part of the code_challenge sent with the initial request.
CodeVerifier string ` form:"codeVerifier" json:"codeVerifier" `
// The redirect url sent with the initial request.
RedirectURL string ` form:"redirectURL" json:"redirectURL" `
// @todo
// deprecated: use RedirectURL instead
// RedirectUrl will be removed after dropping v0.22 support
RedirectUrl string ` form:"redirectUrl" json:"redirectUrl" `
}
func ( form * recordOAuth2LoginForm ) validate ( ) error {
return validation . ValidateStruct ( form ,
validation . Field ( & form . Provider , validation . Required , validation . By ( form . checkProviderName ) ) ,
validation . Field ( & form . Code , validation . Required ) ,
validation . Field ( & form . RedirectURL , validation . Required ) ,
)
}
func ( form * recordOAuth2LoginForm ) checkProviderName ( value any ) error {
name , _ := value . ( string )
_ , ok := form . collection . OAuth2 . GetProviderConfig ( name )
if ! ok {
return validation . NewError ( "validation_invalid_provider" , fmt . Sprintf ( "Provider with name %q is missing or is not enabled." , name ) ) .
SetParams ( map [ string ] any { "name" : name } )
}
return nil
}
func oldCanAssignUsername ( txApp core . App , collection * core . Collection , username string ) bool {
// ensure that username is unique
checkUnique := dbutils . HasSingleColumnUniqueIndex ( collection . OAuth2 . MappedFields . Username , collection . Indexes )
if checkUnique {
if _ , err := txApp . FindFirstRecordByData ( collection , collection . OAuth2 . MappedFields . Username , username ) ; err == nil {
return false // already exist
}
}
// ensure that the value matches the pattern of the username field (if text)
txtField , _ := collection . Fields . GetByName ( collection . OAuth2 . MappedFields . Username ) . ( * core . TextField )
return txtField != nil && txtField . ValidatePlainValue ( username ) == nil
}
func oauth2Submit ( e * core . RecordAuthWithOAuth2RequestEvent , optExternalAuth * core . ExternalAuth ) error {
return e . App . RunInTransaction ( func ( txApp core . App ) error {
if e . Record == nil {
// extra check to prevent creating a superuser record via
// OAuth2 in case the method is used by another action
if e . Collection . Name == core . CollectionNameSuperusers {
return errors . New ( "superusers are not allowed to sign-up with OAuth2" )
}
payload := maps . Clone ( e . CreateData )
if payload == nil {
payload = map [ string ] any { }
}
payload [ core . FieldNameEmail ] = e . OAuth2User . Email
// set a random password if none is set
if v , _ := payload [ core . FieldNamePassword ] . ( string ) ; v == "" {
payload [ core . FieldNamePassword ] = security . RandomString ( 30 )
payload [ core . FieldNamePassword + "Confirm" ] = payload [ core . FieldNamePassword ]
}
// map known fields (unless the field was explicitly submitted as part of CreateData)
if _ , ok := payload [ e . Collection . OAuth2 . MappedFields . Id ] ; ! ok && e . Collection . OAuth2 . MappedFields . Id != "" {
payload [ e . Collection . OAuth2 . MappedFields . Id ] = e . OAuth2User . Id
}
if _ , ok := payload [ e . Collection . OAuth2 . MappedFields . Name ] ; ! ok && e . Collection . OAuth2 . MappedFields . Name != "" {
payload [ e . Collection . OAuth2 . MappedFields . Name ] = e . OAuth2User . Name
}
if _ , ok := payload [ e . Collection . OAuth2 . MappedFields . Username ] ; ! ok &&
// no explicit username payload value and existing OAuth2 mapping
e . Collection . OAuth2 . MappedFields . Username != "" &&
// extra checks for backward compatibility with earlier versions
oldCanAssignUsername ( txApp , e . Collection , e . OAuth2User . Username ) {
payload [ e . Collection . OAuth2 . MappedFields . Username ] = e . OAuth2User . Username
}
2024-10-14 14:31:39 +03:00
if _ , ok := payload [ e . Collection . OAuth2 . MappedFields . AvatarURL ] ; ! ok &&
// no existing OAuth2 mapping
e . Collection . OAuth2 . MappedFields . AvatarURL != "" &&
// non-empty OAuth2 avatar url
e . OAuth2User . AvatarURL != "" {
2024-09-29 19:23:19 +03:00
mappedField := e . Collection . Fields . GetByName ( e . Collection . OAuth2 . MappedFields . AvatarURL )
if mappedField != nil && mappedField . Type ( ) == core . FieldTypeFile {
// download the avatar if the mapped field is a file
avatarFile , err := func ( ) ( * filesystem . File , error ) {
ctx , cancel := context . WithTimeout ( context . Background ( ) , 15 * time . Second )
defer cancel ( )
return filesystem . NewFileFromURL ( ctx , e . OAuth2User . AvatarURL )
} ( )
if err != nil {
2024-10-14 14:31:39 +03:00
txApp . Logger ( ) . Warn ( "Failed to retrieve OAuth2 avatar" , slog . String ( "error" , err . Error ( ) ) )
} else {
payload [ e . Collection . OAuth2 . MappedFields . AvatarURL ] = avatarFile
2024-09-29 19:23:19 +03:00
}
} else {
// otherwise - assign the url string
payload [ e . Collection . OAuth2 . MappedFields . AvatarURL ] = e . OAuth2User . AvatarURL
}
}
createdRecord , err := sendOAuth2RecordCreateRequest ( txApp , e , payload )
if err != nil {
return err
}
e . Record = createdRecord
if e . Record . Email ( ) == e . OAuth2User . Email && ! e . Record . Verified ( ) {
// mark as verified as long as it matches the OAuth2 data (even if the email is empty)
e . Record . SetVerified ( true )
if err := txApp . Save ( e . Record ) ; err != nil {
return err
}
}
} else {
var needUpdate bool
isLoggedAuthRecord := e . Auth != nil &&
e . Auth . Id == e . Record . Id &&
e . Auth . Collection ( ) . Id == e . Record . Collection ( ) . Id
// set random password for users with unverified email
// (this is in case a malicious actor has registered previously with the user email)
if ! isLoggedAuthRecord && e . Record . Email ( ) != "" && ! e . Record . Verified ( ) {
e . Record . SetPassword ( security . RandomString ( 30 ) )
needUpdate = true
}
// update the existing auth record empty email if the data.OAuth2User has one
// (this is in case previously the auth record was created
// with an OAuth2 provider that didn't return an email address)
if e . Record . Email ( ) == "" && e . OAuth2User . Email != "" {
e . Record . SetEmail ( e . OAuth2User . Email )
needUpdate = true
}
// update the existing auth record verified state
// (only if the auth record doesn't have an email or the auth record email match with the one in data.OAuth2User)
if ! e . Record . Verified ( ) && ( e . Record . Email ( ) == "" || e . Record . Email ( ) == e . OAuth2User . Email ) {
e . Record . SetVerified ( true )
needUpdate = true
}
if needUpdate {
if err := txApp . Save ( e . Record ) ; err != nil {
return err
}
}
}
// create ExternalAuth relation if missing
if optExternalAuth == nil {
optExternalAuth = core . NewExternalAuth ( txApp )
optExternalAuth . SetCollectionRef ( e . Record . Collection ( ) . Id )
optExternalAuth . SetRecordRef ( e . Record . Id )
optExternalAuth . SetProvider ( e . ProviderName )
optExternalAuth . SetProviderId ( e . OAuth2User . Id )
if err := txApp . Save ( optExternalAuth ) ; err != nil {
return fmt . Errorf ( "failed to save linked rel: %w" , err )
}
}
return nil
} )
}
func sendOAuth2RecordCreateRequest ( txApp core . App , e * core . RecordAuthWithOAuth2RequestEvent , payload map [ string ] any ) ( * core . Record , error ) {
ir := & core . InternalRequest {
Method : http . MethodPost ,
URL : "/api/collections/" + e . Collection . Name + "/records" ,
Body : payload ,
}
2024-10-24 08:37:22 +03:00
var createdRecord * core . Record
response , err := processInternalRequest ( txApp , e . RequestEvent , ir , core . RequestInfoContextOAuth2 , func ( data any ) error {
createdRecord , _ = data . ( * core . Record )
2024-09-29 19:23:19 +03:00
2024-10-24 08:37:22 +03:00
return nil
} )
2024-09-29 19:23:19 +03:00
if err != nil {
return nil , err
}
2024-10-24 08:37:22 +03:00
if response . Status != http . StatusOK || createdRecord == nil {
return nil , errors . New ( "failed to create OAuth2 auth record" )
2024-09-29 19:23:19 +03:00
}
2024-10-24 08:37:22 +03:00
return createdRecord , nil
2024-09-29 19:23:19 +03:00
}