From ac25928fd0895b0b90dff9a7c875bb99f163b19a Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Sat, 12 Mar 2016 11:49:45 -0800 Subject: [PATCH] Closes #393 Signed-off-by: Vishal Rana --- context.go | 16 ++++++---------- context_test.go | 16 ++++++++++------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/context.go b/context.go index b5651b94..a5e24620 100644 --- a/context.go +++ b/context.go @@ -46,7 +46,7 @@ type ( XML(int, interface{}) error XMLBlob(int, []byte) error File(string) error - Attachment(string) error + Attachment(io.Reader, string) error NoContent(int) error Redirect(int, string) error Error(err error) @@ -288,17 +288,13 @@ func (c *context) File(file string) error { return ServeContent(c.Request(), c.Response(), f, fi) } -// Attachment sends a response as file attachment, prompting client to save the file. -func (c *context) Attachment(file string) (err error) { - f, err := os.Open(file) - if err != nil { - return - } - _, name := filepath.Split(file) - c.response.Header().Set(ContentType, detectContentType(file)) +// Attachment sends a response from `io.Reader` as attachment, prompting client +// to save the file. +func (c *context) Attachment(r io.Reader, name string) (err error) { + c.response.Header().Set(ContentType, detectContentType(name)) c.response.Header().Set(ContentDisposition, "attachment; filename="+name) c.response.WriteHeader(http.StatusOK) - _, err = io.Copy(c.response, f) + _, err = io.Copy(c.response, r) return } diff --git a/context_test.go b/context_test.go index 3abefe47..4aee5342 100644 --- a/context_test.go +++ b/context_test.go @@ -4,6 +4,7 @@ import ( "errors" "io" "net/http" + "os" "testing" "text/template" @@ -161,11 +162,14 @@ func TestContext(t *testing.T) { // Attachment rec = test.NewResponseRecorder() c = NewContext(req, rec, e) - err = c.Attachment("_fixture/images/walle.png") + file, err := os.Open("_fixture/images/walle.png") if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Status()) - assert.Equal(t, rec.Header().Get(ContentDisposition), "attachment; filename=walle.png") - assert.Equal(t, 219885, rec.Body.Len()) + err = c.Attachment(file, "walle.png") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Status()) + assert.Equal(t, "attachment; filename=walle.png", rec.Header().Get(ContentDisposition)) + assert.Equal(t, 219885, rec.Body.Len()) + } } // NoContent @@ -198,12 +202,12 @@ func TestContextPath(t *testing.T) { r.Add(GET, "/users/:id", nil, e) c := NewContext(nil, nil, e) r.Find(GET, "/users/1", c) - assert.Equal(t, c.Path(), "/users/:id") + assert.Equal(t, "/users/:id", c.Path()) r.Add(GET, "/users/:uid/files/:fid", nil, e) c = NewContext(nil, nil, e) r.Find(GET, "/users/1/files/1", c) - assert.Equal(t, c.Path(), "/users/:uid/files/:fid") + assert.Equal(t, "/users/:uid/files/:fid", c.Path()) } func TestContextQuery(t *testing.T) {