diff --git a/src/NzbDrone.Api/Extensions/Pipelines/CacheHeaderPipeline.cs b/src/NzbDrone.Api/Extensions/Pipelines/CacheHeaderPipeline.cs index 94c738d9b..d8e9266ad 100644 --- a/src/NzbDrone.Api/Extensions/Pipelines/CacheHeaderPipeline.cs +++ b/src/NzbDrone.Api/Extensions/Pipelines/CacheHeaderPipeline.cs @@ -23,6 +23,8 @@ public void Register(IPipelines pipelines) private void Handle(NancyContext context) { + if (context.Request.Method == "OPTIONS") return; + if (_cacheableSpecification.IsCacheable(context)) { context.Response.Headers.EnableCache(); @@ -33,4 +35,4 @@ private void Handle(NancyContext context) } } } -} \ No newline at end of file +} diff --git a/src/NzbDrone.Api/Extensions/Pipelines/CorsPipeline.cs b/src/NzbDrone.Api/Extensions/Pipelines/CorsPipeline.cs index fcdd6317a..ad98837e8 100644 --- a/src/NzbDrone.Api/Extensions/Pipelines/CorsPipeline.cs +++ b/src/NzbDrone.Api/Extensions/Pipelines/CorsPipeline.cs @@ -12,9 +12,24 @@ public class CorsPipeline : IRegisterNancyPipeline public void Register(IPipelines pipelines) { + pipelines.BeforeRequest.AddItemToEndOfPipeline(HandleRequest); pipelines.AfterRequest.AddItemToEndOfPipeline(HandleResponse); } + private Response HandleRequest(NancyContext context) + { + if (context == null || context.Request.Method != "OPTIONS") + { + return null; + } + + var response = new Response() + .WithStatusCode(HttpStatusCode.OK) + .WithContentType(""); + ApplyResponseHeaders(response, context.Request); + return response; + } + private void HandleResponse(NancyContext context) { if (context == null || context.Response.Headers.ContainsKey(AccessControlHeaders.AllowOrigin)) @@ -45,18 +60,21 @@ private static void ApplyCorsResponseHeaders(Response response, Request request, { response.Headers.Add(AccessControlHeaders.AllowOrigin, allowOrigin); - if (response.Headers.ContainsKey("Allow")) + if (request.Method == "OPTIONS") { - allowedMethods = response.Headers["Allow"]; - } + if (response.Headers.ContainsKey("Allow")) + { + allowedMethods = response.Headers["Allow"]; + } - response.Headers.Add(AccessControlHeaders.AllowMethods, allowedMethods); + response.Headers.Add(AccessControlHeaders.AllowMethods, allowedMethods); - if (request.Headers[AccessControlHeaders.RequestHeaders].Any()) - { - var requestedHeaders = request.Headers[AccessControlHeaders.RequestHeaders].Join(", "); + if (request.Headers[AccessControlHeaders.RequestHeaders].Any()) + { + var requestedHeaders = request.Headers[AccessControlHeaders.RequestHeaders].Join(", "); - response.Headers.Add(AccessControlHeaders.AllowHeaders, requestedHeaders); + response.Headers.Add(AccessControlHeaders.AllowHeaders, requestedHeaders); + } } } } diff --git a/src/NzbDrone.Api/Extensions/Pipelines/GZipPipeline.cs b/src/NzbDrone.Api/Extensions/Pipelines/GZipPipeline.cs index 12293f23c..8aa9f4ad2 100644 --- a/src/NzbDrone.Api/Extensions/Pipelines/GZipPipeline.cs +++ b/src/NzbDrone.Api/Extensions/Pipelines/GZipPipeline.cs @@ -33,7 +33,8 @@ private void CompressResponse(NancyContext context) try { if ( - !response.ContentType.Contains("image") + response.Contents != Response.NoBody + && !response.ContentType.Contains("image") && !response.ContentType.Contains("font") && request.Headers.AcceptEncoding.Any(x => x.Contains("gzip")) && !AlreadyGzipEncoded(response) @@ -80,4 +81,4 @@ private static bool AlreadyGzipEncoded(Response response) return false; } } -} \ No newline at end of file +} diff --git a/src/NzbDrone.Integration.Test/CorsFixture.cs b/src/NzbDrone.Integration.Test/CorsFixture.cs index 2d9d8ac4f..a37936a2a 100644 --- a/src/NzbDrone.Integration.Test/CorsFixture.cs +++ b/src/NzbDrone.Integration.Test/CorsFixture.cs @@ -8,30 +8,37 @@ namespace NzbDrone.Integration.Test [TestFixture] public class CorsFixture : IntegrationTest { - private RestRequest BuildRequest() + private RestRequest BuildGet(string route = "series") { - var request = new RestRequest("series"); + var request = new RestRequest(route, Method.GET); request.AddHeader(AccessControlHeaders.RequestMethod, "POST"); return request; } + private RestRequest BuildOptions(string route = "series") + { + var request = new RestRequest(route, Method.OPTIONS); + + return request; + } + [Test] public void should_not_have_allow_headers_in_response_when_not_included_in_the_request() { - var request = BuildRequest(); - var response = RestClient.Get(request); - + var request = BuildOptions(); + var response = RestClient.Execute(request); + response.Headers.Should().NotContain(h => h.Name == AccessControlHeaders.AllowHeaders); } [Test] public void should_have_allow_headers_in_response_when_included_in_the_request() { - var request = BuildRequest(); + var request = BuildOptions(); request.AddHeader(AccessControlHeaders.RequestHeaders, "X-Test"); - var response = RestClient.Get(request); + var response = RestClient.Execute(request); response.Headers.Should().Contain(h => h.Name == AccessControlHeaders.AllowHeaders); } @@ -39,8 +46,8 @@ public void should_have_allow_headers_in_response_when_included_in_the_request() [Test] public void should_have_allow_origin_in_response() { - var request = BuildRequest(); - var response = RestClient.Get(request); + var request = BuildOptions(); + var response = RestClient.Execute(request); response.Headers.Should().Contain(h => h.Name == AccessControlHeaders.AllowOrigin); } @@ -48,10 +55,37 @@ public void should_have_allow_origin_in_response() [Test] public void should_have_allow_methods_in_response() { - var request = BuildRequest(); - var response = RestClient.Get(request); + var request = BuildOptions(); + var response = RestClient.Execute(request); response.Headers.Should().Contain(h => h.Name == AccessControlHeaders.AllowMethods); } + + [Test] + public void should_not_have_allow_methods_in_non_options_request() + { + var request = BuildGet(); + var response = RestClient.Execute(request); + + response.Headers.Should().NotContain(h => h.Name == AccessControlHeaders.AllowMethods); + } + + [Test] + public void should_have_allow_origin_in_non_options_request() + { + var request = BuildGet(); + var response = RestClient.Execute(request); + + response.Headers.Should().Contain(h => h.Name == AccessControlHeaders.AllowOrigin); + } + + [Test] + public void should_not_have_allow_origin_in_non_api_request() + { + var request = BuildGet("../abc"); + var response = RestClient.Execute(request); + + response.Headers.Should().NotContain(h => h.Name == AccessControlHeaders.AllowOrigin); + } } }