package ch

import (
	"context"
	"database/sql"
	"reflect"
	"strings"
	"time"
)

type QueryEvent struct {
	DB *DB

	Model     Model
	IQuery    Query
	Query     string
	QueryArgs []any

	StartTime time.Time
	Result    sql.Result
	Err       error

	Stash map[any]any
}

func (e *QueryEvent) Operation() string {
	if e.IQuery != nil {
		return e.IQuery.Operation()
	}
	return queryOperation(e.Query)
}

func queryOperation(query string) string {
	if idx := strings.IndexByte(query, ' '); idx > 0 {
		query = query[:idx]
	}
	if len(query) > 16 {
		query = query[:16]
	}
	return query
}

// QueryHook ...
type QueryHook interface {
	BeforeQuery(context.Context, *QueryEvent) context.Context
	AfterQuery(context.Context, *QueryEvent)
}

// AddQueryHook adds a hook into query processing.
func (db *DB) AddQueryHook(hook QueryHook) {
	db.queryHooks = append(db.queryHooks, hook)
}

func (db *DB) beforeQuery(
	ctx context.Context,
	iquery Query,
	query string,
	params []any,
	model Model,
) (context.Context, *QueryEvent) {
	if len(db.queryHooks) == 0 {
		return ctx, nil
	}

	evt := &QueryEvent{
		StartTime: time.Now(),
		DB:        db,
		Model:     model,
		IQuery:    iquery,
		Query:     query,
		QueryArgs: params,
	}
	for _, hook := range db.queryHooks {
		ctx = hook.BeforeQuery(ctx, evt)
	}
	return ctx, evt
}

func (db *DB) afterQuery(
	ctx context.Context,
	evt *QueryEvent,
	res *result,
	err error,
) {
	if evt == nil {
		return
	}

	evt.Err = err
	if res != nil {
		evt.Result = res
	}

	for _, hook := range db.queryHooks {
		hook.AfterQuery(ctx, evt)
	}
}

//---------------------------------------------------------------------------------------

func callAfterScanRowHook(ctx context.Context, v reflect.Value) error {
	return v.Interface().(AfterScanRowHook).AfterScanRow(ctx)
}

func callAfterScanRowHookSlice(ctx context.Context, slice reflect.Value) error {
	return callHookSlice(ctx, slice, callAfterScanRowHook)
}

func callHookSlice(
	ctx context.Context,
	slice reflect.Value,
	hook func(context.Context, reflect.Value) error,
) error {
	var ptr bool
	switch slice.Type().Elem().Kind() {
	case reflect.Ptr, reflect.Interface:
		ptr = true
	}

	var firstErr error
	sliceLen := slice.Len()
	for i := 0; i < sliceLen; i++ {
		v := slice.Index(i)
		if !ptr {
			v = v.Addr()
		}

		err := hook(ctx, v)
		if err != nil && firstErr == nil {
			firstErr = err
		}
	}
	return firstErr
}