From 8595a803843f9be7b90c21378e0ec72a8f900817 Mon Sep 17 00:00:00 2001 From: Krzysztof Grodzicki Date: Thu, 23 Aug 2018 21:21:37 +0200 Subject: [PATCH] refactor: extract aws session creation, add tests To be able to write some tests in an easy way session creation logic has been extracted. Added tests for configuration and different providers. #754 --- internal/pipeline/s3/awssession.go | 87 ++++++++ internal/pipeline/s3/awssession_test.go | 194 ++++++++++++++++++ internal/pipeline/s3/s3.go | 31 +-- internal/pipeline/s3/testdata/config.ini | 4 + internal/pipeline/s3/testdata/credentials.ini | 12 ++ 5 files changed, 302 insertions(+), 26 deletions(-) create mode 100644 internal/pipeline/s3/awssession.go create mode 100644 internal/pipeline/s3/awssession_test.go create mode 100644 internal/pipeline/s3/testdata/config.ini create mode 100644 internal/pipeline/s3/testdata/credentials.ini diff --git a/internal/pipeline/s3/awssession.go b/internal/pipeline/s3/awssession.go new file mode 100644 index 000000000..57cfd6739 --- /dev/null +++ b/internal/pipeline/s3/awssession.go @@ -0,0 +1,87 @@ +package s3 + +import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + "github.com/aws/aws-sdk-go/aws/session" +) + +type SessionBuilder interface { + Config(*aws.Config) SessionBuilder + Profile(string) SessionBuilder + Options(*session.Options) SessionBuilder + Endpoint(string) SessionBuilder + S3ForcePathStyle(bool) SessionBuilder + Build() *session.Session +} + +type sessionBuilder struct { + awsConfig *aws.Config + profile string + options *session.Options + endpoint *string + forcePathStyle *bool +} + +func (sb *sessionBuilder) Config(c *aws.Config) SessionBuilder { + sb.awsConfig = c + return sb +} + +func (sb *sessionBuilder) Profile(p string) SessionBuilder { + sb.profile = p + return sb +} + +func (sb *sessionBuilder) Endpoint(e string) SessionBuilder { + sb.endpoint = aws.String(e) + return sb +} + +func (sb *sessionBuilder) S3ForcePathStyle(b bool) SessionBuilder { + sb.forcePathStyle = aws.Bool(b) + return sb +} + +func (sb *sessionBuilder) Options(o *session.Options) SessionBuilder { + sb.options = o + return sb +} + +func (sb *sessionBuilder) Build() *session.Session { + if sb.awsConfig == nil { + sb.awsConfig = aws.NewConfig() + } + + if sb.endpoint != nil { + sb.awsConfig.Endpoint = sb.endpoint + sb.awsConfig.S3ForcePathStyle = sb.forcePathStyle + } + + sb.awsConfig.Credentials = credentials.NewChainCredentials([]credentials.Provider{ + &credentials.EnvProvider{}, + &credentials.SharedCredentialsProvider{ + Profile: sb.profile, + }, + }) + + _, err := sb.awsConfig.Credentials.Get() + if err == nil { + return session.Must(session.NewSession(sb.awsConfig)) + } else { + if sb.options == nil { + sb.options = &session.Options{ + AssumeRoleTokenProvider: stscreds.StdinTokenProvider, + SharedConfigState: session.SharedConfigEnable, + Profile: sb.profile, + } + } + + return session.Must(session.NewSessionWithOptions(*sb.options)) + } +} + +func newSessionBuilder() SessionBuilder { + return &sessionBuilder{} +} diff --git a/internal/pipeline/s3/awssession_test.go b/internal/pipeline/s3/awssession_test.go new file mode 100644 index 000000000..bfaa24b6a --- /dev/null +++ b/internal/pipeline/s3/awssession_test.go @@ -0,0 +1,194 @@ +package s3 + +import ( + "testing" + + "fmt" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "time" +) + +func setEnv() { + os.Setenv("AWS_ACCESS_KEY_ID", "accessKey") + os.Setenv("AWS_SECRET_ACCESS_KEY", "secret") +} + +func Test_awsSession(t *testing.T) { + type args struct { + profile string + } + + tests := []struct { + name string + args args + wantValidSession bool + want *session.Session + before func() + after func() + expectToken string + endpoint string + S3ForcePathStyle bool + }{ + { + name: "test endpoint", + before: setEnv, + endpoint: "test", + }, + { + name: "test S3ForcePathStyle", + before: setEnv, + S3ForcePathStyle: true, + }, + { + name: "test env provider", + args: args{ + profile: "test1", + }, + before: setEnv, + }, + { + name: "test default shared credentials provider", + before: func() { + os.Clearenv() + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join("testdata", "credentials.ini")) + }, + expectToken: "token", + }, + { + name: "test default shared credentials provider", + before: func() { + os.Clearenv() + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join("testdata", "credentials.ini")) + }, + expectToken: "token", + }, + { + name: "test profile with shared credentials provider", + before: func() { + os.Clearenv() + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join("testdata", "credentials.ini")) + }, + args: args{ + profile: "no_token", + }, + expectToken: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Clearenv() + defer os.Clearenv() + if tt.before != nil { + tt.before() + } + + builder := newSessionBuilder() + builder.Profile(tt.args.profile) + builder.Endpoint(tt.endpoint) + builder.S3ForcePathStyle(tt.S3ForcePathStyle) + sess := builder.Build() + assert.NotNil(t, sess) + + creds, err := sess.Config.Credentials.Get() + assert.Nil(t, err) + + assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match") + assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match") + assert.Equal(t, tt.expectToken, creds.SessionToken, "Expect token to match") + + assert.Equal(t, aws.String(tt.endpoint), sess.Config.Endpoint, "Expect endpoint to match") + assert.Equal(t, aws.Bool(tt.S3ForcePathStyle), sess.Config.S3ForcePathStyle, "Expect S3ForcePathStyle to match") + }) + } +} + +const assumeRoleRespMsg = ` + + + + arn:aws:sts::account_id:assumed-role/role/session_name + AKID:session_name + + + AKID + SECRET + SESSION_TOKEN + %s + + + + request-id + + +` + +func Test_awsSession_mfa(t *testing.T) { + os.Clearenv() + defer os.Clearenv() + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join("testdata", "credentials.ini")) + os.Setenv("AWS_CONFIG_FILE", filepath.Join("testdata", "config.ini")) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.FormValue("SerialNumber"), "arn:aws:iam::1111111111:mfa/test") + assert.Equal(t, r.FormValue("TokenCode"), "tokencode") + + w.Write([]byte(fmt.Sprintf(assumeRoleRespMsg, time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z")))) + })) + + customProviderCalled := false + + options := &session.Options{ + Profile: "cloudformation@flowlab-dev", + Config: aws.Config{ + Region: aws.String("eu-west-1"), + Endpoint: aws.String(server.URL), + DisableSSL: aws.Bool(true), + }, + SharedConfigState: session.SharedConfigEnable, + AssumeRoleTokenProvider: func() (string, error) { + customProviderCalled = true + return "tokencode", nil + }, + } + + builder := newSessionBuilder() + builder.Profile("cloudformation@flowlab-dev") + builder.Options(options) + sess := builder.Build() + + creds, err := sess.Config.Credentials.Get() + assert.NoError(t, err) + assert.True(t, customProviderCalled) + assert.Contains(t, creds.ProviderName, "AssumeRoleProvider") +} + +func Test_awsSession_fail(t *testing.T) { + tests := []struct { + name string + }{ + { + name: "should fail with no credentials", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Clearenv() + defer os.Clearenv() + + builder := newSessionBuilder() + sess := builder.Build() + assert.NotNil(t, sess) + + _, err := sess.Config.Credentials.Get() + assert.NotNil(t, err) + }) + } +} diff --git a/internal/pipeline/s3/s3.go b/internal/pipeline/s3/s3.go index 08d70c582..2a5be7e10 100644 --- a/internal/pipeline/s3/s3.go +++ b/internal/pipeline/s3/s3.go @@ -7,9 +7,6 @@ import ( "github.com/apex/log" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" "github.com/goreleaser/goreleaser/internal/artifact" "github.com/goreleaser/goreleaser/internal/pipeline" @@ -63,31 +60,13 @@ func (Pipe) Run(ctx *context.Context) error { } func upload(ctx *context.Context, conf config.S3) error { - var awsConfig = aws.NewConfig() - // TODO: add a test for this + builder := newSessionBuilder() + builder.Profile(conf.Profile) if conf.Endpoint != "" { - awsConfig.Endpoint = aws.String(conf.Endpoint) - awsConfig.S3ForcePathStyle = aws.Bool(true) - } - awsConfig.Credentials = credentials.NewChainCredentials([]credentials.Provider{ - &credentials.EnvProvider{}, - &credentials.SharedCredentialsProvider{ - Profile: conf.Profile, - }, - }) - - _, err := awsConfig.Credentials.Get() - var sess *session.Session - if err == nil { - sess = session.Must(session.NewSession(awsConfig)) - } else { - // Specify profile and assume an IAM role with MFA prompting for token code on stdin - sess = session.Must(session.NewSessionWithOptions(session.Options{ - AssumeRoleTokenProvider: stscreds.StdinTokenProvider, - SharedConfigState: session.SharedConfigEnable, - Profile: conf.Profile, - })) + builder.Endpoint(conf.Endpoint) + builder.S3ForcePathStyle(true) } + sess := builder.Build() svc := s3.New(sess, &aws.Config{ Region: aws.String(conf.Region), diff --git a/internal/pipeline/s3/testdata/config.ini b/internal/pipeline/s3/testdata/config.ini new file mode 100644 index 000000000..96193f33b --- /dev/null +++ b/internal/pipeline/s3/testdata/config.ini @@ -0,0 +1,4 @@ +[profile cloudformation@flowlab-dev] +role_arn = arn:aws:iam::1111111111:role/CloudFormation +source_profile = user +mfa_serial = arn:aws:iam::1111111111:mfa/test diff --git a/internal/pipeline/s3/testdata/credentials.ini b/internal/pipeline/s3/testdata/credentials.ini new file mode 100644 index 000000000..1d56ae7c5 --- /dev/null +++ b/internal/pipeline/s3/testdata/credentials.ini @@ -0,0 +1,12 @@ +[default] +aws_access_key_id = accessKey +aws_secret_access_key = secret +aws_session_token = token + +[no_token] +aws_access_key_id = accessKey +aws_secret_access_key = secret + +[user] +aws_access_key_id = accesKey +aws_secret_access_key = secret