diff --git a/download.go b/download.go index 60d93a43..2cc46f39 100644 --- a/download.go +++ b/download.go @@ -16,9 +16,12 @@ import ( var ( downloadClient *http.Client - imageDataCtxKey = ctxKey("imageData") - cacheControlHeaderCtxKey = ctxKey("cacheControlHeader") - expiresHeaderCtxKey = ctxKey("expiresHeader") + imageDataCtxKey = ctxKey("imageData") + + imageHeadersToStore = []string{ + "Cache-Control", + "Expires", + } errSourceResolutionTooBig = newError(422, "Source image resolution is too big", "Invalid source image") errSourceFileTooBig = newError(422, "Source image file is too big", "Invalid source image") @@ -147,10 +150,14 @@ func readAndCheckImage(r io.Reader, contentLength int) (*imageData, error) { if _, err = buf.ReadFrom(r); err != nil { cancel() - return nil, newError(404, err.Error(), msgSourceImageIsUnreachable) + return nil, newError(404, err.Error(), msgSourceImageIsUnreachable).SetUnexpected(conf.ReportDownloadingErrors) } - return &imageData{buf.Bytes(), imgtype, cancel}, nil + return &imageData{ + Data: buf.Bytes(), + Type: imgtype, + cancel: cancel, + }, nil } func requestImage(imageURL string) (*http.Response, error) { @@ -168,6 +175,8 @@ func requestImage(imageURL string) (*http.Response, error) { if res.StatusCode != 200 { body, _ := ioutil.ReadAll(res.Body) + res.Body.Close() + msg := fmt.Sprintf("Can't download image; Status: %d; %s", res.StatusCode, string(body)) return res, newError(404, msg, msgSourceImageIsUnreachable).SetUnexpected(conf.ReportDownloadingErrors) } @@ -175,7 +184,31 @@ func requestImage(imageURL string) (*http.Response, error) { return res, nil } -func downloadImage(ctx context.Context) (context.Context, context.CancelFunc, error) { +func downloadImage(imageURL string) (*imageData, error) { + res, err := requestImage(imageURL) + if res != nil { + defer res.Body.Close() + } + if err != nil { + return nil, err + } + + imgdata, err := readAndCheckImage(res.Body, int(res.ContentLength)) + if err != nil { + return nil, err + } + + imgdata.Headers = make(map[string]string) + for _, h := range imageHeadersToStore { + if val := res.Header.Get(h); len(val) != 0 { + imgdata.Headers[h] = val + } + } + + return imgdata, nil +} + +func downloadImageCtx(ctx context.Context) (context.Context, context.CancelFunc, error) { imageURL := getImageURL(ctx) if newRelicEnabled { @@ -187,36 +220,16 @@ func downloadImage(ctx context.Context) (context.Context, context.CancelFunc, er defer startPrometheusDuration(prometheusDownloadDuration)() } - res, err := requestImage(imageURL) - if res != nil { - defer res.Body.Close() - } - if err != nil { - return ctx, func() {}, err - } - - imgdata, err := readAndCheckImage(res.Body, int(res.ContentLength)) + imgdata, err := downloadImage(imageURL) if err != nil { return ctx, func() {}, err } ctx = context.WithValue(ctx, imageDataCtxKey, imgdata) - ctx = context.WithValue(ctx, cacheControlHeaderCtxKey, res.Header.Get("Cache-Control")) - ctx = context.WithValue(ctx, expiresHeaderCtxKey, res.Header.Get("Expires")) - return ctx, imgdata.Close, err + return ctx, imgdata.Close, nil } func getImageData(ctx context.Context) *imageData { return ctx.Value(imageDataCtxKey).(*imageData) } - -func getCacheControlHeader(ctx context.Context) string { - str, _ := ctx.Value(cacheControlHeaderCtxKey).(string) - return str -} - -func getExpiresHeader(ctx context.Context) string { - str, _ := ctx.Value(expiresHeaderCtxKey).(string) - return str -} diff --git a/image_data.go b/image_data.go index 96953f98..4e89ca9b 100644 --- a/image_data.go +++ b/image_data.go @@ -9,8 +9,9 @@ import ( ) type imageData struct { - Data []byte - Type imageType + Data []byte + Type imageType + Headers map[string]string cancel context.CancelFunc } @@ -87,18 +88,10 @@ func fileImageData(path, desc string) (*imageData, error) { } func remoteImageData(imageURL, desc string) (*imageData, error) { - res, err := requestImage(imageURL) - if res != nil { - defer res.Body.Close() - } + imgdata, err := downloadImage(imageURL) if err != nil { return nil, fmt.Errorf("Can't download %s: %s", desc, err) } - imgdata, err := readAndCheckImage(res.Body, int(res.ContentLength)) - if err != nil { - return nil, fmt.Errorf("Can't download %s: %s", desc, err) - } - - return imgdata, err + return imgdata, nil } diff --git a/processing_handler.go b/processing_handler.go index a66bce3b..d81d508c 100644 --- a/processing_handler.go +++ b/processing_handler.go @@ -42,6 +42,7 @@ func initProcessingHandler() error { func respondWithImage(ctx context.Context, reqID string, r *http.Request, rw http.ResponseWriter, data []byte) { po := getProcessingOptions(ctx) + imgdata := getImageData(ctx) var contentDisposition string if len(po.Filename) > 0 { @@ -63,9 +64,13 @@ func respondWithImage(ctx context.Context, reqID string, r *http.Request, rw htt var cacheControl, expires string - if conf.CacheControlPassthrough { - cacheControl = getCacheControlHeader(ctx) - expires = getExpiresHeader(ctx) + if conf.CacheControlPassthrough && imgdata.Headers != nil { + if val, ok := imgdata.Headers["Cache-Control"]; ok { + cacheControl = val + } + if val, ok := imgdata.Headers["Expires"]; ok { + expires = val + } } if len(cacheControl) == 0 && len(expires) == 0 { @@ -85,7 +90,6 @@ func respondWithImage(ctx context.Context, reqID string, r *http.Request, rw htt } if conf.EnableDebugHeaders { - imgdata := getImageData(ctx) rw.Header().Set("X-Origin-Content-Length", strconv.Itoa(len(imgdata.Data))) } @@ -135,7 +139,7 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) { panic(err) } - ctx, downloadcancel, err := downloadImage(ctx) + ctx, downloadcancel, err := downloadImageCtx(ctx) defer downloadcancel() if err != nil { if newRelicEnabled {