From a09eecc6a261ce6b5fe62dee64b6b1a402cfd705 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Mon, 13 Jul 2020 12:56:05 -0700 Subject: [PATCH] Reduce SessionState size better with MessagePack + LZ4 (#632) * Encode sessions with MsgPack + LZ4 Assumes ciphers are now mandatory per #414. Cookie & Redis sessions can fallback to V5 style JSON in error cases. TODO: session_state.go unit tests & new unit tests for Legacy fallback scenarios. * Only compress encoded sessions with Cookie Store * Cleanup msgpack + lz4 error handling * Change NewBase64Cipher to take in an existing Cipher * Add msgpack & lz4 session state tests * Add required options for oauthproxy tests More aggressively assert.NoError on all validation.Validate(opts) calls to enforce legal options in all our tests. Add additional NoError checks wherever error return values were ignored. * Remove support for uncompressed session state fields * Improve error verbosity & add session state tests * Ensure all marshalled sessions are valid Invalid CFB decryptions can result in garbage data that 1/100 times might cause message pack unmarshal to not fail and instead return an empty session. This adds more rigor to make sure legacy sessions cause appropriate errors. * Add tests for legacy V5 session decoding Refactor common legacy JSON test cases to a legacy helpers area under session store tests. * Make ValidateSession a struct method & add CHANGELOG entry * Improve SessionState error & comments verbosity * Move legacy session test helpers to sessions pkg Placing these helpers under the sessions pkg removed all the circular import uses in housing it under the session store area. * Improve SignatureAuthenticator test helper formatting * Make redis.legacyV5DecodeSession internal * Make LegacyV5TestCase test table public for linter --- CHANGELOG.md | 4 + go.mod | 9 +- go.sum | 23 +- oauthproxy_test.go | 625 +++++++++++------- pkg/apis/sessions/legacy_v5_tester.go | 87 +++ pkg/apis/sessions/session_state.go | 223 +++++-- pkg/apis/sessions/session_state_test.go | 450 +++++-------- pkg/encryption/cipher.go | 8 +- pkg/encryption/cipher_test.go | 13 +- pkg/sessions/cookie/session_store.go | 33 +- pkg/sessions/redis/redis_store.go | 137 ++-- ...sion_store_test.go => redis_store_test.go} | 64 ++ 12 files changed, 1006 insertions(+), 670 deletions(-) create mode 100644 pkg/apis/sessions/legacy_v5_tester.go rename pkg/sessions/redis/{session_store_test.go => redis_store_test.go} (65%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6975cf73..b78a44b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,10 +4,14 @@ ## Important Notes +- [#632](https://github.com/oauth2-proxy/oauth2-proxy/pull/632) There is backwards compatibility to sessions from v5 + - Any unencrypted sessions from before v5 that only contained a Username & Email will trigger a reauthentication + ## Breaking Changes ## Changes since v6.0.0 +- [#632](https://github.com/oauth2-proxy/oauth2-proxy/pull/632) Reduce session size by encoding with MessagePack and using LZ4 compression (@NickMeves) - [#675](https://github.com/oauth2-proxy/oauth2-proxy/pull/675) Fix required ruby version and deprecated option for building docs (@mkontani) - [#669](https://github.com/oauth2-proxy/oauth2-proxy/pull/669) Reduce docker context to improve build times (@JoelSpeed) - [#668](https://github.com/oauth2-proxy/oauth2-proxy/pull/668) Use req.Host in --force-https when req.URL.Host is empty (@zucaritask) diff --git a/go.mod b/go.mod index cacf7ab8..cd1d061a 100644 --- a/go.mod +++ b/go.mod @@ -9,23 +9,24 @@ require ( github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect github.com/coreos/go-oidc v2.2.1+incompatible github.com/dgrijalva/jwt-go v3.2.0+incompatible + github.com/frankban/quicktest v1.10.0 // indirect github.com/fsnotify/fsnotify v1.4.9 github.com/go-redis/redis/v7 v7.2.0 github.com/justinas/alice v1.2.0 - github.com/kr/pretty v0.2.0 // indirect github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa github.com/mitchellh/mapstructure v1.1.2 github.com/onsi/ginkgo v1.12.0 github.com/onsi/gomega v1.9.0 + github.com/pierrec/lz4 v2.5.2+incompatible github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect github.com/spf13/pflag v1.0.3 github.com/spf13/viper v1.6.3 github.com/stretchr/testify v1.5.1 + github.com/vmihailenco/msgpack/v4 v4.3.11 github.com/yhat/wsutil v0.0.0-20170731153501-1d66fa95c997 - golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 - golang.org/x/net v0.0.0-20200226121028-0de0cce0169b + golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 + golang.org/x/net v0.0.0-20200301022130-244492dfa37a golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d - golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 // indirect google.golang.org/api v0.20.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/square/go-jose.v2 v2.4.1 diff --git a/go.sum b/go.sum index ba0d342b..619014de 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/frankban/quicktest v1.10.0 h1:Gfh+GAJZOAoKZsIZeZbdn2JF10kN1XHNvjsvQK8gVkE= +github.com/frankban/quicktest v1.10.0/go.mod h1:ui7WezCLWMWxVWr1GETZY3smRy0G4KWq9vcPtJmFl7Y= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= @@ -69,6 +71,8 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.4 h1:87PNWwrRvUSnqS4dlcBU/ftvOIBep4sYuBLlh6rX2wk= +github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/gomodule/redigo v1.7.1-0.20190322064113-39e2c31b7ca3 h1:6amM4HsNPOvMLVc2ZnyqrjeQ92YAVWn7T4WBKK87inY= github.com/gomodule/redigo v1.7.1-0.20190322064113-39e2c31b7ca3/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/gomodule/redigo v1.8.1 h1:Abmo0bI7Xf0IhdIPc7HZQzZcShdnmxeoVuDDtIQp8N8= @@ -78,6 +82,8 @@ github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -142,6 +148,8 @@ github.com/onsi/gomega v1.9.0 h1:R1uwffexN6Pr340GtYRIdZmAiN4J+iw6WG4wog1DUXg= github.com/onsi/gomega v1.9.0/go.mod h1:Ho0h+IUsWyvy1OpqCwxlQ/21gkhVunqlU8fDGcoTdcA= github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/pierrec/lz4 v2.5.2+incompatible h1:WCjObylUIOlKy/+7Abdn34TLIkXiA4UWUMhxq9m9ZXI= +github.com/pierrec/lz4 v2.5.2+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -184,6 +192,10 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= +github.com/vmihailenco/msgpack/v4 v4.3.11 h1:Q47CePddpNGNhk4GCnAx9DDtASi2rasatE0cd26cZoE= +github.com/vmihailenco/msgpack/v4 v4.3.11/go.mod h1:gborTTJjAo/GWTqqRjrLCn9pgNN+NXzzngzBKDPIqw4= +github.com/vmihailenco/tagparser v0.1.1 h1:quXMXlA39OCbd2wAdTsGDlK9RkOk6Wuw+x37wVyIuWY= +github.com/vmihailenco/tagparser v0.1.1/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/yhat/wsutil v0.0.0-20170731153501-1d66fa95c997 h1:1+FQ4Ns+UZtUiQ4lP0sTCyKSQ0EXoiwAdHZB0Pd5t9Q= @@ -201,8 +213,6 @@ go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= @@ -217,14 +227,14 @@ golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c h1:uOCk1iQW6Vc18bnC13MfzScl+wdKBmM9Y9kU7Z83/lw= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g= golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200226121028-0de0cce0169b h1:0mm1VjtFUOIlE1SbDlwjYaDxZVDP2S5ou6y0gSgXHu8= -golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200301022130-244492dfa37a h1:GuSPYbZzB5/dcLNCwLQLsg3obCJtX9IJhpXkvY7kzk0= +golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 h1:SVwTIAaPC2U/AvvLNZ2a7OVsmBpC8L5BlwK1whH3hm0= @@ -243,7 +253,6 @@ golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b h1:ag/x1USPSsqHud38I9BAC88qdNLDHHtQ4mlgQIZPPNA= golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -276,6 +285,8 @@ google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9Ywl google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0 h1:KxkO13IPW4Lslp2bz+KHP2E3gtFlrIGNThxkZQ3g+4c= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= +google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 0510a252..5bafd048 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -23,7 +23,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" - "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie" + sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie" "github.com/oauth2-proxy/oauth2-proxy/pkg/validation" "github.com/oauth2-proxy/oauth2-proxy/providers" "github.com/stretchr/testify/assert" @@ -36,11 +36,12 @@ const ( // encoded version of this. rawCookieSecret = "secretthirtytwobytes+abcdefghijk" base64CookieSecret = "c2VjcmV0dGhpcnR5dHdvYnl0ZXMrYWJjZGVmZ2hpams" + clientID = "3984n253984d7348dm8234yf982t" + clientSecret = "gv3498mfc9t23y23974dm2394dm9" ) func init() { logger.SetFlags(logger.Lshortfile) - } type WebSocketOrRestHandler struct { @@ -61,24 +62,30 @@ func TestWebSocketProxy(t *testing.T) { restHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) hostname, _, _ := net.SplitHostPort(r.Host) - w.Write([]byte(hostname)) + _, err := w.Write([]byte(hostname)) + if err != nil { + t.Fatal(err) + } }), wsHandler: websocket.Handler(func(ws *websocket.Conn) { - defer ws.Close() + defer func(t *testing.T) { + if err := ws.Close(); err != nil { + t.Fatal(err) + } + }(t) var data []byte err := websocket.Message.Receive(ws, &data) if err != nil { - t.Fatalf("err %s", err) - return + t.Fatal(err) } err = websocket.Message.Send(ws, data) if err != nil { - t.Fatalf("err %s", err) + t.Fatal(err) } }), } backend := httptest.NewServer(&handler) - defer backend.Close() + t.Cleanup(backend.Close) backendURL, _ := url.Parse(backend.URL) @@ -87,24 +94,24 @@ func TestWebSocketProxy(t *testing.T) { opts.PassHostHeader = true proxyHandler := NewWebSocketOrRestReverseProxy(backendURL, opts, auth) frontend := httptest.NewServer(proxyHandler) - defer frontend.Close() + t.Cleanup(frontend.Close) frontendURL, _ := url.Parse(frontend.URL) frontendWSURL := "ws://" + frontendURL.Host + "/" ws, err := websocket.Dial(frontendWSURL, "", "http://localhost/") if err != nil { - t.Fatalf("err %s", err) + t.Fatal(err) } request := []byte("hello, world!") err = websocket.Message.Send(ws, request) if err != nil { - t.Fatalf("err %s", err) + t.Fatal(err) } var response = make([]byte, 1024) - websocket.Message.Receive(ws, &response) + err = websocket.Message.Receive(ws, &response) if err != nil { - t.Fatalf("err %s", err) + t.Fatal(err) } if g, e := string(request), string(response); g != e { t.Errorf("got body %q; expected %q", g, e) @@ -123,9 +130,12 @@ func TestNewReverseProxy(t *testing.T) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) hostname, _, _ := net.SplitHostPort(r.Host) - w.Write([]byte(hostname)) + _, err := w.Write([]byte(hostname)) + if err != nil { + t.Fatal(err) + } })) - defer backend.Close() + t.Cleanup(backend.Close) backendURL, _ := url.Parse(backend.URL) backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host) @@ -135,7 +145,7 @@ func TestNewReverseProxy(t *testing.T) { proxyHandler := NewReverseProxy(proxyURL, &options.Options{FlushInterval: time.Second}) setProxyUpstreamHostHeader(proxyHandler, proxyURL) frontend := httptest.NewServer(proxyHandler) - defer frontend.Close() + t.Cleanup(frontend.Close) getReq, _ := http.NewRequest("GET", frontend.URL, nil) res, _ := http.DefaultClient.Do(getReq) @@ -151,20 +161,20 @@ func TestEncodedSlashes(t *testing.T) { w.WriteHeader(200) seen = r.RequestURI })) - defer backend.Close() + t.Cleanup(backend.Close) b, _ := url.Parse(backend.URL) proxyHandler := NewReverseProxy(b, &options.Options{FlushInterval: time.Second}) setProxyDirector(proxyHandler) frontend := httptest.NewServer(proxyHandler) - defer frontend.Close() + t.Cleanup(frontend.Close) f, _ := url.Parse(frontend.URL) encodedPath := "/a%2Fb/?c=1" getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}} _, err := http.DefaultClient.Do(getReq) if err != nil { - t.Fatalf("err %s", err) + t.Fatal(err) } if seen != encodedPath { t.Errorf("got bad request %q expected %q", seen, encodedPath) @@ -173,13 +183,13 @@ func TestEncodedSlashes(t *testing.T) { func TestRobotsTxt(t *testing.T) { opts := baseTestOptions() - opts.ClientID = "asdlkjx" - opts.ClientSecret = "alkgks" - opts.Cookie.Secret = rawCookieSecret - validation.Validate(opts) + err := validation.Validate(opts) + assert.NoError(t, err) proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/robots.txt", nil) proxy.ServeHTTP(rw, req) @@ -189,9 +199,6 @@ func TestRobotsTxt(t *testing.T) { func TestIsValidRedirect(t *testing.T) { opts := baseTestOptions() - opts.ClientID = "skdlfj" - opts.ClientSecret = "fgkdsgj" - opts.Cookie.Secret = base64CookieSecret // Should match domains that are exactly foo.bar and any subdomain of bar.foo opts.WhitelistDomains = []string{ "foo.bar", @@ -201,10 +208,13 @@ func TestIsValidRedirect(t *testing.T) { "anyport.bar:*", ".sub.anyport.bar:*", } - validation.Validate(opts) + err := validation.Validate(opts) + assert.NoError(t, err) proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } testCases := []struct { Desc, Redirect string @@ -439,11 +449,7 @@ func TestIsValidRedirect(t *testing.T) { } func TestOpenRedirects(t *testing.T) { - opts := options.NewOptions() - opts.ClientID = "skdlfj" - opts.ClientSecret = "fgkdsgj" - opts.Cookie.Secret = rawCookieSecret - opts.EmailDomains = []string{"*"} + opts := baseTestOptions() // Should match domains that are exactly foo.bar and any subdomain of bar.foo opts.WhitelistDomains = []string{ "foo.bar", @@ -458,13 +464,19 @@ func TestOpenRedirects(t *testing.T) { assert.NoError(t, err) proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } file, err := os.Open("./test/openredirects.txt") if err != nil { t.Fatal(err) } - defer file.Close() + defer func(t *testing.T) { + if err := file.Close(); err != nil { + t.Fatal(err) + } + }(t) scanner := bufio.NewScanner(file) for scanner.Scan() { @@ -544,22 +556,21 @@ func TestBasicAuthPassword(t *testing.T) { } } w.WriteHeader(200) - w.Write([]byte(payload)) + _, err := w.Write([]byte(payload)) + if err != nil { + t.Fatal(err) + } })) opts := baseTestOptions() opts.Upstreams = append(opts.Upstreams, providerServer.URL) - // The CookieSecret must be 32 bytes in order to create the AES - // cipher. - opts.Cookie.Secret = "xyzzyplughxyzzyplughxyzzyplughxp" - opts.ClientID = "dlgkj" - opts.ClientSecret = "alkgret" opts.Cookie.Secure = false opts.PassBasicAuth = true opts.SetBasicAuth = true opts.PassUserHeaders = true opts.PreferEmailToUser = true opts.BasicAuthPassword = "This is a secure password" - validation.Validate(opts) + err := validation.Validate(opts) + assert.NoError(t, err) providerURL, _ := url.Parse(providerServer.URL) const emailAddress = "john.doe@example.com" @@ -568,7 +579,9 @@ func TestBasicAuthPassword(t *testing.T) { proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", strings.NewReader("")) @@ -618,8 +631,8 @@ func TestBasicAuthWithEmail(t *testing.T) { opts.PassUserHeaders = false opts.PreferEmailToUser = false opts.BasicAuthPassword = "This is a secure password" - opts.Cookie.Secret = rawCookieSecret - validation.Validate(opts) + err := validation.Validate(opts) + assert.NoError(t, err) const emailAddress = "john.doe@example.com" const userName = "9fcab5c9b889a557" @@ -641,7 +654,9 @@ func TestBasicAuthWithEmail(t *testing.T) { proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } proxy.addHeadersForProxying(rw, req, session) assert.Equal(t, expectedUserHeader, req.Header["Authorization"][0]) assert.Equal(t, userName, req.Header["X-Forwarded-User"][0]) @@ -655,7 +670,9 @@ func TestBasicAuthWithEmail(t *testing.T) { proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } proxy.addHeadersForProxying(rw, req, session) assert.Equal(t, expectedEmailHeader, req.Header["Authorization"][0]) assert.Equal(t, emailAddress, req.Header["X-Forwarded-User"][0]) @@ -664,11 +681,8 @@ func TestBasicAuthWithEmail(t *testing.T) { func TestPassUserHeadersWithEmail(t *testing.T) { opts := baseTestOptions() - opts.PassBasicAuth = false - opts.PassUserHeaders = true - opts.PreferEmailToUser = false - opts.Cookie.Secret = base64CookieSecret - validation.Validate(opts) + err := validation.Validate(opts) + assert.NoError(t, err) const emailAddress = "john.doe@example.com" const userName = "9fcab5c9b889a557" @@ -686,7 +700,9 @@ func TestPassUserHeadersWithEmail(t *testing.T) { proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } proxy.addHeadersForProxying(rw, req, session) assert.Equal(t, userName, req.Header["X-Forwarded-User"][0]) } @@ -699,7 +715,9 @@ func TestPassUserHeadersWithEmail(t *testing.T) { proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } proxy.addHeadersForProxying(rw, req, session) assert.Equal(t, emailAddress, req.Header["X-Forwarded-User"][0]) } @@ -716,10 +734,10 @@ type PassAccessTokenTestOptions struct { ProxyUpstream string } -func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest { - t := &PassAccessTokenTest{} +func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTest, error) { + patt := &PassAccessTokenTest{} - t.providerServer = httptest.NewServer( + patt.providerServer = httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var payload string switch r.URL.Path { @@ -732,35 +750,35 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes } } w.WriteHeader(200) - w.Write([]byte(payload)) + _, err := w.Write([]byte(payload)) + if err != nil { + panic(err) + } })) - t.opts = baseTestOptions() - t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL) + patt.opts = baseTestOptions() + patt.opts.Upstreams = append(patt.opts.Upstreams, patt.providerServer.URL) if opts.ProxyUpstream != "" { - t.opts.Upstreams = append(t.opts.Upstreams, opts.ProxyUpstream) + patt.opts.Upstreams = append(patt.opts.Upstreams, opts.ProxyUpstream) + } + patt.opts.Cookie.Secure = false + patt.opts.PassAccessToken = opts.PassAccessToken + err := validation.Validate(patt.opts) + if err != nil { + return nil, err } - // The CookieSecret must be 32 bytes in order to create the AES - // cipher. - t.opts.Cookie.Secret = "xyzzyplughxyzzyplughxyzzyplughxp" - t.opts.ClientID = "slgkj" - t.opts.ClientSecret = "gfjgojl" - t.opts.Cookie.Secure = false - t.opts.PassAccessToken = opts.PassAccessToken - validation.Validate(t.opts) - providerURL, _ := url.Parse(t.providerServer.URL) + providerURL, _ := url.Parse(patt.providerServer.URL) const emailAddress = "michael.bland@gsa.gov" - t.opts.SetProvider(NewTestProvider(providerURL, emailAddress)) - var err error - t.proxy, err = NewOAuthProxy(t.opts, func(email string) bool { + patt.opts.SetProvider(NewTestProvider(providerURL, emailAddress)) + patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool { return email == emailAddress }) if err != nil { - panic(err) + return nil, err } - return t + return patt, nil } func (patTest *PassAccessTokenTest) Close() { @@ -817,17 +835,20 @@ func (patTest *PassAccessTokenTest) getEndpointWithCookie(cookie string, endpoin } func TestForwardAccessTokenUpstream(t *testing.T) { - patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ + patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: true, }) - defer patTest.Close() + if err != nil { + t.Fatal(err) + } + t.Cleanup(patTest.Close) // A successful validation will redirect and set the auth cookie. code, cookie := patTest.getCallbackEndpoint() if code != 302 { t.Fatalf("expected 302; got %d", code) } - assert.NotEqual(t, nil, cookie) + assert.NotNil(t, cookie) // Now we make a regular request; the access_token from the cookie is // forwarded as the "X-Forwarded-Access-Token" header. The token is @@ -840,12 +861,14 @@ func TestForwardAccessTokenUpstream(t *testing.T) { } func TestStaticProxyUpstream(t *testing.T) { - patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ + patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: true, ProxyUpstream: "static://200/static-proxy", }) - - defer patTest.Close() + if err != nil { + t.Fatal(err) + } + t.Cleanup(patTest.Close) // A successful validation will redirect and set the auth cookie. code, cookie := patTest.getCallbackEndpoint() @@ -864,10 +887,13 @@ func TestStaticProxyUpstream(t *testing.T) { } func TestDoNotForwardAccessTokenUpstream(t *testing.T) { - patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ + patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: false, }) - defer patTest.Close() + if err != nil { + t.Fatal(err) + } + t.Cleanup(patTest.Close) // A successful validation will redirect and set the auth cookie. code, cookie := patTest.getCallbackEndpoint() @@ -895,27 +921,26 @@ type SignInPageTest struct { const signInRedirectPattern = `` const signInSkipProvider = `>Found<` -func NewSignInPageTest(skipProvider bool) *SignInPageTest { +func NewSignInPageTest(skipProvider bool) (*SignInPageTest, error) { var sipTest SignInPageTest sipTest.opts = baseTestOptions() - sipTest.opts.Cookie.Secret = rawCookieSecret - sipTest.opts.ClientID = "lkdgj" - sipTest.opts.ClientSecret = "sgiufgoi" sipTest.opts.SkipProviderButton = skipProvider - validation.Validate(sipTest.opts) + err := validation.Validate(sipTest.opts) + if err != nil { + return nil, err + } - var err error sipTest.proxy, err = NewOAuthProxy(sipTest.opts, func(email string) bool { return true }) if err != nil { - panic(err) + return nil, err } sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern) sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider) - return &sipTest + return &sipTest, nil } func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) { @@ -926,7 +951,10 @@ func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) { } func TestSignInPageIncludesTargetRedirect(t *testing.T) { - sipTest := NewSignInPageTest(false) + sipTest, err := NewSignInPageTest(false) + if err != nil { + t.Fatal(err) + } const endpoint = "/some/random/endpoint" code, body := sipTest.GetEndpoint(endpoint) @@ -944,7 +972,10 @@ func TestSignInPageIncludesTargetRedirect(t *testing.T) { } func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { - sipTest := NewSignInPageTest(false) + sipTest, err := NewSignInPageTest(false) + if err != nil { + t.Fatal(err) + } code, body := sipTest.GetEndpoint("/oauth2/sign_in") assert.Equal(t, 200, code) @@ -959,8 +990,12 @@ func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { } func TestSignInPageSkipProvider(t *testing.T) { - sipTest := NewSignInPageTest(true) - const endpoint = "/some/random/endpoint" + sipTest, err := NewSignInPageTest(true) + if err != nil { + t.Fatal(err) + } + + endpoint := "/some/random/endpoint" code, body := sipTest.GetEndpoint(endpoint) assert.Equal(t, 302, code) @@ -973,8 +1008,12 @@ func TestSignInPageSkipProvider(t *testing.T) { } func TestSignInPageSkipProviderDirect(t *testing.T) { - sipTest := NewSignInPageTest(true) - const endpoint = "/sign_in" + sipTest, err := NewSignInPageTest(true) + if err != nil { + t.Fatal(err) + } + + endpoint := "/sign_in" code, body := sipTest.GetEndpoint(endpoint) assert.Equal(t, 302, code) @@ -1000,27 +1039,26 @@ type ProcessCookieTestOpts struct { type OptionsModifier func(*options.Options) -func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) *ProcessCookieTest { +func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) (*ProcessCookieTest, error) { var pcTest ProcessCookieTest pcTest.opts = baseTestOptions() for _, modifier := range modifiers { modifier(pcTest.opts) } - pcTest.opts.ClientID = "asdfljk" - pcTest.opts.ClientSecret = "lkjfdsig" - pcTest.opts.Cookie.Secret = "0123456789abcdef0123456789abcdef" // First, set the CookieRefresh option so proxy.AesCipher is created, // needed to encrypt the access_token. pcTest.opts.Cookie.Refresh = time.Hour - validation.Validate(pcTest.opts) + err := validation.Validate(pcTest.opts) + if err != nil { + return nil, err + } - var err error pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { - panic(err) + return nil, err } pcTest.proxy.provider = &TestProvider{ ValidToken: opts.providerValidateCookieResponse, @@ -1032,16 +1070,16 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi pcTest.rw = httptest.NewRecorder() pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader("")) pcTest.validateUser = true - return &pcTest + return &pcTest, nil } -func NewProcessCookieTestWithDefaults() *ProcessCookieTest { +func NewProcessCookieTestWithDefaults() (*ProcessCookieTest, error) { return NewProcessCookieTest(ProcessCookieTestOpts{ providerValidateCookieResponse: true, }) } -func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) *ProcessCookieTest { +func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) (*ProcessCookieTest, error) { return NewProcessCookieTest(ProcessCookieTestOpts{ providerValidateCookieResponse: true, }, modifiers...) @@ -1063,37 +1101,51 @@ func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) } func TestLoadCookiedSession(t *testing.T) { - pcTest := NewProcessCookieTestWithDefaults() + pcTest, err := NewProcessCookieTestWithDefaults() + if err != nil { + t.Fatal(err) + } created := time.Now() startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: &created} - pcTest.SaveSession(startSession) + err = pcTest.SaveSession(startSession) + assert.NoError(t, err) session, err := pcTest.LoadCookiedSession() - assert.Equal(t, nil, err) + if err != nil { + t.Fatal(err) + } assert.Equal(t, startSession.Email, session.Email) assert.Equal(t, "", session.User) assert.Equal(t, startSession.AccessToken, session.AccessToken) } func TestProcessCookieNoCookieError(t *testing.T) { - pcTest := NewProcessCookieTestWithDefaults() + pcTest, err := NewProcessCookieTestWithDefaults() + if err != nil { + t.Fatal(err) + } session, err := pcTest.LoadCookiedSession() - assert.Equal(t, "cookie \"_oauth2_proxy\" not present", err.Error()) + assert.Error(t, err, "cookie \"_oauth2_proxy\" not present") if session != nil { t.Errorf("expected nil session. got %#v", session) } } func TestProcessCookieRefreshNotSet(t *testing.T) { - pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { + pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { opts.Cookie.Expire = time.Duration(23) * time.Hour }) + if err != nil { + t.Fatal(err) + } + reference := time.Now().Add(time.Duration(-2) * time.Hour) startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference} - pcTest.SaveSession(startSession) + err = pcTest.SaveSession(startSession) + assert.NoError(t, err) session, err := pcTest.LoadCookiedSession() assert.Equal(t, nil, err) @@ -1104,12 +1156,17 @@ func TestProcessCookieRefreshNotSet(t *testing.T) { } func TestProcessCookieFailIfCookieExpired(t *testing.T) { - pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { + pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { opts.Cookie.Expire = time.Duration(24) * time.Hour }) + if err != nil { + t.Fatal(err) + } + reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference} - pcTest.SaveSession(startSession) + err = pcTest.SaveSession(startSession) + assert.NoError(t, err) session, err := pcTest.LoadCookiedSession() assert.NotEqual(t, nil, err) @@ -1119,12 +1176,17 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) { } func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { - pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { + pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { opts.Cookie.Expire = time.Duration(24) * time.Hour }) + if err != nil { + t.Fatal(err) + } + reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference} - pcTest.SaveSession(startSession) + err = pcTest.SaveSession(startSession) + assert.NoError(t, err) pcTest.proxy.CookieRefresh = time.Hour session, err := pcTest.LoadCookiedSession() @@ -1134,18 +1196,26 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { } } -func NewUserInfoEndpointTest() *ProcessCookieTest { - pcTest := NewProcessCookieTestWithDefaults() +func NewUserInfoEndpointTest() (*ProcessCookieTest, error) { + pcTest, err := NewProcessCookieTestWithDefaults() + if err != nil { + return nil, err + } pcTest.req, _ = http.NewRequest("GET", pcTest.opts.ProxyPrefix+"/userinfo", nil) - return pcTest + return pcTest, nil } func TestUserInfoEndpointAccepted(t *testing.T) { - test := NewUserInfoEndpointTest() + test, err := NewUserInfoEndpointTest() + if err != nil { + t.Fatal(err) + } + startSession := &sessions.SessionState{ Email: "john.doe@example.com", AccessToken: "my_access_token"} - test.SaveSession(startSession) + err = test.SaveSession(startSession) + assert.NoError(t, err) test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusOK, test.rw.Code) @@ -1154,25 +1224,36 @@ func TestUserInfoEndpointAccepted(t *testing.T) { } func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { - test := NewUserInfoEndpointTest() + test, err := NewUserInfoEndpointTest() + if err != nil { + t.Fatal(err) + } test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) } -func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) *ProcessCookieTest { - pcTest := NewProcessCookieTestWithOptionsModifiers(modifiers...) +func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) (*ProcessCookieTest, error) { + pcTest, err := NewProcessCookieTestWithOptionsModifiers(modifiers...) + if err != nil { + return nil, err + } pcTest.req, _ = http.NewRequest("GET", pcTest.opts.ProxyPrefix+"/auth", nil) - return pcTest + return pcTest, nil } func TestAuthOnlyEndpointAccepted(t *testing.T) { - test := NewAuthOnlyEndpointTest() + test, err := NewAuthOnlyEndpointTest() + if err != nil { + t.Fatal(err) + } + created := time.Now() startSession := &sessions.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created} - test.SaveSession(startSession) + err = test.SaveSession(startSession) + assert.NoError(t, err) test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusAccepted, test.rw.Code) @@ -1181,7 +1262,10 @@ func TestAuthOnlyEndpointAccepted(t *testing.T) { } func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { - test := NewAuthOnlyEndpointTest() + test, err := NewAuthOnlyEndpointTest() + if err != nil { + t.Fatal(err) + } test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) @@ -1190,13 +1274,18 @@ func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { } func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { - test := NewAuthOnlyEndpointTest(func(opts *options.Options) { + test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) { opts.Cookie.Expire = time.Duration(24) * time.Hour }) + if err != nil { + t.Fatal(err) + } + reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &sessions.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference} - test.SaveSession(startSession) + err = test.SaveSession(startSession) + assert.NoError(t, err) test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) @@ -1205,11 +1294,16 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { } func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { - test := NewAuthOnlyEndpointTest() + test, err := NewAuthOnlyEndpointTest() + if err != nil { + t.Fatal(err) + } + created := time.Now() startSession := &sessions.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created} - test.SaveSession(startSession) + err = test.SaveSession(startSession) + assert.NoError(t, err) test.validateUser = false test.proxy.ServeHTTP(test.rw, test.req) @@ -1224,15 +1318,13 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { pcTest.opts = baseTestOptions() pcTest.opts.SetXAuthRequest = true err := validation.Validate(pcTest.opts) - if err != nil { - panic(err) - } + assert.NoError(t, err) pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { - panic(err) + t.Fatal(err) } pcTest.proxy.provider = &TestProvider{ ValidToken: true, @@ -1247,7 +1339,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { created := time.Now() startSession := &sessions.SessionState{ User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created} - pcTest.SaveSession(startSession) + err = pcTest.SaveSession(startSession) + assert.NoError(t, err) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) @@ -1261,14 +1354,14 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { pcTest.opts = baseTestOptions() pcTest.opts.SetXAuthRequest = true pcTest.opts.SetBasicAuth = true - validation.Validate(pcTest.opts) + err := validation.Validate(pcTest.opts) + assert.NoError(t, err) - var err error pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { - panic(err) + t.Fatal(err) } pcTest.proxy.provider = &TestProvider{ ValidToken: true, @@ -1283,7 +1376,8 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { created := time.Now() startSession := &sessions.SessionState{ User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created} - pcTest.SaveSession(startSession) + err = pcTest.SaveSession(startSession) + assert.NoError(t, err) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) @@ -1299,14 +1393,14 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { pcTest.opts = baseTestOptions() pcTest.opts.SetXAuthRequest = true pcTest.opts.SetBasicAuth = false - validation.Validate(pcTest.opts) + err := validation.Validate(pcTest.opts) + assert.NoError(t, err) - var err error pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { - panic(err) + t.Fatal(err) } pcTest.proxy.provider = &TestProvider{ ValidToken: true, @@ -1321,7 +1415,8 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { created := time.Now() startSession := &sessions.SessionState{ User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created} - pcTest.SaveSession(startSession) + err = pcTest.SaveSession(startSession) + assert.NoError(t, err) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) @@ -1333,20 +1428,26 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { func TestAuthSkippedForPreflightRequests(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) - w.Write([]byte("response")) + _, err := w.Write([]byte("response")) + if err != nil { + t.Fatal(err) + } })) - defer upstream.Close() + t.Cleanup(upstream.Close) opts := baseTestOptions() opts.Upstreams = append(opts.Upstreams, upstream.URL) opts.SkipAuthPreflight = true - validation.Validate(opts) + err := validation.Validate(opts) + assert.NoError(t, err) upstreamURL, _ := url.Parse(upstream.URL) opts.SetProvider(NewTestProvider(upstreamURL, "")) proxy, err := NewOAuthProxy(opts, func(string) bool { return false }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } rw := httptest.NewRecorder() req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil) proxy.ServeHTTP(rw, req) @@ -1361,16 +1462,25 @@ type SignatureAuthenticator struct { func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Request) { result, headerSig, computedSig := v.auth.AuthenticateRequest(r) - if result == hmacauth.ResultNoSignature { - w.Write([]byte("no signature received")) - } else if result == hmacauth.ResultMatch { - w.Write([]byte("signatures match")) - } else if result == hmacauth.ResultMismatch { - w.Write([]byte("signatures do not match:" + - "\n received: " + headerSig + - "\n computed: " + computedSig)) - } else { - panic("Unknown result value: " + result.String()) + + var msg string + switch result { + case hmacauth.ResultNoSignature: + msg = "no signature received" + case hmacauth.ResultMatch: + msg = "signatures match" + case hmacauth.ResultMismatch: + msg = fmt.Sprintf( + "signatures do not match:\n received: %s\n computed: %s", + headerSig, + computedSig) + default: + panic("unknown result value: " + result.String()) + } + + _, err := w.Write([]byte(msg)) + if err != nil { + panic(err) } } @@ -1384,24 +1494,30 @@ type SignatureTest struct { authenticator *SignatureAuthenticator } -func NewSignatureTest() *SignatureTest { +func NewSignatureTest() (*SignatureTest, error) { opts := baseTestOptions() - opts.Cookie.Secret = rawCookieSecret - opts.ClientID = "client ID" - opts.ClientSecret = "client secret" opts.EmailDomains = []string{"acm.org"} authenticator := &SignatureAuthenticator{} upstream := httptest.NewServer( http.HandlerFunc(authenticator.Authenticate)) - upstreamURL, _ := url.Parse(upstream.URL) + upstreamURL, err := url.Parse(upstream.URL) + if err != nil { + return nil, err + } opts.Upstreams = append(opts.Upstreams, upstream.URL) providerHandler := func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(`{"access_token": "my_auth_token"}`)) + _, err := w.Write([]byte(`{"access_token": "my_auth_token"}`)) + if err != nil { + panic(err) + } } provider := httptest.NewServer(http.HandlerFunc(providerHandler)) - providerURL, _ := url.Parse(provider.URL) + providerURL, err := url.Parse(provider.URL) + if err != nil { + return nil, err + } opts.SetProvider(NewTestProvider(providerURL, "mbland@acm.org")) return &SignatureTest{ @@ -1412,7 +1528,7 @@ func NewSignatureTest() *SignatureTest { make(http.Header), httptest.NewRecorder(), authenticator, - } + }, nil } func (st *SignatureTest) Close() { @@ -1436,14 +1552,14 @@ func (fnc *fakeNetConn) Read(p []byte) (n int, err error) { return 0, io.EOF } -func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { +func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) error { err := validation.Validate(st.opts) if err != nil { - panic(err) + return err } proxy, err := NewOAuthProxy(st.opts, func(email string) bool { return true }) if err != nil { - panic(err) + return err } var bodyBuf io.ReadCloser @@ -1457,7 +1573,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { Email: "mbland@acm.org", AccessToken: "my_access_token"} err = proxy.SaveSession(st.rw, req, state) if err != nil { - panic(err) + return err } for _, c := range st.rw.Result().Cookies() { req.AddCookie(c) @@ -1466,33 +1582,52 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { st.authenticator.auth = hmacauth.NewHmacAuth( crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders) proxy.ServeHTTP(st.rw, req) + + return nil } -func TestNoRequestSignature(t *testing.T) { - st := NewSignatureTest() - defer st.Close() - st.MakeRequestWithExpectedKey("GET", "", "") - assert.Equal(t, 200, st.rw.Code) - assert.Equal(t, st.rw.Body.String(), "no signature received") -} - -func TestRequestSignatureGetRequest(t *testing.T) { - st := NewSignatureTest() - defer st.Close() - st.opts.SignatureKey = "sha1:7d9e1aa87a5954e6f9fc59266b3af9d7c35fda2d" - st.MakeRequestWithExpectedKey("GET", "", "7d9e1aa87a5954e6f9fc59266b3af9d7c35fda2d") - assert.Equal(t, 200, st.rw.Code) - assert.Equal(t, st.rw.Body.String(), "signatures match") -} - -func TestRequestSignaturePostRequest(t *testing.T) { - st := NewSignatureTest() - defer st.Close() - st.opts.SignatureKey = "sha1:d90df39e2d19282840252612dd7c81421a372f61" - payload := `{ "hello": "world!" }` - st.MakeRequestWithExpectedKey("POST", payload, "d90df39e2d19282840252612dd7c81421a372f61") - assert.Equal(t, 200, st.rw.Code) - assert.Equal(t, st.rw.Body.String(), "signatures match") +func TestRequestSignature(t *testing.T) { + testCases := map[string]struct { + method string + body string + key string + resp string + }{ + "No request signature": { + method: "GET", + body: "", + key: "", + resp: "no signature received", + }, + "Get request": { + method: "GET", + body: "", + key: "7d9e1aa87a5954e6f9fc59266b3af9d7c35fda2d", + resp: "signatures match", + }, + "Post request": { + method: "POST", + body: `{ "hello": "world!" }`, + key: "d90df39e2d19282840252612dd7c81421a372f61", + resp: "signatures match", + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + st, err := NewSignatureTest() + if err != nil { + t.Fatal(err) + } + t.Cleanup(st.Close) + if tc.key != "" { + st.opts.SignatureKey = fmt.Sprintf("sha1:%s", tc.key) + } + err = st.MakeRequestWithExpectedKey(tc.method, tc.body, tc.key) + assert.NoError(t, err) + assert.Equal(t, 200, st.rw.Code) + assert.Equal(t, tc.resp, st.rw.Body.String()) + }) + } } func TestGetRedirect(t *testing.T) { @@ -1501,7 +1636,9 @@ func TestGetRedirect(t *testing.T) { assert.NoError(t, err) require.NotEmpty(t, opts.ProxyPrefix) proxy, err := NewOAuthProxy(opts, func(s string) bool { return false }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } tests := []struct { name string @@ -1535,22 +1672,21 @@ type ajaxRequestTest struct { proxy *OAuthProxy } -func newAjaxRequestTest() *ajaxRequestTest { +func newAjaxRequestTest() (*ajaxRequestTest, error) { test := &ajaxRequestTest{} test.opts = baseTestOptions() - test.opts.Cookie.Secret = base64CookieSecret - test.opts.ClientID = "gkljfdl" - test.opts.ClientSecret = "sdflkjs" - validation.Validate(test.opts) + err := validation.Validate(test.opts) + if err != nil { + return nil, err + } - var err error test.proxy, err = NewOAuthProxy(test.opts, func(email string) bool { return true }) if err != nil { - panic(err) + return nil, err } - return test + return test, nil } func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (int, http.Header, error) { @@ -1565,7 +1701,10 @@ func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (i } func testAjaxUnauthorizedRequest(t *testing.T, header http.Header) { - test := newAjaxRequestTest() + test, err := newAjaxRequestTest() + if err != nil { + t.Fatal(err) + } endpoint := "/test" code, rh, err := test.getEndpoint(endpoint, header) @@ -1589,7 +1728,10 @@ func TestAjaxUnauthorizedRequest2(t *testing.T) { } func TestAjaxForbiddendRequest(t *testing.T) { - test := newAjaxRequestTest() + test, err := newAjaxRequestTest() + if err != nil { + t.Fatal(err) + } endpoint := "/test" header := make(http.Header) code, rh, err := test.getEndpoint(endpoint, header) @@ -1604,8 +1746,14 @@ func TestClearSplitCookie(t *testing.T) { opts.Cookie.Secret = base64CookieSecret opts.Cookie.Name = "oauth2" opts.Cookie.Domains = []string{"abc"} - store, err := cookie.NewCookieSessionStore(&opts.Session, &opts.Cookie) - assert.Equal(t, nil, err) + err := validation.Validate(opts) + assert.NoError(t, err) + + store, err := sessionscookie.NewCookieSessionStore(&opts.Session, &opts.Cookie) + if err != nil { + t.Fatal(err) + } + p := OAuthProxy{CookieName: opts.Cookie.Name, CookieDomains: opts.Cookie.Domains, sessionStore: store} var rw = httptest.NewRecorder() req := httptest.NewRequest("get", "/", nil) @@ -1623,7 +1771,8 @@ func TestClearSplitCookie(t *testing.T) { Value: "oauth2_1", }) - p.ClearSessionCookie(rw, req) + err = p.ClearSessionCookie(rw, req) + assert.NoError(t, err) header := rw.Header() assert.Equal(t, 2, len(header["Set-Cookie"]), "should have 3 set-cookie header entries") @@ -1633,8 +1782,11 @@ func TestClearSingleCookie(t *testing.T) { opts := baseTestOptions() opts.Cookie.Name = "oauth2" opts.Cookie.Domains = []string{"abc"} - store, err := cookie.NewCookieSessionStore(&opts.Session, &opts.Cookie) - assert.Equal(t, nil, err) + store, err := sessionscookie.NewCookieSessionStore(&opts.Session, &opts.Cookie) + if err != nil { + t.Fatal(err) + } + p := OAuthProxy{CookieName: opts.Cookie.Name, CookieDomains: opts.Cookie.Domains, sessionStore: store} var rw = httptest.NewRecorder() req := httptest.NewRequest("get", "/", nil) @@ -1648,7 +1800,8 @@ func TestClearSingleCookie(t *testing.T) { Value: "oauth2", }) - p.ClearSessionCookie(rw, req) + err = p.ClearSessionCookie(rw, req) + assert.NoError(t, err) header := rw.Header() assert.Equal(t, 1, len(header["Set-Cookie"]), "should have 1 set-cookie header entries") @@ -1686,13 +1839,16 @@ func TestGetJwtSession(t *testing.T) { verifier := oidc.NewVerifier("https://issuer.example.com", keyset, &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) - test := NewAuthOnlyEndpointTest(func(opts *options.Options) { + test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) { opts.PassAuthorization = true opts.SetAuthorization = true opts.SetXAuthRequest = true opts.SkipJwtBearerTokens = true opts.SetJWTBearerVerifiers(append(opts.GetJWTBearerVerifiers(), verifier)) }) + if err != nil { + t.Fatal(err) + } tp, _ := test.proxy.provider.(*TestProvider) tp.GroupValidator = func(s string) bool { return true @@ -1705,7 +1861,8 @@ func TestGetJwtSession(t *testing.T) { // Bearer expires := time.Unix(1912151821, 0) - session, _ := test.proxy.GetJwtSession(test.req) + session, err := test.proxy.GetJwtSession(test.req) + assert.NoError(t, err) assert.Equal(t, session.User, "1234567890") assert.Equal(t, session.Email, "john@example.com") assert.Equal(t, session.ExpiresOn, &expires) @@ -1739,22 +1896,26 @@ func TestFindJwtBearerToken(t *testing.T) { "Authorization": {fmt.Sprintf("Bearer %s", validToken)}, } - token, _ = p.findBearerToken(getReq) + token, err := p.findBearerToken(getReq) + assert.NoError(t, err) assert.Equal(t, validToken, token) // Basic - no password getReq.SetBasicAuth(token, "") - token, _ = p.findBearerToken(getReq) + token, err = p.findBearerToken(getReq) + assert.NoError(t, err) assert.Equal(t, validToken, token) // Basic - sentinel password getReq.SetBasicAuth(token, "x-oauth-basic") - token, _ = p.findBearerToken(getReq) + token, err = p.findBearerToken(getReq) + assert.NoError(t, err) assert.Equal(t, validToken, token) // Basic - any username, password matching jwt pattern getReq.SetBasicAuth("any-username-you-could-wish-for", token) - token, _ = p.findBearerToken(getReq) + token, err = p.findBearerToken(getReq) + assert.NoError(t, err) assert.Equal(t, validToken, token) failures := []string{ @@ -1785,8 +1946,6 @@ func TestFindJwtBearerToken(t *testing.T) { _, err := p.findBearerToken(getReq) assert.Error(t, err) } - - fmt.Printf("%s", token) } func Test_prepareNoCache(t *testing.T) { @@ -1807,18 +1966,22 @@ func Test_prepareNoCache(t *testing.T) { func Test_noCacheHeaders(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("upstream")) + _, err := w.Write([]byte("upstream")) + if err != nil { + t.Error(err) + } })) t.Cleanup(upstream.Close) opts := baseTestOptions() opts.Upstreams = []string{upstream.URL} opts.SkipAuthRegex = []string{".*"} - _ = validation.Validate(opts) - proxy, err := NewOAuthProxy(opts, func(email string) bool { - return true - }) + err := validation.Validate(opts) assert.NoError(t, err) + proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + if err != nil { + t.Fatal(err) + } t.Run("not exist in response from upstream", func(t *testing.T) { rec := httptest.NewRecorder() @@ -1887,8 +2050,8 @@ func Test_noCacheHeaders(t *testing.T) { func baseTestOptions() *options.Options { opts := options.NewOptions() opts.Cookie.Secret = rawCookieSecret - opts.ClientID = "cliend-id" - opts.ClientSecret = "client-secret" + opts.ClientID = clientID + opts.ClientSecret = clientSecret opts.EmailDomains = []string{"*"} return opts } diff --git a/pkg/apis/sessions/legacy_v5_tester.go b/pkg/apis/sessions/legacy_v5_tester.go new file mode 100644 index 00000000..44cf4e73 --- /dev/null +++ b/pkg/apis/sessions/legacy_v5_tester.go @@ -0,0 +1,87 @@ +package sessions + +import ( + "fmt" + "testing" + "time" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" + "github.com/stretchr/testify/assert" +) + +// LegacyV5TestCase provides V5 JSON based test cases for legacy fallback code +type LegacyV5TestCase struct { + Input string + Error bool + Output *SessionState +} + +// CreateLegacyV5TestCases makes various V5 JSON sessions as test cases +// +// Used for `apis/sessions/session_state_test.go` & `sessions/redis/redis_store_test.go` +// +// TODO: Remove when this is deprecated (likely V7) +func CreateLegacyV5TestCases(t *testing.T) (map[string]LegacyV5TestCase, encryption.Cipher, encryption.Cipher) { + const secret = "0123456789abcdefghijklmnopqrstuv" + + created := time.Now() + createdJSON, err := created.MarshalJSON() + assert.NoError(t, err) + createdString := string(createdJSON) + e := time.Now().Add(time.Duration(1) * time.Hour) + eJSON, err := e.MarshalJSON() + assert.NoError(t, err) + eString := string(eJSON) + + cfbCipher, err := encryption.NewCFBCipher([]byte(secret)) + assert.NoError(t, err) + legacyCipher := encryption.NewBase64Cipher(cfbCipher) + + testCases := map[string]LegacyV5TestCase{ + "User & email unencrypted": { + Input: `{"Email":"user@domain.com","User":"just-user"}`, + Error: true, + }, + "Only email unencrypted": { + Input: `{"Email":"user@domain.com"}`, + Error: true, + }, + "Just user unencrypted": { + Input: `{"User":"just-user"}`, + Error: true, + }, + "User and Email unencrypted while rest is encrypted": { + Input: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), + Error: true, + }, + "Full session with cipher": { + Input: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), + Output: &SessionState{ + Email: "user@domain.com", + User: "just-user", + AccessToken: "token1234", + IDToken: "rawtoken1234", + CreatedAt: &created, + ExpiresOn: &e, + RefreshToken: "refresh4321", + }, + }, + "Minimal session encrypted with cipher": { + Input: `{"Email":"EGTllJcOFC16b7LBYzLekaHAC5SMMSPdyUrg8hd25g==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw=="}`, + Output: &SessionState{ + Email: "user@domain.com", + User: "just-user", + }, + }, + "Unencrypted User, Email and AccessToken": { + Input: `{"Email":"user@domain.com","User":"just-user","AccessToken":"X"}`, + Error: true, + }, + "Unencrypted User, Email and IDToken": { + Input: `{"Email":"user@domain.com","User":"just-user","IDToken":"XXXX"}`, + Error: true, + }, + } + + return testCases, cfbCipher, legacyCipher +} diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index 44b91bd2..2015df8c 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -1,25 +1,30 @@ package sessions import ( + "bytes" "encoding/json" "errors" "fmt" + "io" + "io/ioutil" "time" "unicode/utf8" "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" + "github.com/pierrec/lz4" + "github.com/vmihailenco/msgpack/v4" ) // SessionState is used to store information about the currently authenticated user session type SessionState struct { - AccessToken string `json:",omitempty"` - IDToken string `json:",omitempty"` - CreatedAt *time.Time `json:",omitempty"` - ExpiresOn *time.Time `json:",omitempty"` - RefreshToken string `json:",omitempty"` - Email string `json:",omitempty"` - User string `json:",omitempty"` - PreferredUsername string `json:",omitempty"` + AccessToken string `json:",omitempty" msgpack:"at,omitempty"` + IDToken string `json:",omitempty" msgpack:"it,omitempty"` + CreatedAt *time.Time `json:",omitempty" msgpack:"ca,omitempty"` + ExpiresOn *time.Time `json:",omitempty" msgpack:"eo,omitempty"` + RefreshToken string `json:",omitempty" msgpack:"rt,omitempty"` + Email string `json:",omitempty" msgpack:"e,omitempty"` + User string `json:",omitempty" msgpack:"u,omitempty"` + PreferredUsername string `json:",omitempty" msgpack:"pu,omitempty"` } // IsExpired checks whether the session has expired @@ -59,78 +64,79 @@ func (s *SessionState) String() string { return o + "}" } -// EncodeSessionState returns string representation of the current session -func (s *SessionState) EncodeSessionState(c encryption.Cipher) (string, error) { - var ss SessionState - if c == nil { - // Store only Email and User when cipher is unavailable - ss.Email = s.Email - ss.User = s.User - ss.PreferredUsername = s.PreferredUsername - } else { - ss = *s - for _, s := range []*string{ - &ss.Email, - &ss.User, - &ss.PreferredUsername, - &ss.AccessToken, - &ss.IDToken, - &ss.RefreshToken, - } { - err := into(s, c.Encrypt) - if err != nil { - return "", err - } +// EncodeSessionState returns an encrypted, lz4 compressed, MessagePack encoded session +func (s *SessionState) EncodeSessionState(c encryption.Cipher, compress bool) ([]byte, error) { + packed, err := msgpack.Marshal(s) + if err != nil { + return nil, fmt.Errorf("error marshalling session state to msgpack: %w", err) + } + + if !compress { + return c.Encrypt(packed) + } + + compressed, err := lz4Compress(packed) + if err != nil { + return nil, err + } + return c.Encrypt(compressed) +} + +// DecodeSessionState decodes a LZ4 compressed MessagePack into a Session State +func DecodeSessionState(data []byte, c encryption.Cipher, compressed bool) (*SessionState, error) { + decrypted, err := c.Decrypt(data) + if err != nil { + return nil, fmt.Errorf("error decrypting the session state: %w", err) + } + + packed := decrypted + if compressed { + packed, err = lz4Decompress(decrypted) + if err != nil { + return nil, err } } - b, err := json.Marshal(ss) - return string(b), err + var ss SessionState + err = msgpack.Unmarshal(packed, &ss) + if err != nil { + return nil, fmt.Errorf("error unmarshalling data to session state: %w", err) + } + + err = ss.validate() + if err != nil { + return nil, err + } + + return &ss, nil } -// DecodeSessionState decodes the session cookie string into a SessionState -func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { +// LegacyV5DecodeSessionState decodes a legacy JSON session cookie string into a SessionState +func LegacyV5DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { var ss SessionState err := json.Unmarshal([]byte(v), &ss) if err != nil { return nil, fmt.Errorf("error unmarshalling session: %w", err) } - if c == nil { - // Load only Email and User when cipher is unavailable - ss = SessionState{ - Email: ss.Email, - User: ss.User, - PreferredUsername: ss.PreferredUsername, - } - } else { - // Backward compatibility with using unencrypted Email or User - // Decryption errors will leave original string - err = into(&ss.Email, c.Decrypt) - if err == nil { - if !utf8.ValidString(ss.Email) { - return nil, errors.New("invalid value for decrypted email") - } - } - err = into(&ss.User, c.Decrypt) - if err == nil { - if !utf8.ValidString(ss.User) { - return nil, errors.New("invalid value for decrypted user") - } - } - - for _, s := range []*string{ - &ss.PreferredUsername, - &ss.AccessToken, - &ss.IDToken, - &ss.RefreshToken, - } { - err := into(s, c.Decrypt) - if err != nil { - return nil, err - } + for _, s := range []*string{ + &ss.User, + &ss.Email, + &ss.PreferredUsername, + &ss.AccessToken, + &ss.IDToken, + &ss.RefreshToken, + } { + err := into(s, c.Decrypt) + if err != nil { + return nil, err } } + err = ss.validate() + if err != nil { + return nil, err + } + return &ss, nil } @@ -150,3 +156,86 @@ func into(s *string, f codecFunc) error { *s = string(d) return nil } + +// lz4Compress compresses with LZ4 +// +// The Compress:Decompress ratio is 1:Many. LZ4 gives fastest decompress speeds +// at the expense of greater compression compared to other compression +// algorithms. +func lz4Compress(payload []byte) ([]byte, error) { + buf := new(bytes.Buffer) + zw := lz4.NewWriter(nil) + zw.Header = lz4.Header{ + BlockMaxSize: 65536, + CompressionLevel: 0, + } + zw.Reset(buf) + + reader := bytes.NewReader(payload) + _, err := io.Copy(zw, reader) + if err != nil { + return nil, fmt.Errorf("error copying lz4 stream to buffer: %w", err) + } + err = zw.Close() + if err != nil { + return nil, fmt.Errorf("error closing lz4 writer: %w", err) + } + + compressed, err := ioutil.ReadAll(buf) + if err != nil { + return nil, fmt.Errorf("error reading lz4 buffer: %w", err) + } + + return compressed, nil +} + +// lz4Decompress decompresses with LZ4 +func lz4Decompress(compressed []byte) ([]byte, error) { + reader := bytes.NewReader(compressed) + buf := new(bytes.Buffer) + zr := lz4.NewReader(nil) + zr.Reset(reader) + _, err := io.Copy(buf, zr) + if err != nil { + return nil, fmt.Errorf("error copying lz4 stream to buffer: %w", err) + } + + payload, err := ioutil.ReadAll(buf) + if err != nil { + return nil, fmt.Errorf("error reading lz4 buffer: %w", err) + } + + return payload, nil +} + +// validate ensures the decoded session is non-empty and contains valid data +// +// Non-empty check is needed due to ensure the non-authenticated AES-CFB +// decryption doesn't result in garbage data that collides with a valid +// MessagePack header bytes (which MessagePack will unmarshal to an empty +// default SessionState). <1% chance, but observed with random test data. +// +// UTF-8 check ensures the strings are valid and not raw bytes overloaded +// into Latin-1 encoding. The occurs when legacy unencrypted fields are +// decrypted with AES-CFB which results in random bytes. +func (s *SessionState) validate() error { + for _, field := range []string{ + s.User, + s.Email, + s.PreferredUsername, + s.AccessToken, + s.IDToken, + s.RefreshToken, + } { + if !utf8.ValidString(field) { + return errors.New("invalid non-UTF8 field in session") + } + } + + empty := new(SessionState) + if *s == *empty { + return errors.New("invalid empty session unmarshalled") + } + + return nil +} diff --git a/pkg/apis/sessions/session_state_test.go b/pkg/apis/sessions/session_state_test.go index 3e9554c5..ac554c60 100644 --- a/pkg/apis/sessions/session_state_test.go +++ b/pkg/apis/sessions/session_state_test.go @@ -12,132 +12,11 @@ import ( "github.com/stretchr/testify/assert" ) -const secret = "0123456789abcdefghijklmnopqrstuv" -const altSecret = "0000000000abcdefghijklmnopqrstuv" - func timePtr(t time.Time) *time.Time { return &t } -func newTestCipher(secret []byte) (encryption.Cipher, error) { - return encryption.NewBase64Cipher(encryption.NewCFBCipher, secret) -} - -func TestSessionStateSerialization(t *testing.T) { - c, err := newTestCipher([]byte(secret)) - assert.Equal(t, nil, err) - c2, err := newTestCipher([]byte(altSecret)) - assert.Equal(t, nil, err) - s := &SessionState{ - Email: "user@domain.com", - PreferredUsername: "user", - AccessToken: "token1234", - IDToken: "rawtoken1234", - CreatedAt: timePtr(time.Now()), - ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)), - RefreshToken: "refresh4321", - } - encoded, err := s.EncodeSessionState(c) - assert.Equal(t, nil, err) - - ss, err := DecodeSessionState(encoded, c) - t.Logf("%#v", ss) - assert.Equal(t, nil, err) - assert.Equal(t, "", ss.User) - assert.Equal(t, s.Email, ss.Email) - assert.Equal(t, s.PreferredUsername, ss.PreferredUsername) - assert.Equal(t, s.AccessToken, ss.AccessToken) - assert.Equal(t, s.IDToken, ss.IDToken) - assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) - assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) - assert.Equal(t, s.RefreshToken, ss.RefreshToken) - - // ensure a different cipher can't decode properly (ie: it gets gibberish) - ss, err = DecodeSessionState(encoded, c2) - t.Logf("%#v", ss) - assert.NotEqual(t, nil, err) -} - -func TestSessionStateSerializationWithUser(t *testing.T) { - c, err := newTestCipher([]byte(secret)) - assert.Equal(t, nil, err) - c2, err := newTestCipher([]byte(altSecret)) - assert.Equal(t, nil, err) - s := &SessionState{ - User: "just-user", - PreferredUsername: "ju", - Email: "user@domain.com", - AccessToken: "token1234", - CreatedAt: timePtr(time.Now()), - ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)), - RefreshToken: "refresh4321", - } - encoded, err := s.EncodeSessionState(c) - assert.Equal(t, nil, err) - - ss, err := DecodeSessionState(encoded, c) - t.Logf("%#v", ss) - assert.Equal(t, nil, err) - assert.Equal(t, s.User, ss.User) - assert.Equal(t, s.Email, ss.Email) - assert.Equal(t, s.PreferredUsername, ss.PreferredUsername) - assert.Equal(t, s.AccessToken, ss.AccessToken) - assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) - assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) - assert.Equal(t, s.RefreshToken, ss.RefreshToken) - - // ensure a different cipher can't decode properly (ie: it gets gibberish) - ss, err = DecodeSessionState(encoded, c2) - t.Logf("%#v", ss) - assert.NotEqual(t, nil, err) -} - -func TestSessionStateSerializationNoCipher(t *testing.T) { - s := &SessionState{ - Email: "user@domain.com", - PreferredUsername: "user", - AccessToken: "token1234", - CreatedAt: timePtr(time.Now()), - ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)), - RefreshToken: "refresh4321", - } - encoded, err := s.EncodeSessionState(nil) - assert.Equal(t, nil, err) - - // only email should have been serialized - ss, err := DecodeSessionState(encoded, nil) - assert.Equal(t, nil, err) - assert.Equal(t, "", ss.User) - assert.Equal(t, s.Email, ss.Email) - assert.Equal(t, s.PreferredUsername, ss.PreferredUsername) - assert.Equal(t, "", ss.AccessToken) - assert.Equal(t, "", ss.RefreshToken) -} - -func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { - s := &SessionState{ - User: "just-user", - Email: "user@domain.com", - PreferredUsername: "user", - AccessToken: "token1234", - CreatedAt: timePtr(time.Now()), - ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)), - RefreshToken: "refresh4321", - } - encoded, err := s.EncodeSessionState(nil) - assert.Equal(t, nil, err) - - // only email should have been serialized - ss, err := DecodeSessionState(encoded, nil) - assert.Equal(t, nil, err) - assert.Equal(t, s.User, ss.User) - assert.Equal(t, s.Email, ss.Email) - assert.Equal(t, s.PreferredUsername, ss.PreferredUsername) - assert.Equal(t, "", ss.AccessToken) - assert.Equal(t, "", ss.RefreshToken) -} - -func TestExpired(t *testing.T) { +func TestIsExpired(t *testing.T) { s := &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))} assert.Equal(t, true, s.IsExpired()) @@ -148,161 +27,7 @@ func TestExpired(t *testing.T) { assert.Equal(t, false, s.IsExpired()) } -type testCase struct { - SessionState - Encoded string - Cipher encryption.Cipher - Error bool -} - -// TestEncodeSessionState tests EncodeSessionState with the test vector -// -// Currently only tests without cipher here because we have no way to mock -// the random generator used in EncodeSessionState. -func TestEncodeSessionState(t *testing.T) { - c := time.Now() - e := time.Now().Add(time.Duration(1) * time.Hour) - - testCases := []testCase{ - { - SessionState: SessionState{ - Email: "user@domain.com", - User: "just-user", - }, - Encoded: `{"Email":"user@domain.com","User":"just-user"}`, - }, - { - SessionState: SessionState{ - Email: "user@domain.com", - User: "just-user", - AccessToken: "token1234", - IDToken: "rawtoken1234", - CreatedAt: &c, - ExpiresOn: &e, - RefreshToken: "refresh4321", - }, - Encoded: `{"Email":"user@domain.com","User":"just-user"}`, - }, - } - - for i, tc := range testCases { - encoded, err := tc.EncodeSessionState(tc.Cipher) - t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) - if tc.Error { - assert.Error(t, err) - assert.Empty(t, encoded) - continue - } - assert.NoError(t, err) - assert.JSONEq(t, tc.Encoded, encoded) - } -} - -// TestDecodeSessionState testssessions.DecodeSessionState with the test vector -func TestDecodeSessionState(t *testing.T) { - created := time.Now() - createdJSON, _ := created.MarshalJSON() - createdString := string(createdJSON) - e := time.Now().Add(time.Duration(1) * time.Hour) - eJSON, _ := e.MarshalJSON() - eString := string(eJSON) - - c, err := newTestCipher([]byte(secret)) - assert.NoError(t, err) - - testCases := []testCase{ - { - SessionState: SessionState{ - Email: "user@domain.com", - User: "just-user", - }, - Encoded: `{"Email":"user@domain.com","User":"just-user"}`, - }, - { - SessionState: SessionState{ - Email: "user@domain.com", - User: "", - }, - Encoded: `{"Email":"user@domain.com"}`, - }, - { - SessionState: SessionState{ - User: "just-user", - }, - Encoded: `{"User":"just-user"}`, - }, - { - SessionState: SessionState{ - Email: "user@domain.com", - User: "just-user", - }, - Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), - }, - { - SessionState: SessionState{ - Email: "user@domain.com", - User: "just-user", - AccessToken: "token1234", - IDToken: "rawtoken1234", - CreatedAt: &created, - ExpiresOn: &e, - RefreshToken: "refresh4321", - }, - Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), - Cipher: c, - }, - { - SessionState: SessionState{ - Email: "user@domain.com", - User: "just-user", - }, - Encoded: `{"Email":"EGTllJcOFC16b7LBYzLekaHAC5SMMSPdyUrg8hd25g==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw=="}`, - Cipher: c, - }, - { - Encoded: `{"Email":"user@domain.com","User":"just-user","AccessToken":"X"}`, - Cipher: c, - Error: true, - }, - { - Encoded: `{"Email":"user@domain.com","User":"just-user","IDToken":"XXXX"}`, - Cipher: c, - Error: true, - }, - { - SessionState: SessionState{ - Email: "user@domain.com", - User: "YmFzZTY0LWVuY29kZWQtdXNlcgo=", // Base64 encoding of base64-encoded-user - }, - Error: true, - Cipher: c, - }, - } - - for i, tc := range testCases { - ss, err := DecodeSessionState(tc.Encoded, tc.Cipher) - t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, tc.Encoded, ss, err) - if tc.Error { - assert.Error(t, err) - assert.Nil(t, ss) - continue - } - assert.NoError(t, err) - if assert.NotNil(t, ss) { - assert.Equal(t, tc.User, ss.User) - assert.Equal(t, tc.Email, ss.Email) - assert.Equal(t, tc.AccessToken, ss.AccessToken) - assert.Equal(t, tc.RefreshToken, ss.RefreshToken) - assert.Equal(t, tc.IDToken, ss.IDToken) - if tc.ExpiresOn != nil { - assert.NotEqual(t, nil, ss.ExpiresOn) - assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) - } - } - } -} - -func TestSessionStateAge(t *testing.T) { +func TestAge(t *testing.T) { ss := &SessionState{} // Created at unset so should be 0 @@ -313,7 +38,149 @@ func TestSessionStateAge(t *testing.T) { assert.Equal(t, time.Hour, ss.Age().Round(time.Minute)) } -func TestIntoEncryptAndIntoDecrypt(t *testing.T) { +// TestEncodeAndDecodeSessionState encodes & decodes various session states +// and confirms the operation is 1:1 +func TestEncodeAndDecodeSessionState(t *testing.T) { + created := time.Now() + expires := time.Now().Add(time.Duration(1) * time.Hour) + + // Tokens in the test table are purposefully redundant + // Otherwise compressing small payloads could result in a compressed value + // that is larger (compression dictionary + limited like strings to compress) + // which breaks the len(compressed) < len(uncompressed) assertion. + testCases := map[string]SessionState{ + "Full session": { + Email: "username@example.com", + User: "username", + PreferredUsername: "preferred.username", + AccessToken: "AccessToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + CreatedAt: &created, + ExpiresOn: &expires, + RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + }, + "No ExpiresOn": { + Email: "username@example.com", + User: "username", + PreferredUsername: "preferred.username", + AccessToken: "AccessToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + CreatedAt: &created, + RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + }, + "No PreferredUsername": { + Email: "username@example.com", + User: "username", + AccessToken: "AccessToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + CreatedAt: &created, + ExpiresOn: &expires, + RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + }, + "Minimal session": { + User: "username", + IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + CreatedAt: &created, + RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + }, + "Bearer authorization header created session": { + Email: "username", + User: "username", + AccessToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + ExpiresOn: &expires, + }, + } + + for _, secretSize := range []int{16, 24, 32} { + t.Run(fmt.Sprintf("%d byte secret", secretSize), func(t *testing.T) { + secret := make([]byte, secretSize) + _, err := io.ReadFull(rand.Reader, secret) + assert.NoError(t, err) + + cfb, err := encryption.NewCFBCipher([]byte(secret)) + assert.NoError(t, err) + gcm, err := encryption.NewGCMCipher([]byte(secret)) + assert.NoError(t, err) + + ciphers := map[string]encryption.Cipher{ + "CFB cipher": cfb, + "GCM cipher": gcm, + } + + for cipherName, c := range ciphers { + t.Run(cipherName, func(t *testing.T) { + for testName, ss := range testCases { + t.Run(testName, func(t *testing.T) { + encoded, err := ss.EncodeSessionState(c, false) + assert.NoError(t, err) + encodedCompressed, err := ss.EncodeSessionState(c, true) + assert.NoError(t, err) + // Make sure compressed version is smaller than if not compressed + assert.Greater(t, len(encoded), len(encodedCompressed)) + + decoded, err := DecodeSessionState(encoded, c, false) + assert.NoError(t, err) + decodedCompressed, err := DecodeSessionState(encodedCompressed, c, true) + assert.NoError(t, err) + + compareSessionStates(t, decoded, decodedCompressed) + compareSessionStates(t, decoded, &ss) + }) + } + }) + } + + t.Run("Mixed cipher types cause errors", func(t *testing.T) { + for testName, ss := range testCases { + t.Run(testName, func(t *testing.T) { + cfbEncoded, err := ss.EncodeSessionState(cfb, false) + assert.NoError(t, err) + _, err = DecodeSessionState(cfbEncoded, gcm, false) + assert.Error(t, err) + + gcmEncoded, err := ss.EncodeSessionState(gcm, false) + assert.NoError(t, err) + _, err = DecodeSessionState(gcmEncoded, cfb, false) + assert.Error(t, err) + }) + } + }) + }) + } +} + +// TestLegacyV5DecodeSessionState confirms V5 JSON sessions decode +// +// TODO: Remove when this is deprecated (likely V7) +func TestLegacyV5DecodeSessionState(t *testing.T) { + testCases, cipher, legacyCipher := CreateLegacyV5TestCases(t) + + for testName, tc := range testCases { + t.Run(testName, func(t *testing.T) { + // Legacy sessions fail in DecodeSessionState which results in + // the fallback to LegacyV5DecodeSessionState + _, err := DecodeSessionState([]byte(tc.Input), cipher, false) + assert.Error(t, err) + _, err = DecodeSessionState([]byte(tc.Input), cipher, true) + assert.Error(t, err) + + ss, err := LegacyV5DecodeSessionState(tc.Input, legacyCipher) + if tc.Error { + assert.Error(t, err) + assert.Nil(t, ss) + return + } + assert.NoError(t, err) + compareSessionStates(t, tc.Output, ss) + }) + } +} + +// Test_into tests the into helper function used in LegacyV5DecodeSessionState +// +// TODO: Remove when this is deprecated (likely V7) +func Test_into(t *testing.T) { const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" // Test all 3 valid AES sizes @@ -323,8 +190,9 @@ func TestIntoEncryptAndIntoDecrypt(t *testing.T) { _, err := io.ReadFull(rand.Reader, secret) assert.Equal(t, nil, err) - c, err := newTestCipher(secret) + cfb, err := encryption.NewCFBCipher(secret) assert.NoError(t, err) + c := encryption.NewBase64Cipher(cfb) // Check no errors with empty or nil strings empty := "" @@ -353,3 +221,27 @@ func TestIntoEncryptAndIntoDecrypt(t *testing.T) { }) } } + +func compareSessionStates(t *testing.T, expected *SessionState, actual *SessionState) { + if expected.CreatedAt != nil { + assert.NotNil(t, actual.CreatedAt) + assert.Equal(t, true, expected.CreatedAt.Equal(*actual.CreatedAt)) + } else { + assert.Nil(t, actual.CreatedAt) + } + if expected.ExpiresOn != nil { + assert.NotNil(t, actual.ExpiresOn) + assert.Equal(t, true, expected.ExpiresOn.Equal(*actual.ExpiresOn)) + } else { + assert.Nil(t, actual.ExpiresOn) + } + + // Compare sessions without *time.Time fields + exp := *expected + exp.CreatedAt = nil + exp.ExpiresOn = nil + act := *actual + act.CreatedAt = nil + act.ExpiresOn = nil + assert.Equal(t, exp, act) +} diff --git a/pkg/encryption/cipher.go b/pkg/encryption/cipher.go index c1158b5c..37e08ba8 100644 --- a/pkg/encryption/cipher.go +++ b/pkg/encryption/cipher.go @@ -21,12 +21,8 @@ type base64Cipher struct { // NewBase64Cipher returns a new AES Cipher for encrypting cookie values // and wrapping them in Base64 -- Supports Legacy encryption scheme -func NewBase64Cipher(initCipher func([]byte) (Cipher, error), secret []byte) (Cipher, error) { - c, err := initCipher(secret) - if err != nil { - return nil, err - } - return &base64Cipher{Cipher: c}, nil +func NewBase64Cipher(c Cipher) Cipher { + return &base64Cipher{Cipher: c} } // Encrypt encrypts a value with the embedded Cipher & Base64 encodes it diff --git a/pkg/encryption/cipher_test.go b/pkg/encryption/cipher_test.go index b552e70c..16e12929 100644 --- a/pkg/encryption/cipher_test.go +++ b/pkg/encryption/cipher_test.go @@ -13,8 +13,9 @@ import ( func TestEncodeAndDecodeAccessToken(t *testing.T) { const secret = "0123456789abcdefghijklmnopqrstuv" const token = "my access token" - c, err := NewBase64Cipher(NewCFBCipher, []byte(secret)) - assert.Equal(t, nil, err) + cfb, err := NewCFBCipher([]byte(secret)) + assert.NoError(t, err) + c := NewBase64Cipher(cfb) encoded, err := c.Encrypt([]byte(token)) assert.Equal(t, nil, err) @@ -32,8 +33,9 @@ func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { secret, err := base64.URLEncoding.DecodeString(secretBase64) assert.Equal(t, nil, err) - c, err := NewBase64Cipher(NewCFBCipher, []byte(secret)) - assert.Equal(t, nil, err) + cfb, err := NewCFBCipher([]byte(secret)) + assert.NoError(t, err) + c := NewBase64Cipher(cfb) encoded, err := c.Encrypt([]byte(token)) assert.Equal(t, nil, err) @@ -64,8 +66,7 @@ func TestEncryptAndDecrypt(t *testing.T) { cstd, err := initCipher(secret) assert.Equal(t, nil, err) - cb64, err := NewBase64Cipher(initCipher, secret) - assert.Equal(t, nil, err) + cb64 := NewBase64Cipher(cstd) ciphers := map[string]Cipher{ "Standard": cstd, diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index 6fa6b5ea..69b55d11 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -60,7 +60,7 @@ func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) { return nil, errors.New("cookie signature not valid") } - session, err := sessionFromCookie(string(val), s.CookieCipher) + session, err := sessionFromCookie(val, s.CookieCipher) if err != nil { return nil, err } @@ -85,17 +85,26 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { } // cookieForSession serializes a session state for storage in a cookie -func cookieForSession(s *sessions.SessionState, c encryption.Cipher) (string, error) { - return s.EncodeSessionState(c) +func cookieForSession(s *sessions.SessionState, c encryption.Cipher) ([]byte, error) { + return s.EncodeSessionState(c, true) } // sessionFromCookie deserializes a session from a cookie value -func sessionFromCookie(v string, c encryption.Cipher) (s *sessions.SessionState, err error) { - return sessions.DecodeSessionState(v, c) +func sessionFromCookie(v []byte, c encryption.Cipher) (s *sessions.SessionState, err error) { + ss, err := sessions.DecodeSessionState(v, c, true) + // If anything fails (Decrypt, LZ4, MessagePack), try legacy JSON decode + // LZ4 will likely fail for wrong header after AES-CFB spits out garbage + // data from trying to decrypt JSON it things is ciphertext + if err != nil { + // Legacy used Base64 + AES CFB + legacyCipher := encryption.NewBase64Cipher(c) + return sessions.LegacyV5DecodeSessionState(string(v), legacyCipher) + } + return ss, nil } // setSessionCookie adds the user's session cookie to the response -func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string, created time.Time) { +func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val []byte, created time.Time) { for _, c := range s.makeSessionCookie(req, val, created) { http.SetCookie(rw, c) } @@ -103,12 +112,12 @@ func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Reques // makeSessionCookie creates an http.Cookie containing the authenticated user's // authentication details -func (s *SessionStore) makeSessionCookie(req *http.Request, value string, now time.Time) []*http.Cookie { - if value != "" { - value = encryption.SignedValue(s.Cookie.Secret, s.Cookie.Name, []byte(value), now) +func (s *SessionStore) makeSessionCookie(req *http.Request, value []byte, now time.Time) []*http.Cookie { + strValue := string(value) + if strValue != "" { + strValue = encryption.SignedValue(s.Cookie.Secret, s.Cookie.Name, value, now) } - c := s.makeCookie(req, s.Cookie.Name, value, s.Cookie.Expire, now) - + c := s.makeCookie(req, s.Cookie.Name, strValue, s.Cookie.Expire, now) if len(c.String()) > maxCookieLength { return splitCookie(c) } @@ -129,7 +138,7 @@ func (s *SessionStore) makeCookie(req *http.Request, name string, value string, // NewCookieSessionStore initialises a new instance of the SessionStore from // the configuration given func NewCookieSessionStore(opts *options.SessionOptions, cookieOpts *options.Cookie) (sessions.SessionStore, error) { - cipher, err := encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(cookieOpts.Secret)) + cipher, err := encryption.NewCFBCipher(encryption.SecretBytes(cookieOpts.Secret)) if err != nil { return nil, fmt.Errorf("error initialising cipher: %v", err) } diff --git a/pkg/sessions/redis/redis_store.go b/pkg/sessions/redis/redis_store.go index a89349e8..0e0d7cd9 100644 --- a/pkg/sessions/redis/redis_store.go +++ b/pkg/sessions/redis/redis_store.go @@ -40,7 +40,7 @@ type SessionStore struct { // NewRedisSessionStore initialises a new instance of the SessionStore from // the configuration given func NewRedisSessionStore(opts *options.SessionOptions, cookieOpts *options.Cookie) (sessions.SessionStore, error) { - cipher, err := encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(cookieOpts.Secret)) + cfbCipher, err := encryption.NewCFBCipher(encryption.SecretBytes(cookieOpts.Secret)) if err != nil { return nil, fmt.Errorf("error initialising cipher: %v", err) } @@ -52,7 +52,7 @@ func NewRedisSessionStore(opts *options.SessionOptions, cookieOpts *options.Cook rs := &SessionStore{ Client: client, - CookieCipher: cipher, + CookieCipher: cfbCipher, Cookie: cookieOpts, } return rs, nil @@ -146,12 +146,8 @@ func (store *SessionStore) Save(rw http.ResponseWriter, req *http.Request, s *se // Old sessions that we are refreshing would have a request cookie // New sessions don't, so we ignore the error. storeValue will check requestCookie requestCookie, _ := req.Cookie(store.Cookie.Name) - value, err := s.EncodeSessionState(store.CookieCipher) - if err != nil { - return err - } ctx := req.Context() - ticketString, err := store.storeValue(ctx, value, store.Cookie.Expire, requestCookie) + ticketString, err := store.saveSession(ctx, s, store.Cookie.Expire, requestCookie) if err != nil { return err } @@ -180,40 +176,13 @@ func (store *SessionStore) Load(req *http.Request) (*sessions.SessionState, erro return nil, fmt.Errorf("cookie signature not valid") } ctx := req.Context() - session, err := store.loadSessionFromString(ctx, string(val)) + session, err := store.loadSessionFromTicket(ctx, string(val)) if err != nil { return nil, fmt.Errorf("error loading session: %s", err) } return session, nil } -// loadSessionFromString loads the session based on the ticket value -func (store *SessionStore) loadSessionFromString(ctx context.Context, value string) (*sessions.SessionState, error) { - ticket, err := decodeTicket(store.Cookie.Name, value) - if err != nil { - return nil, err - } - - resultBytes, err := store.Client.Get(ctx, ticket.asHandle(store.Cookie.Name)) - if err != nil { - return nil, err - } - - block, err := aes.NewCipher(ticket.Secret) - if err != nil { - return nil, err - } - // Use secret as the IV too, because each entry has it's own key - stream := cipher.NewCFBDecrypter(block, ticket.Secret) - stream.XORKeyStream(resultBytes, resultBytes) - - session, err := sessions.DecodeSessionState(string(resultBytes), store.CookieCipher) - if err != nil { - return nil, err - } - return session, nil -} - // Clear clears any saved session information for a given ticket cookie // from redis, and then clears the session func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { @@ -253,6 +222,80 @@ func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro return nil } +// saveSession encodes a session with a GCM cipher & saves the data into Redis +func (store *SessionStore) saveSession(ctx context.Context, s *sessions.SessionState, expiration time.Duration, requestCookie *http.Cookie) (string, error) { + ticket, err := store.getTicket(requestCookie) + if err != nil { + return "", fmt.Errorf("error getting ticket: %v", err) + } + + c, err := encryption.NewGCMCipher(ticket.Secret) + if err != nil { + return "", fmt.Errorf("error initiating cipher block %s", err) + } + + // Use AES-GCM since it provides authenticated encryption + // AES-CFB used in cookies has the cookie signing SHA to get around the lack of + // authentication in AES-CFB + ciphertext, err := s.EncodeSessionState(c, false) + if err != nil { + return "", err + } + + handle := ticket.asHandle(store.Cookie.Name) + err = store.Client.Set(ctx, handle, ciphertext, expiration) + if err != nil { + return "", err + } + return ticket.encodeTicket(store.Cookie.Name), nil +} + +// loadSessionFromTicket loads the session based on the ticket value +func (store *SessionStore) loadSessionFromTicket(ctx context.Context, value string) (*sessions.SessionState, error) { + ticket, err := decodeTicket(store.Cookie.Name, value) + if err != nil { + return nil, err + } + + resultBytes, err := store.Client.Get(ctx, ticket.asHandle(store.Cookie.Name)) + if err != nil { + return nil, err + } + + c, err := encryption.NewGCMCipher(ticket.Secret) + if err != nil { + return nil, err + } + + session, err := sessions.DecodeSessionState(resultBytes, c, false) + if err != nil { + // The GCM cipher will error due to a legacy JSON payload not passing + // the authentication check part of AES GCM encryption. + // In that case, we can attempt to fallback to try a legacy load + legacyCipher := encryption.NewBase64Cipher(store.CookieCipher) + return legacyV5DecodeSession(resultBytes, ticket, legacyCipher) + } + return session, nil +} + +// legacyV5DecodeSession loads the session based on the ticket value +// This fallback uses V5 style encryption of Base64 + AES CFB +func legacyV5DecodeSession(resultBytes []byte, ticket *TicketData, c encryption.Cipher) (*sessions.SessionState, error) { + block, err := aes.NewCipher(ticket.Secret) + if err != nil { + return nil, err + } + // Use secret as the IV too, because each entry has it's own key + stream := cipher.NewCFBDecrypter(block, ticket.Secret) + stream.XORKeyStream(resultBytes, resultBytes) + + session, err := sessions.LegacyV5DecodeSessionState(string(resultBytes), c) + if err != nil { + return nil, err + } + return session, nil +} + // makeCookie makes a cookie, signing the value if present func (store *SessionStore) makeCookie(req *http.Request, value string, expires time.Duration, now time.Time) *http.Cookie { if value != "" { @@ -268,30 +311,6 @@ func (store *SessionStore) makeCookie(req *http.Request, value string, expires t ) } -func (store *SessionStore) storeValue(ctx context.Context, value string, expiration time.Duration, requestCookie *http.Cookie) (string, error) { - ticket, err := store.getTicket(requestCookie) - if err != nil { - return "", fmt.Errorf("error getting ticket: %v", err) - } - - ciphertext := make([]byte, len(value)) - block, err := aes.NewCipher(ticket.Secret) - if err != nil { - return "", fmt.Errorf("error initiating cipher block %s", err) - } - - // Use secret as the Initialization Vector too, because each entry has it's own key - stream := cipher.NewCFBEncrypter(block, ticket.Secret) - stream.XORKeyStream(ciphertext, []byte(value)) - - handle := ticket.asHandle(store.Cookie.Name) - err = store.Client.Set(ctx, handle, ciphertext, expiration) - if err != nil { - return "", err - } - return ticket.encodeTicket(store.Cookie.Name), nil -} - // getTicket retrieves an existing ticket from the cookie if present, // or creates a new ticket func (store *SessionStore) getTicket(requestCookie *http.Cookie) (*TicketData, error) { diff --git a/pkg/sessions/redis/session_store_test.go b/pkg/sessions/redis/redis_store_test.go similarity index 65% rename from pkg/sessions/redis/session_store_test.go rename to pkg/sessions/redis/redis_store_test.go index 78dd111f..12965705 100644 --- a/pkg/sessions/redis/session_store_test.go +++ b/pkg/sessions/redis/redis_store_test.go @@ -1,6 +1,11 @@ package redis import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "io" "log" "os" "testing" @@ -17,6 +22,65 @@ import ( . "github.com/onsi/gomega" ) +// TestLegacyV5DecodeSession tests the fallback to LegacyV5DecodeSession +// when a V5 encoded session is in Redis +// +// TODO: Remove when this is deprecated (likely V7) +func Test_legacyV5DecodeSession(t *testing.T) { + testCases, _, legacyCipher := sessionsapi.CreateLegacyV5TestCases(t) + + for testName, tc := range testCases { + t.Run(testName, func(t *testing.T) { + g := NewWithT(t) + + secret := make([]byte, aes.BlockSize) + _, err := io.ReadFull(rand.Reader, secret) + g.Expect(err).ToNot(HaveOccurred()) + ticket := &TicketData{ + TicketID: "", + Secret: secret, + } + + encrypted, err := legacyStoreValue(tc.Input, ticket.Secret) + g.Expect(err).ToNot(HaveOccurred()) + + ss, err := legacyV5DecodeSession(encrypted, ticket, legacyCipher) + if tc.Error { + g.Expect(err).To(HaveOccurred()) + g.Expect(ss).To(BeNil()) + return + } + g.Expect(err).ToNot(HaveOccurred()) + + // Compare sessions without *time.Time fields + exp := *tc.Output + exp.CreatedAt = nil + exp.ExpiresOn = nil + act := *ss + act.CreatedAt = nil + act.ExpiresOn = nil + g.Expect(exp).To(Equal(act)) + }) + } +} + +// legacyStoreValue implements the legacy V5 Redis store AES-CFB value encryption +// +// TODO: Remove when this is deprecated (likely V7) +func legacyStoreValue(value string, ticketSecret []byte) ([]byte, error) { + ciphertext := make([]byte, len(value)) + block, err := aes.NewCipher(ticketSecret) + if err != nil { + return nil, fmt.Errorf("error initiating cipher block: %v", err) + } + + // Use secret as the Initialization Vector too, because each entry has it's own key + stream := cipher.NewCFBEncrypter(block, ticketSecret) + stream.XORKeyStream(ciphertext, []byte(value)) + + return ciphertext, nil +} + func TestSessionStore(t *testing.T) { logger.SetOutput(GinkgoWriter)