package meddler

import (
	"database/sql"
	"fmt"
	"log"
	"reflect"
	"strconv"
	"strings"
	"sync"
)

// the name of our struct tag
const tagName = "meddler"

// Database contains database-specific options.
// MySQL, PostgreSQL, and SQLite are provided for convenience.
// Setting Default to any of these lets you use the package-level convenience functions.
type Database struct {
	Quote               string // the quote character for table and column names
	Placeholder         string // the placeholder style to use in generated queries
	UseReturningToGetID bool   // use PostgreSQL-style RETURNING "ID" instead of calling sql.Result.LastInsertID
}

var MySQL = &Database{
	Quote:               "`",
	Placeholder:         "?",
	UseReturningToGetID: false,
}

var PostgreSQL = &Database{
	Quote:               `"`,
	Placeholder:         "$1",
	UseReturningToGetID: true,
}

var SQLite = &Database{
	Quote:               `"`,
	Placeholder:         "?",
	UseReturningToGetID: false,
}

var Default = MySQL

func (d *Database) quoted(s string) string {
	return d.Quote + s + d.Quote
}

func (d *Database) placeholder(n int) string {
	return strings.Replace(d.Placeholder, "1", strconv.FormatInt(int64(n), 10), 1)
}

// Debug enables debug mode, where unused columns and struct fields will be logged
var Debug = true

type structField struct {
	column     string
	index      int
	primaryKey bool
	meddler    Meddler
}

type structData struct {
	columns []string
	fields  map[string]*structField
	pk      string
}

// cache reflection data
var fieldsCache = make(map[reflect.Type]*structData)
var fieldsCacheMutex sync.Mutex

// getFields gathers the list of columns from a struct using reflection.
func getFields(dstType reflect.Type) (*structData, error) {
	fieldsCacheMutex.Lock()
	defer fieldsCacheMutex.Unlock()

	if result, present := fieldsCache[dstType]; present {
		return result, nil
	}

	// make sure dst is a non-nil pointer to a struct
	if dstType.Kind() != reflect.Ptr {
		return nil, fmt.Errorf("meddler called with non-pointer destination %v", dstType)
	}
	structType := dstType.Elem()
	if structType.Kind() != reflect.Struct {
		return nil, fmt.Errorf("meddler called with pointer to non-struct %v", dstType)
	}

	// gather the list of fields in the struct
	data := new(structData)
	data.fields = make(map[string]*structField)

	for i := 0; i < structType.NumField(); i++ {
		f := structType.Field(i)

		// skip non-exported fields
		if f.PkgPath != "" {
			continue
		}

		// examine the tag for metadata
		tag := strings.Split(f.Tag.Get(tagName), ",")

		// was this field marked for skipping?
		if len(tag) > 0 && tag[0] == "-" {
			continue
		}

		// default to the field name
		name := f.Name

		// the tag can override the field name
		if len(tag) > 0 && tag[0] != "" {
			name = tag[0]
		}

		// check for a meddler
		var meddler Meddler = registry["identity"]
		for j := 1; j < len(tag); j++ {
			if tag[j] == "pk" {
				if f.Type.Kind() == reflect.Ptr {
					return nil, fmt.Errorf("meddler found field %s which is marked as the primary key but is a pointer", f.Name)
				}

				// make sure it is an int of some kind
				switch f.Type.Kind() {
				case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
				case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
				default:
					return nil, fmt.Errorf("meddler found field %s which is marked as the primary key, but is not an integer type", f.Name)
				}

				if data.pk != "" {
					return nil, fmt.Errorf("meddler found field %s which is marked as the primary key, but a primary key field was already found", f.Name)
				}
				data.pk = name
			} else if m, present := registry[tag[j]]; present {
				meddler = m
			} else {
				return nil, fmt.Errorf("meddler found field %s with meddler %s, but that meddler is not registered", f.Name, tag[j])
			}
		}

		if _, present := data.fields[name]; present {
			return nil, fmt.Errorf("meddler found multiple fields for column %s", name)
		}
		data.fields[name] = &structField{
			column:     name,
			primaryKey: name == data.pk,
			index:      i,
			meddler:    meddler,
		}
		data.columns = append(data.columns, name)
	}

	fieldsCache[dstType] = data
	return data, nil
}

// Columns returns a list of column names for its input struct.
func (d *Database) Columns(src interface{}, includePk bool) ([]string, error) {
	data, err := getFields(reflect.TypeOf(src))
	if err != nil {
		return nil, err
	}

	var names []string
	for _, elt := range data.columns {
		if !includePk && elt == data.pk {
			continue
		}
		names = append(names, elt)
	}

	return names, nil
}

// Columns using the Default Database type
func Columns(src interface{}, includePk bool) ([]string, error) {
	return Default.Columns(src, includePk)
}

// ColumnsQuoted is similar to Columns, but it return the list of columns in the form:
//   `column1`,`column2`,...
// using Quote as the quote character.
func (d *Database) ColumnsQuoted(src interface{}, includePk bool) (string, error) {
	unquoted, err := Columns(src, includePk)
	if err != nil {
		return "", err
	}

	var parts []string
	for _, elt := range unquoted {
		parts = append(parts, d.quoted(elt))
	}

	return strings.Join(parts, ","), nil
}

// ColumnsQuoted using the Default Database type
func ColumnsQuoted(src interface{}, includePk bool) (string, error) {
	return Default.ColumnsQuoted(src, includePk)
}

// PrimaryKey returns the name and value of the primary key field. The name
// is the empty string if there is not primary key field marked.
func (d *Database) PrimaryKey(src interface{}) (name string, pk int64, err error) {
	data, err := getFields(reflect.TypeOf(src))
	if err != nil {
		return "", 0, err
	}

	if data.pk == "" {
		return "", 0, nil
	}

	name = data.pk
	field := reflect.ValueOf(src).Elem().Field(data.fields[name].index)
	switch field.Type().Kind() {
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		pk = field.Int()
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		pk = int64(field.Uint())
	default:
		return "", 0, fmt.Errorf("meddler found field %s which is marked as the primary key, but is not an integer type", name)
	}

	return name, pk, nil
}

// PrimaryKey using the Default Database type
func PrimaryKey(src interface{}) (name string, pk int64, err error) {
	return Default.PrimaryKey(src)
}

// SetPrimaryKey sets the primary key field to the given int value.
func (d *Database) SetPrimaryKey(src interface{}, pk int64) error {
	data, err := getFields(reflect.TypeOf(src))
	if err != nil {
		return err
	}

	if data.pk == "" {
		return fmt.Errorf("meddler.SetPrimaryKey: no primary key field found")
	}

	field := reflect.ValueOf(src).Elem().Field(data.fields[data.pk].index)
	switch field.Type().Kind() {
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		field.SetInt(pk)
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		field.SetUint(uint64(pk))
	default:
		return fmt.Errorf("meddler found field %s which is marked as the primary key, but is not an integer type", data.pk)
	}

	return nil
}

// SetPrimaryKey using the Default Database type
func SetPrimaryKey(src interface{}, pk int64) error {
	return Default.SetPrimaryKey(src, pk)
}

// Values returns a list of PreWrite processed values suitable for
// use in an INSERT or UPDATE query. If includePk is false, the primary
// key field is omitted. The columns used are the same ones (in the same
// order) as returned by Columns.
func (d *Database) Values(src interface{}, includePk bool) ([]interface{}, error) {
	columns, err := d.Columns(src, includePk)
	if err != nil {
		return nil, err
	}
	return d.SomeValues(src, columns)
}

// Values using the Default Database type
func Values(src interface{}, includePk bool) ([]interface{}, error) {
	return Default.Values(src, includePk)
}

// SomeValues returns a list of PreWrite processed values suitable for
// use in an INSERT or UPDATE query. The columns used are the same ones (in
// the same order) as specified in the columns argument.
func (d *Database) SomeValues(src interface{}, columns []string) ([]interface{}, error) {
	data, err := getFields(reflect.TypeOf(src))
	if err != nil {
		return nil, err
	}
	structVal := reflect.ValueOf(src).Elem()

	var values []interface{}
	for _, name := range columns {
		field, present := data.fields[name]
		if !present {
			// write null to the database
			values = append(values, nil)

			if Debug {
				log.Printf("meddler.SomeValues: column [%s] not found in struct", name)
			}
			continue
		}

		saveVal, err := field.meddler.PreWrite(structVal.Field(field.index).Interface())
		if err != nil {
			return nil, fmt.Errorf("meddler.SomeValues: PreWrite error on column [%s]: %v", name, err)
		}
		values = append(values, saveVal)
	}

	return values, nil
}

// SomeValues using the Default Database type
func SomeValues(src interface{}, columns []string) ([]interface{}, error) {
	return Default.SomeValues(src, columns)
}

// Placeholders returns a list of placeholders suitable for an INSERT or UPDATE query.
// If includePk is false, the primary key field is omitted.
func (d *Database) Placeholders(src interface{}, includePk bool) ([]string, error) {
	data, err := getFields(reflect.TypeOf(src))
	if err != nil {
		return nil, err
	}

	var placeholders []string
	for _, name := range data.columns {
		if !includePk && name == data.pk {
			continue
		}
		ph := d.placeholder(len(placeholders) + 1)
		placeholders = append(placeholders, ph)
	}

	return placeholders, nil
}

// Placeholders using the Default Database type
func Placeholders(src interface{}, includePk bool) ([]string, error) {
	return Default.Placeholders(src, includePk)
}

// PlaceholdersString returns a list of placeholders suitable for an INSERT
// or UPDATE query in string form, e.g.:
//   ?,?,?,?
// if includePk is false, the primary key field is omitted.
func (d *Database) PlaceholdersString(src interface{}, includePk bool) (string, error) {
	lst, err := d.Placeholders(src, includePk)
	if err != nil {
		return "", err
	}
	return strings.Join(lst, ","), nil
}

// PlaceholdersString using the Default Database type
func PlaceholdersString(src interface{}, includePk bool) (string, error) {
	return Default.PlaceholdersString(src, includePk)
}

// scan a single row of data into a struct.
func (d *Database) scanRow(data *structData, rows *sql.Rows, dst interface{}, columns []string) error {
	// check if there is data waiting
	if !rows.Next() {
		if err := rows.Err(); err != nil {
			return err
		}
		return sql.ErrNoRows
	}

	// get a list of targets
	targets, err := d.Targets(dst, columns)
	if err != nil {
		return err
	}

	// perform the scan
	if err := rows.Scan(targets...); err != nil {
		return err
	}

	// post-process and copy the target values into the struct
	if err := d.WriteTargets(dst, columns, targets); err != nil {
		return err
	}

	return rows.Err()
}

// Targets returns a list of values suitable for handing to a
// Scan function in the sql package, complete with meddling. After
// the Scan is performed, the same values should be handed to
// WriteTargets to finalize the values and record them in the struct.
func (d *Database) Targets(dst interface{}, columns []string) ([]interface{}, error) {
	data, err := getFields(reflect.TypeOf(dst))
	if err != nil {
		return nil, err
	}

	structVal := reflect.ValueOf(dst).Elem()

	var targets []interface{}
	for _, name := range columns {
		if field, present := data.fields[name]; present {
			fieldAddr := structVal.Field(field.index).Addr().Interface()
			scanTarget, err := field.meddler.PreRead(fieldAddr)
			if err != nil {
				return nil, fmt.Errorf("meddler.Targets: PreRead error on column %s: %v", name, err)
			}
			targets = append(targets, scanTarget)
		} else {
			// no destination, so throw this away
			targets = append(targets, new(interface{}))

			if Debug {
				log.Printf("meddler.Targets: column [%s] not found in struct", name)
			}
		}
	}

	return targets, nil
}

// Targets using the Default Database type
func Targets(dst interface{}, columns []string) ([]interface{}, error) {
	return Default.Targets(dst, columns)
}

// WriteTargets post-processes values with meddlers after a Scan from the
// sql package has been performed. The list of targets is normally produced
// by Targets.
func (d *Database) WriteTargets(dst interface{}, columns []string, targets []interface{}) error {
	if len(columns) != len(targets) {
		return fmt.Errorf("meddler.WriteTargets: mismatch in number of columns (%d) and targets (%s)",
			len(columns), len(targets))
	}

	data, err := getFields(reflect.TypeOf(dst))
	if err != nil {
		return err
	}
	structVal := reflect.ValueOf(dst).Elem()

	for i, name := range columns {
		if field, present := data.fields[name]; present {
			fieldAddr := structVal.Field(field.index).Addr().Interface()
			err := field.meddler.PostRead(fieldAddr, targets[i])
			if err != nil {
				return fmt.Errorf("meddler.WriteTargets: PostRead error on column [%s]: %v", name, err)
			}
		} else {
			// not destination, so throw this away
			if Debug {
				log.Printf("meddler.WriteTargets: column [%s] not found in struct", name)
			}
		}
	}

	return nil
}

// WriteTargets using the Default Database type
func WriteTargets(dst interface{}, columns []string, targets []interface{}) error {
	return Default.WriteTargets(dst, columns, targets)
}

// Scan scans a single sql result row into a struct.
// It leaves rows ready to be scanned again for the next row.
// Returns sql.ErrNoRows if there is no data to read.
func (d *Database) Scan(rows *sql.Rows, dst interface{}) error {
	// get the list of struct fields
	data, err := getFields(reflect.TypeOf(dst))
	if err != nil {
		return err
	}

	// get the sql columns
	columns, err := rows.Columns()
	if err != nil {
		return err
	}

	return d.scanRow(data, rows, dst, columns)
}

// Scan using the Default Database type
func Scan(rows *sql.Rows, dst interface{}) error {
	return Default.Scan(rows, dst)
}

// ScanRow scans a single sql result row into a struct.
// It reads exactly one result row and closes rows when finished.
// Returns sql.ErrNoRows if there is no result row.
func (d *Database) ScanRow(rows *sql.Rows, dst interface{}) error {
	// make sure we always close rows
	defer rows.Close()

	if err := d.Scan(rows, dst); err != nil {
		return err
	}
	if err := rows.Close(); err != nil {
		return err
	}

	return nil
}

// ScanRow using the Default Database type
func ScanRow(rows *sql.Rows, dst interface{}) error {
	return Default.ScanRow(rows, dst)
}

// ScanAll scans all sql result rows into a slice of structs.
// It reads all rows and closes rows when finished.
// dst should be a pointer to a slice of the appropriate type.
// The new results will be appended to any existing data in dst.
func (d *Database) ScanAll(rows *sql.Rows, dst interface{}) error {
	// make sure we always close rows
	defer rows.Close()

	// make sure dst is an appropriate type
	dstVal := reflect.ValueOf(dst)
	if dstVal.Kind() != reflect.Ptr || dstVal.IsNil() {
		return fmt.Errorf("ScanAll called with non-pointer destination: %T", dst)
	}
	sliceVal := dstVal.Elem()
	if sliceVal.Kind() != reflect.Slice {
		return fmt.Errorf("ScanAll called with pointer to non-slice: %T", dst)
	}
	ptrType := sliceVal.Type().Elem()
	if ptrType.Kind() != reflect.Ptr {
		return fmt.Errorf("ScanAll expects element to be pointers, found %T", dst)
	}
	eltType := ptrType.Elem()
	if eltType.Kind() != reflect.Struct {
		return fmt.Errorf("ScanAll expects element to be pointers to structs, found %T", dst)
	}

	// get the list of struct fields
	data, err := getFields(ptrType)
	if err != nil {
		return err
	}

	// get the sql columns
	columns, err := rows.Columns()
	if err != nil {
		return err
	}

	// gather the results
	for {
		// create a new element
		eltVal := reflect.New(eltType)
		elt := eltVal.Interface()

		// scan it
		if err := d.scanRow(data, rows, elt, columns); err != nil {
			if err == sql.ErrNoRows {
				return nil
			}
			return err
		}

		// add to the result slice
		sliceVal.Set(reflect.Append(sliceVal, eltVal))
	}
}

// ScanAll using the Default Database type
func ScanAll(rows *sql.Rows, dst interface{}) error {
	return Default.ScanAll(rows, dst)
}