// Copyright 2012 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package securecookie import ( "crypto/aes" "crypto/hmac" "crypto/sha256" "encoding/base64" "errors" "fmt" "strings" "testing" ) var testCookies = []interface{}{ map[string]string{"foo": "bar"}, map[string]string{"baz": "ding"}, } var testStrings = []string{"foo", "bar", "baz"} func TestSecureCookie(t *testing.T) { // TODO test too old / too new timestamps compareMaps := func(m1, m2 map[string]interface{}) error { if len(m1) != len(m2) { return errors.New("different maps") } for k, v := range m1 { if m2[k] != v { return fmt.Errorf("Different value for key %v: expected %v, got %v", k, m2[k], v) } } return nil } s1 := New([]byte("12345"), []byte("1234567890123456")) s2 := New([]byte("54321"), []byte("6543210987654321")) value := map[string]interface{}{ "foo": "bar", "baz": 128, } for i := 0; i < 50; i++ { // Running this multiple times to check if any special character // breaks encoding/decoding. encoded, err1 := s1.Encode("sid", value) if err1 != nil { t.Error(err1) continue } dst := make(map[string]interface{}) err2 := s1.Decode("sid", encoded, &dst) if err2 != nil { t.Fatalf("%v: %v", err2, encoded) } if err := compareMaps(dst, value); err != nil { t.Fatalf("Expected %v, got %v.", value, dst) } dst2 := make(map[string]interface{}) err3 := s2.Decode("sid", encoded, &dst2) if err3 == nil { t.Fatalf("Expected failure decoding.") } } } func TestDecodeInvalid(t *testing.T) { // List of invalid cookies, which must not be accepted, base64-decoded // (they will be encoded before passing to Decode). invalidCookies := []string{ "", " ", "\n", "||", "|||", "cookie", } s := New([]byte("12345"), nil) var dst string for i, v := range invalidCookies { err := s.Decode("name", base64.StdEncoding.EncodeToString([]byte(v)), &dst) if err == nil { t.Fatalf("%d: expected failure decoding", i) } } } func TestAuthentication(t *testing.T) { hash := hmac.New(sha256.New, []byte("secret-key")) for _, value := range testStrings { hash.Reset() signed := createMac(hash, []byte(value)) hash.Reset() err := verifyMac(hash, []byte(value), signed) if err != nil { t.Error(err) } } } func TestEncription(t *testing.T) { block, err := aes.NewCipher([]byte("1234567890123456")) if err != nil { t.Fatalf("Block could not be created") } var encrypted, decrypted []byte for _, value := range testStrings { if encrypted, err = encrypt(block, []byte(value)); err != nil { t.Error(err) } else { if decrypted, err = decrypt(block, encrypted); err != nil { t.Error(err) } if string(decrypted) != value { t.Errorf("Expected %v, got %v.", value, string(decrypted)) } } } } func TestSerialization(t *testing.T) { var ( serialized []byte deserialized map[string]string err error ) for _, value := range testCookies { if serialized, err = serialize(value); err != nil { t.Error(err) } else { deserialized = make(map[string]string) if err = deserialize(serialized, &deserialized); err != nil { t.Error(err) } if fmt.Sprintf("%v", deserialized) != fmt.Sprintf("%v", value) { t.Errorf("Expected %v, got %v.", value, deserialized) } } } } func TestEncoding(t *testing.T) { for _, value := range testStrings { encoded := encode([]byte(value)) decoded, err := decode(encoded) if err != nil { t.Error(err) } else if string(decoded) != value { t.Errorf("Expected %v, got %s.", value, string(decoded)) } } } func TestMultiError(t *testing.T) { s1, s2 := New(nil, nil), New(nil, nil) _, err := EncodeMulti("sid", "value", s1, s2) if len(err.(MultiError)) != 2 { t.Errorf("Expected 2 errors, got %s.", err) } else { if strings.Index(err.Error(), "hash key is not set") == -1 { t.Errorf("Expected missing hash key error, got %s.", err.Error()) } } } func TestMultiNoCodecs(t *testing.T) { _, err := EncodeMulti("foo", "bar") if err != errNoCodecs { t.Errorf("EncodeMulti: bad value for error, got: %v", err) } var dst []byte err = DecodeMulti("foo", "bar", &dst) if err != errNoCodecs { t.Errorf("DecodeMulti: bad value for error, got: %v", err) } } func TestMissingKey(t *testing.T) { s1 := New(nil, nil) var dst []byte err := s1.Decode("sid", "value", &dst) if err != errHashKeyNotSet { t.Fatalf("Expected %#v, got %#v", errHashKeyNotSet, err) } } // ---------------------------------------------------------------------------- type FooBar struct { Foo int Bar string } func TestCustomType(t *testing.T) { s1 := New([]byte("12345"), []byte("1234567890123456")) // Type is not registered in gob. (!!!) src := &FooBar{42, "bar"} encoded, _ := s1.Encode("sid", src) dst := &FooBar{} _ = s1.Decode("sid", encoded, dst) if dst.Foo != 42 || dst.Bar != "bar" { t.Fatalf("Expected %#v, got %#v", src, dst) } }