1
0
mirror of https://github.com/uptrace/go-clickhouse.git synced 2025-06-08 23:26:11 +02:00

593 lines
12 KiB
Go
Raw Normal View History

2022-01-23 09:36:24 +02:00
package ch
import (
"context"
"crypto/tls"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"net"
"reflect"
"sync/atomic"
"time"
"github.com/uptrace/go-clickhouse/ch/chpool"
"github.com/uptrace/go-clickhouse/ch/chproto"
"github.com/uptrace/go-clickhouse/ch/chschema"
"github.com/uptrace/go-clickhouse/ch/internal"
)
type DBStats struct {
Queries uint64
Errors uint64
}
type DB struct {
cfg *Config
pool *chpool.ConnPool
queryHooks []QueryHook
fmter chschema.Formatter
flags internal.Flag
stats DBStats
}
func Connect(opts ...Option) *DB {
db := &DB{
cfg: defaultConfig(),
}
for _, opt := range opts {
opt(db)
}
db.pool = newConnPool(db.cfg)
return db
}
func newConnPool(cfg *Config) *chpool.ConnPool {
poolcfg := cfg.Config
poolcfg.Dialer = func(ctx context.Context) (net.Conn, error) {
if cfg.TLSConfig != nil {
return tls.DialWithDialer(
cfg.netDialer(),
cfg.Network,
cfg.Addr,
cfg.TLSConfig,
)
}
return cfg.netDialer().DialContext(ctx, cfg.Network, cfg.Addr)
}
return chpool.New(&poolcfg)
}
// Close closes the database client, releasing any open resources.
//
// It is rare to Close a DB, as the DB handle is meant to be
// long-lived and shared between many goroutines.
func (db *DB) Close() error {
return db.pool.Close()
}
func (db *DB) String() string {
return fmt.Sprintf("DB<addr: %s>", db.cfg.Addr)
}
func (db *DB) Config() *Config {
return db.cfg
}
func (db *DB) WithTimeout(d time.Duration) *DB {
newcfg := *db.cfg
newcfg.ReadTimeout = d
newcfg.WriteTimeout = d
clone := db.clone()
clone.cfg = &newcfg
return clone
}
func (db *DB) clone() *DB {
clone := *db
l := len(db.queryHooks)
clone.queryHooks = db.queryHooks[:l:l]
return &clone
}
func (db *DB) Stats() DBStats {
return DBStats{
Queries: atomic.LoadUint64(&db.stats.Queries),
Errors: atomic.LoadUint64(&db.stats.Errors),
}
}
func (db *DB) getConn(ctx context.Context) (*chpool.Conn, error) {
cn, err := db.pool.Get(ctx)
if err != nil {
return nil, err
}
if err := db.initConn(ctx, cn); err != nil {
db.pool.Remove(cn, err)
if err := internal.Unwrap(err); err != nil {
return nil, err
}
return nil, err
}
return cn, nil
}
func (db *DB) initConn(ctx context.Context, cn *chpool.Conn) error {
if cn.Inited {
return nil
}
cn.Inited = true
return db.hello(ctx, cn)
}
func (db *DB) releaseConn(cn *chpool.Conn, err error) {
if isBadConn(err, false) || cn.Closed() {
db.pool.Remove(cn, err)
} else {
db.pool.Put(cn)
}
}
func (db *DB) withConn(ctx context.Context, fn func(*chpool.Conn) error) error {
err := db._withConn(ctx, fn)
atomic.AddUint64(&db.stats.Queries, 1)
if err != nil {
atomic.AddUint64(&db.stats.Errors, 1)
}
return err
}
func (db *DB) _withConn(ctx context.Context, fn func(*chpool.Conn) error) error {
cn, err := db.getConn(ctx)
if err != nil {
return err
}
var done chan struct{}
if ctxDone := ctx.Done(); ctxDone != nil {
done = make(chan struct{})
go func() {
select {
case <-done:
// fn has finished, skip cancel
case <-ctxDone:
db.cancelConn(ctx, cn)
// Signal end of conn use.
done <- struct{}{}
}
}()
}
defer func() {
if done != nil {
select {
case <-done: // wait for cancel to finish request
case done <- struct{}{}: // signal fn finish, skip cancel goroutine
}
}
db.releaseConn(cn, err)
}()
// err is used in releaseConn above
err = fn(cn)
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
db.cancelConn(ctx, cn)
}
return err
}
func (db *DB) cancelConn(ctx context.Context, cn *chpool.Conn) {
if err := cn.WithWriter(ctx, db.cfg.WriteTimeout, func(wr *chproto.Writer) {
writeCancel(wr)
}); err != nil {
internal.Logger.Printf("writeCancel failed: %s", err)
}
_ = cn.Close()
}
func (db *DB) Ping(ctx context.Context) error {
return db.withConn(ctx, func(cn *chpool.Conn) error {
if err := cn.WithWriter(ctx, db.cfg.WriteTimeout, func(wr *chproto.Writer) {
writePing(wr)
}); err != nil {
return err
}
return cn.WithReader(ctx, db.cfg.ReadTimeout, func(rd *chproto.Reader) error {
return readPong(rd)
})
})
}
func (db *DB) Exec(query string, args ...any) (sql.Result, error) {
return db.ExecContext(context.Background(), query, args...)
}
func (db *DB) ExecContext(
ctx context.Context, query string, args ...any,
) (sql.Result, error) {
query = db.FormatQuery(query, args...)
ctx, evt := db.beforeQuery(ctx, nil, query, args, nil)
res, err := db.exec(ctx, query)
2022-01-23 09:36:24 +02:00
db.afterQuery(ctx, evt, res, err)
return res, err
}
func (db *DB) exec(ctx context.Context, query string) (*result, error) {
var res *result
var lastErr error
for attempt := 0; attempt <= db.cfg.MaxRetries; attempt++ {
if attempt > 0 {
lastErr = internal.Sleep(ctx, db.retryBackoff(attempt-1))
if lastErr != nil {
break
}
}
res, lastErr = db._exec(ctx, query)
if !db.shouldRetry(lastErr) {
break
}
}
return res, lastErr
}
func (db *DB) _exec(ctx context.Context, query string) (*result, error) {
var res *result
err := db.withConn(ctx, func(cn *chpool.Conn) error {
if err := cn.WithWriter(ctx, db.cfg.WriteTimeout, func(wr *chproto.Writer) {
db.writeQuery(ctx, cn, wr, query)
db.writeBlock(ctx, wr, nil)
}); err != nil {
return err
}
return cn.WithReader(ctx, db.cfg.ReadTimeout, func(rd *chproto.Reader) error {
var err error
res, err = db.readDataBlocks(cn, rd)
return err
})
})
return res, err
}
2022-01-23 09:36:24 +02:00
func (db *DB) Query(query string, args ...any) (*Rows, error) {
return db.QueryContext(context.Background(), query, args...)
}
func (db *DB) QueryContext(
ctx context.Context, query string, args ...any,
) (*Rows, error) {
query = db.FormatQuery(query, args...)
ctx, evt := db.beforeQuery(ctx, nil, query, args, nil)
blocks, err := db.query(ctx, query)
db.afterQuery(ctx, evt, nil, err)
2022-01-23 09:36:24 +02:00
if err != nil {
return nil, err
}
return newRows(ctx, blocks), nil
2022-01-23 09:36:24 +02:00
}
func (db *DB) QueryRow(query string, args ...any) *Row {
return db.QueryRowContext(context.Background(), query, args...)
}
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
rows, err := db.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err}
}
func (db *DB) query(ctx context.Context, query string) (*blockIter, error) {
var blocks *blockIter
2022-01-23 09:36:24 +02:00
var lastErr error
2022-01-23 09:36:24 +02:00
for attempt := 0; attempt <= db.cfg.MaxRetries; attempt++ {
if attempt > 0 {
lastErr = internal.Sleep(ctx, db.retryBackoff(attempt-1))
if lastErr != nil {
break
}
}
blocks, lastErr = db._query(ctx, query)
2022-01-23 09:36:24 +02:00
if !db.shouldRetry(lastErr) {
break
}
}
return blocks, lastErr
}
func (db *DB) _query(ctx context.Context, query string) (*blockIter, error) {
cn, err := db.getConn(ctx)
if err != nil {
return nil, err
2022-01-23 09:36:24 +02:00
}
if err := cn.WithWriter(ctx, db.cfg.WriteTimeout, func(wr *chproto.Writer) {
db.writeQuery(ctx, cn, wr, query)
db.writeBlock(ctx, wr, nil)
}); err != nil {
return nil, err
}
2022-01-23 09:36:24 +02:00
return newBlockIter(db, cn), nil
2022-01-23 09:36:24 +02:00
}
func (db *DB) insert(
ctx context.Context, model TableModel, query string, fields []*chschema.Field,
) (*result, error) {
block := model.Block(fields)
var res *result
var lastErr error
for attempt := 0; attempt <= db.cfg.MaxRetries; attempt++ {
if attempt > 0 {
lastErr = internal.Sleep(ctx, db.retryBackoff(attempt-1))
if lastErr != nil {
break
}
}
res, lastErr = db._insert(ctx, model, query, block)
if !db.shouldRetry(lastErr) {
break
}
}
return res, lastErr
}
func (db *DB) _insert(
ctx context.Context, model TableModel, query string, block *chschema.Block,
) (*result, error) {
var res *result
err := db.withConn(ctx, func(cn *chpool.Conn) error {
if err := cn.WithWriter(ctx, db.cfg.WriteTimeout, func(wr *chproto.Writer) {
db.writeQuery(ctx, cn, wr, query)
db.writeBlock(ctx, wr, nil)
2022-01-23 09:36:24 +02:00
}); err != nil {
return err
}
if err := cn.WithReader(ctx, db.cfg.ReadTimeout, func(rd *chproto.Reader) error {
_, err := db.readSampleBlock(rd)
2022-01-23 09:36:24 +02:00
return err
}); err != nil {
return err
}
if err := cn.WithWriter(ctx, db.cfg.WriteTimeout, func(wr *chproto.Writer) {
db.writeBlock(ctx, wr, block)
db.writeBlock(ctx, wr, nil)
2022-01-23 09:36:24 +02:00
}); err != nil {
return err
}
return cn.WithReader(ctx, db.cfg.ReadTimeout, func(rd *chproto.Reader) error {
var err error
res, err = readPacket(cn, rd)
2022-01-23 09:36:24 +02:00
if err != nil {
return err
}
res.affected = block.NumRow
return nil
})
})
return res, err
}
func (db *DB) NewSelect() *SelectQuery {
return NewSelectQuery(db)
}
func (db *DB) NewInsert() *InsertQuery {
return NewInsertQuery(db)
}
func (db *DB) NewCreateTable() *CreateTableQuery {
return NewCreateTableQuery(db)
}
func (db *DB) NewDropTable() *DropTableQuery {
return NewDropTableQuery(db)
}
func (db *DB) NewTruncateTable() *TruncateTableQuery {
return NewTruncateTableQuery(db)
}
func (db *DB) ResetModel(ctx context.Context, models ...any) error {
for _, model := range models {
if _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx); err != nil {
return err
}
if _, err := db.NewCreateTable().Model(model).Exec(ctx); err != nil {
return err
}
}
return nil
}
func (db *DB) Formatter() chschema.Formatter {
return db.fmter
}
func (db *DB) WithFormatter(fmter chschema.Formatter) *DB {
clone := db.clone()
clone.fmter = fmter
return clone
}
func (db *DB) shouldRetry(err error) bool {
switch err {
case driver.ErrBadConn:
return true
case nil, context.Canceled, context.DeadlineExceeded:
return false
}
if err, ok := err.(*Error); ok {
// https://github.com/ClickHouse/ClickHouse/blob/master/src/Common/ErrorCodes.cpp
const (
timeoutExceeded = 159
tooManySimultaneousQueries = 202
memoryLimitExceeded = 241
)
switch err.Code {
case timeoutExceeded, tooManySimultaneousQueries, memoryLimitExceeded:
return true
}
}
return false
}
func (db *DB) retryBackoff(attempt int) time.Duration {
return internal.RetryBackoff(
attempt, db.cfg.MinRetryBackoff, db.cfg.MaxRetryBackoff)
}
func (db *DB) FormatQuery(query string, args ...any) string {
return db.fmter.FormatQuery(query, args...)
}
func (db *DB) makeQueryBytes() []byte {
// TODO: make this configurable?
return make([]byte, 0, 4096)
}
//------------------------------------------------------------------------------
// Rows is the result of a query. Its cursor starts before the first row of the result set.
// Use Next to advance from row to row.
type Rows struct {
ctx context.Context
blocks *blockIter
block *chschema.Block
2022-01-23 09:36:24 +02:00
rowIndex int
hasNext bool
closed bool
2022-01-23 09:36:24 +02:00
}
func newRows(ctx context.Context, blocks *blockIter) *Rows {
return &Rows{
ctx: ctx,
blocks: blocks,
block: new(chschema.Block),
}
2022-01-23 09:36:24 +02:00
}
func (rs *Rows) Close() error {
if !rs.closed {
for rs.blocks.Next(rs.ctx, rs.block) {
}
rs.close()
}
2022-01-23 09:36:24 +02:00
return nil
}
func (rs *Rows) close() {
rs.closed = true
_ = rs.blocks.Close()
}
2022-01-23 09:36:24 +02:00
func (rs *Rows) ColumnTypes() ([]*sql.ColumnType, error) {
return nil, errors.New("not implemented")
}
func (rs *Rows) Columns() ([]string, error) {
return nil, errors.New("not implemented")
}
func (rs *Rows) Err() error {
return rs.blocks.Err()
2022-01-23 09:36:24 +02:00
}
func (rs *Rows) Next() bool {
if rs.closed {
return false
2022-01-23 09:36:24 +02:00
}
for rs.rowIndex >= rs.block.NumRow {
if !rs.blocks.Next(rs.ctx, rs.block) {
rs.close()
return false
2022-01-23 09:36:24 +02:00
}
rs.rowIndex = 0
2022-01-23 09:36:24 +02:00
}
rs.hasNext = true
rs.rowIndex++
return true
2022-01-23 09:36:24 +02:00
}
func (rs *Rows) NextResultSet() bool {
return false
}
func (rs *Rows) Scan(dest ...any) error {
if rs.closed {
return rs.Err()
}
if !rs.hasNext {
2022-01-23 09:36:24 +02:00
return errors.New("ch: Scan called without calling Next")
}
rs.hasNext = false
2022-01-23 09:36:24 +02:00
if rs.block.NumColumn != len(dest) {
return fmt.Errorf("ch: got %d columns, but Scan has %d values",
rs.block.NumColumn, len(dest))
}
for i, col := range rs.block.Columns {
if err := col.ConvertAssign(rs.rowIndex-1, reflect.ValueOf(dest[i]).Elem()); err != nil {
return err
}
}
return nil
}
type Row struct {
rows *Rows
err error
}
func (r *Row) Err() error {
return r.err
}
func (r *Row) Scan(dest ...any) error {
if r.err != nil {
return r.err
}
defer r.rows.Close()
if r.rows.Next() {
return r.rows.Scan(dest...)
}
return sql.ErrNoRows
}