1
0
mirror of https://github.com/goreleaser/goreleaser.git synced 2025-01-06 03:13:48 +02:00

refactor: extract aws session creation, add tests

To be able to write some tests in an easy way session creation logic has
been extracted. Added tests for configuration and different providers.

#754
This commit is contained in:
Krzysztof Grodzicki 2018-08-23 21:21:37 +02:00 committed by Carlos Alexandro Becker
parent dc0e2bd766
commit 8595a80384
5 changed files with 302 additions and 26 deletions

View File

@ -0,0 +1,87 @@
package s3
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
)
type SessionBuilder interface {
Config(*aws.Config) SessionBuilder
Profile(string) SessionBuilder
Options(*session.Options) SessionBuilder
Endpoint(string) SessionBuilder
S3ForcePathStyle(bool) SessionBuilder
Build() *session.Session
}
type sessionBuilder struct {
awsConfig *aws.Config
profile string
options *session.Options
endpoint *string
forcePathStyle *bool
}
func (sb *sessionBuilder) Config(c *aws.Config) SessionBuilder {
sb.awsConfig = c
return sb
}
func (sb *sessionBuilder) Profile(p string) SessionBuilder {
sb.profile = p
return sb
}
func (sb *sessionBuilder) Endpoint(e string) SessionBuilder {
sb.endpoint = aws.String(e)
return sb
}
func (sb *sessionBuilder) S3ForcePathStyle(b bool) SessionBuilder {
sb.forcePathStyle = aws.Bool(b)
return sb
}
func (sb *sessionBuilder) Options(o *session.Options) SessionBuilder {
sb.options = o
return sb
}
func (sb *sessionBuilder) Build() *session.Session {
if sb.awsConfig == nil {
sb.awsConfig = aws.NewConfig()
}
if sb.endpoint != nil {
sb.awsConfig.Endpoint = sb.endpoint
sb.awsConfig.S3ForcePathStyle = sb.forcePathStyle
}
sb.awsConfig.Credentials = credentials.NewChainCredentials([]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{
Profile: sb.profile,
},
})
_, err := sb.awsConfig.Credentials.Get()
if err == nil {
return session.Must(session.NewSession(sb.awsConfig))
} else {
if sb.options == nil {
sb.options = &session.Options{
AssumeRoleTokenProvider: stscreds.StdinTokenProvider,
SharedConfigState: session.SharedConfigEnable,
Profile: sb.profile,
}
}
return session.Must(session.NewSessionWithOptions(*sb.options))
}
}
func newSessionBuilder() SessionBuilder {
return &sessionBuilder{}
}

View File

@ -0,0 +1,194 @@
package s3
import (
"testing"
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"time"
)
func setEnv() {
os.Setenv("AWS_ACCESS_KEY_ID", "accessKey")
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
}
func Test_awsSession(t *testing.T) {
type args struct {
profile string
}
tests := []struct {
name string
args args
wantValidSession bool
want *session.Session
before func()
after func()
expectToken string
endpoint string
S3ForcePathStyle bool
}{
{
name: "test endpoint",
before: setEnv,
endpoint: "test",
},
{
name: "test S3ForcePathStyle",
before: setEnv,
S3ForcePathStyle: true,
},
{
name: "test env provider",
args: args{
profile: "test1",
},
before: setEnv,
},
{
name: "test default shared credentials provider",
before: func() {
os.Clearenv()
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join("testdata", "credentials.ini"))
},
expectToken: "token",
},
{
name: "test default shared credentials provider",
before: func() {
os.Clearenv()
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join("testdata", "credentials.ini"))
},
expectToken: "token",
},
{
name: "test profile with shared credentials provider",
before: func() {
os.Clearenv()
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join("testdata", "credentials.ini"))
},
args: args{
profile: "no_token",
},
expectToken: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Clearenv()
defer os.Clearenv()
if tt.before != nil {
tt.before()
}
builder := newSessionBuilder()
builder.Profile(tt.args.profile)
builder.Endpoint(tt.endpoint)
builder.S3ForcePathStyle(tt.S3ForcePathStyle)
sess := builder.Build()
assert.NotNil(t, sess)
creds, err := sess.Config.Credentials.Get()
assert.Nil(t, err)
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, tt.expectToken, creds.SessionToken, "Expect token to match")
assert.Equal(t, aws.String(tt.endpoint), sess.Config.Endpoint, "Expect endpoint to match")
assert.Equal(t, aws.Bool(tt.S3ForcePathStyle), sess.Config.S3ForcePathStyle, "Expect S3ForcePathStyle to match")
})
}
}
const assumeRoleRespMsg = `
<AssumeRoleResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<AssumeRoleResult>
<AssumedRoleUser>
<Arn>arn:aws:sts::account_id:assumed-role/role/session_name</Arn>
<AssumedRoleId>AKID:session_name</AssumedRoleId>
</AssumedRoleUser>
<Credentials>
<AccessKeyId>AKID</AccessKeyId>
<SecretAccessKey>SECRET</SecretAccessKey>
<SessionToken>SESSION_TOKEN</SessionToken>
<Expiration>%s</Expiration>
</Credentials>
</AssumeRoleResult>
<ResponseMetadata>
<RequestId>request-id</RequestId>
</ResponseMetadata>
</AssumeRoleResponse>
`
func Test_awsSession_mfa(t *testing.T) {
os.Clearenv()
defer os.Clearenv()
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join("testdata", "credentials.ini"))
os.Setenv("AWS_CONFIG_FILE", filepath.Join("testdata", "config.ini"))
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.FormValue("SerialNumber"), "arn:aws:iam::1111111111:mfa/test")
assert.Equal(t, r.FormValue("TokenCode"), "tokencode")
w.Write([]byte(fmt.Sprintf(assumeRoleRespMsg, time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z"))))
}))
customProviderCalled := false
options := &session.Options{
Profile: "cloudformation@flowlab-dev",
Config: aws.Config{
Region: aws.String("eu-west-1"),
Endpoint: aws.String(server.URL),
DisableSSL: aws.Bool(true),
},
SharedConfigState: session.SharedConfigEnable,
AssumeRoleTokenProvider: func() (string, error) {
customProviderCalled = true
return "tokencode", nil
},
}
builder := newSessionBuilder()
builder.Profile("cloudformation@flowlab-dev")
builder.Options(options)
sess := builder.Build()
creds, err := sess.Config.Credentials.Get()
assert.NoError(t, err)
assert.True(t, customProviderCalled)
assert.Contains(t, creds.ProviderName, "AssumeRoleProvider")
}
func Test_awsSession_fail(t *testing.T) {
tests := []struct {
name string
}{
{
name: "should fail with no credentials",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Clearenv()
defer os.Clearenv()
builder := newSessionBuilder()
sess := builder.Build()
assert.NotNil(t, sess)
_, err := sess.Config.Credentials.Get()
assert.NotNil(t, err)
})
}
}

View File

@ -7,9 +7,6 @@ import (
"github.com/apex/log"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/goreleaser/goreleaser/internal/artifact"
"github.com/goreleaser/goreleaser/internal/pipeline"
@ -63,31 +60,13 @@ func (Pipe) Run(ctx *context.Context) error {
}
func upload(ctx *context.Context, conf config.S3) error {
var awsConfig = aws.NewConfig()
// TODO: add a test for this
builder := newSessionBuilder()
builder.Profile(conf.Profile)
if conf.Endpoint != "" {
awsConfig.Endpoint = aws.String(conf.Endpoint)
awsConfig.S3ForcePathStyle = aws.Bool(true)
}
awsConfig.Credentials = credentials.NewChainCredentials([]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{
Profile: conf.Profile,
},
})
_, err := awsConfig.Credentials.Get()
var sess *session.Session
if err == nil {
sess = session.Must(session.NewSession(awsConfig))
} else {
// Specify profile and assume an IAM role with MFA prompting for token code on stdin
sess = session.Must(session.NewSessionWithOptions(session.Options{
AssumeRoleTokenProvider: stscreds.StdinTokenProvider,
SharedConfigState: session.SharedConfigEnable,
Profile: conf.Profile,
}))
builder.Endpoint(conf.Endpoint)
builder.S3ForcePathStyle(true)
}
sess := builder.Build()
svc := s3.New(sess, &aws.Config{
Region: aws.String(conf.Region),

View File

@ -0,0 +1,4 @@
[profile cloudformation@flowlab-dev]
role_arn = arn:aws:iam::1111111111:role/CloudFormation
source_profile = user
mfa_serial = arn:aws:iam::1111111111:mfa/test

View File

@ -0,0 +1,12 @@
[default]
aws_access_key_id = accessKey
aws_secret_access_key = secret
aws_session_token = token
[no_token]
aws_access_key_id = accessKey
aws_secret_access_key = secret
[user]
aws_access_key_id = accesKey
aws_secret_access_key = secret