1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-05-21 22:33:38 +02:00
oauth2-proxy/providers/ms_entra_id_test.go

186 lines
5.6 KiB
Go
Raw Permalink Normal View History

package providers
import (
"context"
"crypto/rand"
"crypto/rsa"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt/v5"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/stretchr/testify/assert"
. "github.com/onsi/gomega"
)
func TestAzureEntraOIDCProviderNewMultiTenant(t *testing.T) {
g := NewWithT(t)
provider := NewMicrosoftEntraIDProvider(&ProviderData{},
options.Provider{OIDCConfig: options.OIDCOptions{
IssuerURL: "https://login.microsoftonline.com/common/v2.0",
InsecureSkipIssuerVerification: true,
}},
)
g.Expect(provider.Data().ProviderName).To(Equal("Microsoft Entra ID"))
}
func TestAzureEntraOIDCProviderNewSingleTenant(t *testing.T) {
g := NewWithT(t)
provider := NewMicrosoftEntraIDProvider(&ProviderData{},
options.Provider{OIDCConfig: options.OIDCOptions{
IssuerURL: "https://login.microsoftonline.com/18014347-dd57-41a1-8191-7a1f734ea457/v2.0",
}},
)
g.Expect(provider.Data().ProviderName).To(Equal("Microsoft Entra ID"))
}
func TestAzureEntraOIDCProviderEnrichSessionGroupOverage(t *testing.T) {
// Create ID Token that indicates group overage with _claim_names
key, _ := rsa.GenerateKey(rand.Reader, 2048)
claimsWithGroupOverage := &claimsWithGroupOverage{
jwt.RegisteredClaims{
Issuer: "https://login.microsoftonline.com/18014347-dd57-41a1-8191-7a1f734ea457/v2.0",
},
map[string]string{"groups": "src1"},
}
jwtWithClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, claimsWithGroupOverage)
signedJWT, err := jwtWithClaims.SignedString(key)
assert.NoError(t, err)
session := CreateAuthorizedSession()
session.IDToken = signedJWT
session.Email = "mock@example.com"
// Create provider
provider := NewMicrosoftEntraIDProvider(&ProviderData{},
options.Provider{OIDCConfig: options.OIDCOptions{
IssuerURL: "https://login.microsoftonline.com/18014347-dd57-41a1-8191-7a1f734ea457/v2.0",
}},
)
// Create mocked Azure Graph server and override Graph URL
mockedGraph := mockGraphAPI(false)
mockedGraphURL, _ := url.Parse(mockedGraph.URL)
updateURL(provider.microsoftGraphURL, mockedGraphURL.Host)
// Test EnrichSession
err = provider.EnrichSession(context.Background(), session)
assert.NoError(t, err)
assert.Contains(t, session.Groups, "85d7d600-7804-4d92-8d43-9c33c21c130c")
assert.Contains(t, session.Groups, "916f0604-8a3b-4a69-bda9-06db11a8f0cd")
assert.Contains(t, session.Groups, "b1aef995-6b55-4ac6-bbfe-e829810e9352", "Pagination using $skiptoken failed")
}
func TestAzureEntraOIDCProviderValidateSessionAllowedTenants(t *testing.T) {
// Create multi-tenant Azure Entra provider with allowed tenants
provider := NewMicrosoftEntraIDProvider(
&ProviderData{
Verifier: &mockedVerifier{},
},
options.Provider{
OIDCConfig: options.OIDCOptions{
IssuerURL: "https://login.microsoftonline.com/common/v2.0",
InsecureSkipIssuerVerification: true,
InsecureSkipNonce: true,
},
MicrosoftEntraIDConfig: options.MicrosoftEntraIDOptions{
AllowedTenants: []string{"85d7d600-7804-4d92-8d43-9c33c21c130c"},
},
},
)
// Check for invalid tenant
key, _ := rsa.GenerateKey(rand.Reader, 2048)
idToken := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.RegisteredClaims{
Issuer: "https://login.microsoftonline.com/invalid_tenant/v2.0",
})
invalidJWT, err := idToken.SignedString(key)
assert.NoError(t, err)
session := CreateAuthorizedSession()
session.IDToken = invalidJWT
valid := provider.ValidateSession(context.Background(), session)
assert.False(t, valid)
// Check for valid tenant
idToken = jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.RegisteredClaims{
Issuer: "https://login.microsoftonline.com/85d7d600-7804-4d92-8d43-9c33c21c130c/v2.0",
})
validJWT, err := idToken.SignedString(key)
assert.NoError(t, err)
session = CreateAuthorizedSession()
session.IDToken = validJWT
valid = provider.ValidateSession(context.Background(), session)
assert.True(t, valid)
}
func mockGraphAPI(noGroupMemberPermissions bool) *httptest.Server {
groupsPath := "/v1.0/me/transitiveMemberOf"
return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if noGroupMemberPermissions {
w.WriteHeader(401)
} else if r.URL.Path == groupsPath && r.Method == http.MethodGet && len(r.URL.Query()["$skiptoken"]) > 0 {
// Second page (pagination)
w.Write([]byte(`{
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects(id)",
"value": [
{
"@odata.type": "#microsoft.graph.group",
"id": "b1aef995-6b55-4ac6-bbfe-e829810e9352"
}
]
}`))
} else if r.URL.Path == groupsPath && r.Method == http.MethodGet {
// First page (pagination)
w.Write([]byte(fmt.Sprintf(`{
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects(id)",
"@odata.nextLink": "http://%s/v1.0/me/transitiveMemberOf?$select=id&$top=2&$skiptoken=TEST_TOKEN",
"value": [
{
"@odata.type": "#microsoft.graph.group",
"id": "85d7d600-7804-4d92-8d43-9c33c21c130c"
},
{
"@odata.type": "#microsoft.graph.group",
"id": "916f0604-8a3b-4a69-bda9-06db11a8f0cd"
}
]
}`, r.Host)))
}
},
))
}
type claimsWithGroupOverage struct {
jwt.RegisteredClaims
ClaimNames interface{} `json:"_claim_names,omitempty"`
}
func (c *claimsWithGroupOverage) Valid() error {
return nil
}
type mockedVerifier struct {
}
func (v *mockedVerifier) Verify(ctx context.Context, rawIDToken string) (*oidc.IDToken, error) {
return nil, nil
}