diff --git a/sql.go b/sql.go index 2d7679e..98b23aa 100644 --- a/sql.go +++ b/sql.go @@ -24,14 +24,22 @@ func (uuid *UUID) Scan(src interface{}) error { *uuid = parsed case []byte: - // assumes a simple slice of bytes, just check validity and store - u := UUID(src.([]byte)) + b := src.([]byte) - if u.Variant() == Invalid { - return errors.New("Scan: invalid UUID format") + // assumes a simple slice of bytes if 16 bytes + // otherwise attempts to parse + if len(b) == 16 { + *uuid = UUID(b) + } else { + u := Parse(string(b)) + + if u == nil { + return errors.New("Scan: invalid UUID format") + } + + *uuid = u } - *uuid = u default: return fmt.Errorf("Scan: unable to scan type %T into UUID", src) } diff --git a/sql_test.go b/sql_test.go index d643567..83bac8c 100644 --- a/sql_test.go +++ b/sql_test.go @@ -22,6 +22,11 @@ func TestScan(t *testing.T) { t.Fatal(err) } + err = (&uuid).Scan([]byte(stringTest)) + if err != nil { + t.Fatal(err) + } + err = (&uuid).Scan(byteTest) if err != nil { t.Fatal(err)