1
0
mirror of https://github.com/uptrace/go-clickhouse.git synced 2024-11-21 17:56:48 +02:00

feat: initial commit

This commit is contained in:
Vladimir Mihailenco 2022-01-23 09:36:24 +02:00
commit 092a2dbf28
125 changed files with 14450 additions and 0 deletions

1
.github/FUNDING.yml vendored Normal file
View File

@ -0,0 +1 @@
custom: ['https://uptrace.dev/sponsor']

10
.github/dependabot.yml vendored Normal file
View File

@ -0,0 +1,10 @@
version: 2
updates:
- package-ecosystem: gomod
directory: /
schedule:
interval: weekly
- package-ecosystem: github-actions
directory: /
schedule:
interval: weekly

36
.github/workflows/build.yml vendored Normal file
View File

@ -0,0 +1,36 @@
name: Go
on:
push:
branches: [master]
pull_request:
branches: [master]
jobs:
build:
name: build
runs-on: ubuntu-latest
services:
clickhouse:
image: clickhouse/clickhouse-server:21.12
options: >-
--health-cmd "clickhouse-client -q 'select 1'" --health-interval 10s --health-timeout 5s
--health-retries 5
ports:
- 9000:9000
steps:
- name: Set up ${{ matrix.go-version }}
uses: actions/setup-go@v2
with:
go-version: 1.18.0-beta1
stable: false
- name: Checkout code
uses: actions/checkout@v3
- name: Test
run: make test
env:
CH: clickhouse://localhost:9000/default?sslmode=disable

11
.github/workflows/commitlint.yml vendored Normal file
View File

@ -0,0 +1,11 @@
name: Lint Commit Messages
on: [pull_request]
jobs:
commitlint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- uses: wagoid/commitlint-github-action@v4

19
.github/workflows/golangci-lint.yml vendored Normal file
View File

@ -0,0 +1,19 @@
name: golangci-lint
on:
push:
tags:
- v*
branches:
- master
- main
pull_request:
jobs:
golangci:
name: lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: golangci-lint
uses: golangci/golangci-lint-action@v3.1.0

18
.github/workflows/release.yml vendored Normal file
View File

@ -0,0 +1,18 @@
name: Releases
on:
push:
tags:
- 'v*'
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: ncipollo/release-action@v1
with:
body:
Please refer to
[CHANGELOG.md](https://github.com/uptrace/go-clickhouse/blob/master/CHANGELOG.md) for
details

6
.prettierrc.yml Normal file
View File

@ -0,0 +1,6 @@
trailingComma: all
tabWidth: 2
semi: false
singleQuote: true
proseWrap: always
printWidth: 100

19
CHANGELOG.md Normal file
View File

@ -0,0 +1,19 @@
# [](https://github.com/uptrace/go-clickhouse/compare/v0.1.0...v) (2022-03-17)
# (2022-03-09)
### Bug Fixes
- parse query settings from DSN
([6dd2a1a](https://github.com/uptrace/go-clickhouse/commit/6dd2a1adde7a6992d25bf319ce447556fd21aa39))
### Features
- add CreateTableQuery.Order
([50192cd](https://github.com/uptrace/go-clickhouse/commit/50192cd8fb1bb6aa65f50daee5e7b11435627255))
- add migrations example
([98ecef3](https://github.com/uptrace/go-clickhouse/commit/98ecef3fdb7b10dc947fccb31d641a4ebce2f650))
- initial commit
([2f20600](https://github.com/uptrace/go-clickhouse/commit/2f20600f5e4fc9a20e12f1f027e65e0c2bd4f046))

24
LICENSE Normal file
View File

@ -0,0 +1,24 @@
Copyright (c) 2021 Vladimir Mihailenco. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

22
Makefile Normal file
View File

@ -0,0 +1,22 @@
ALL_GO_MOD_DIRS := $(shell find . -type f -name 'go.mod' -exec dirname {} \; | sort)
test:
set -e; for dir in $(ALL_GO_MOD_DIRS); do \
echo "go test in $${dir}"; \
(cd "$${dir}" && \
go test && \
env GOOS=linux GOARCH=386 go test && \
go vet); \
done
go_mod_tidy:
set -e; for dir in $(ALL_GO_MOD_DIRS); do \
echo "go mod tidy in $${dir}"; \
(cd "$${dir}" && \
go get -u ./... && \
go mod tidy); \
done
fmt:
gofmt -w -s ./
goimports -w -local github.com/uptrace/go-clickhouse ./

109
README.md Normal file
View File

@ -0,0 +1,109 @@
# ClickHouse client for Go 1.18+
[![build workflow](https://github.com/uptrace/go-clickhouse/actions/workflows/build.yml/badge.svg)](https://github.com/uptrace/go-clickhouse/actions)
[![PkgGoDev](https://pkg.go.dev/badge/github.com/uptrace/go-clickhouse/ch)](https://pkg.go.dev/github.com/go-clickhouse/ch)
[![Documentation](https://img.shields.io/badge/ch-documentation-informational)](https://clickhouse.uptrace.dev/)
[![Chat](https://discordapp.com/api/guilds/752070105847955518/widget.png)](https://discord.gg/rWtp5Aj)
This client uses native protocol to communicate with ClickHouse server and requires Go 1.18+ in
order to use generics. This is not a database/sql driver, but the API is compatible.
Main features are:
- ClickHouse native protocol support and efficient column-oriented design.
- API compatible with database/sql.
- [Bun](https://github.com/uptrace/bun/)-like query builder.
- [Selecting](https://clickhouse.uptrace.dev/guide/query-select.html) into scalars, structs, maps,
slices of maps/structs/scalars.
- `Array(T)` including nested arrays.
- Enums and `LowCardinality(String)`.
- `Nullable(T)` except `Nullable(Array(T))`.
- [Migrations](https://clickhouse.uptrace.dev/guide/migrations.html).
- [OpenTelemetry](https://clickhouse.uptrace.dev/guide/monitoring.html) support.
- In production at [Uptrace](https://uptrace.dev/)
Unsupported:
- Server timezones other than UTC.
Resources:
- [**Get started**](https://clickhouse.uptrace.dev/guide/getting-started.html)
- [Examples](https://github.com/uptrace/go-clickhouse/tree/master/example)
- [Discussions](https://github.com/uptrace/go-clickhouse/discussions)
- [Chat](https://discord.gg/rWtp5Aj)
- [Reference](https://pkg.go.dev/github.com/uptrace/go-clickhouse/ch)
- [Example app](https://github.com/uptrace/uptrace)
## Benchmark
**Read** (best of 3 runs):
| Library | Timing |
| ---------------------------------------------------------------------------------------------------------------- | ------ |
| [This library](example/benchmark/read-native/main.go) | 655ms |
| [ClickHouse/clickhouse-go](https://github.com/ClickHouse/clickhouse-go/blob/v2/benchmark/v2/read-native/main.go) | 849ms |
**Write** (best of 3 runs):
| Library | Timing |
| -------------------------------------------------------------------------------------------------------------------------- | ------ |
| [This library](example/benchmark/write-native-columnar/main.go) | 475ms |
| [ClickHouse/clickhouse-go](https://github.com/ClickHouse/clickhouse-go/blob/v2/benchmark/v2/write-native-columnar/main.go) | 881ms |
## Example
A [basic](example/basic) example:
```go
package main
import (
"context"
"fmt"
"time"
"github.com/uptrace/go-clickhouse/ch"
"github.com/uptrace/go-clickhouse/chdebug"
)
type Model struct {
ch.CHModel `ch:"partition:toYYYYMM(time)"`
ID uint64
Text string `ch:",lc"`
Time time.Time `ch:",pk"`
}
func main() {
ctx := context.Background()
db := ch.Connect(ch.WithDatabase("test"))
db.AddQueryHook(chdebug.NewQueryHook(chdebug.WithVerbose(true)))
if err := db.Ping(ctx); err != nil {
panic(err)
}
var num int
if err := db.QueryRowContext(ctx, "SELECT 123").Scan(&num); err != nil {
panic(err)
}
fmt.Println(num)
if err := db.ResetModel(ctx, (*Model)(nil)); err != nil {
panic(err)
}
src := &Model{ID: 1, Text: "hello", Time: time.Now()}
if _, err := db.NewInsert().Model(src).Exec(ctx); err != nil {
panic(err)
}
dest := new(Model)
if err := db.NewSelect().Model(dest).Where("id = ?", src.ID).Limit(1).Scan(ctx); err != nil {
panic(err)
}
fmt.Println(dest)
}
```

120
ch/ch.go Normal file
View File

@ -0,0 +1,120 @@
package ch
import (
"database/sql"
"errors"
"fmt"
"net"
"reflect"
"github.com/uptrace/go-clickhouse/ch/chschema"
)
type (
Safe = chschema.Safe
Ident = chschema.Ident
CHModel = chschema.CHModel
AfterScanRowHook = chschema.AfterScanRowHook
)
func SafeQuery(query string, args ...any) chschema.QueryWithArgs {
return chschema.SafeQuery(query, args)
}
//------------------------------------------------------------------------------
type result struct {
model Model
affected int
}
var _ sql.Result = (*result)(nil)
func (res *result) Model() Model {
return res.model
}
func (res *result) RowsAffected() (int64, error) {
return int64(res.affected), nil
}
func (res *result) LastInsertId() (int64, error) {
return 0, errors.New("not implemented")
}
//------------------------------------------------------------------------------
type Error struct {
Code int32
Name string
Message string
StackTrace string
nested error // TODO: wrap/unwrap
}
func (exc *Error) Error() string {
return exc.Name + ": " + exc.Message
}
func isBadConn(err error, allowTimeout bool) bool {
if err == nil {
return false
}
if allowTimeout {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return !netErr.Temporary()
}
}
return true
}
//------------------------------------------------------------------------------
type InValues struct {
slice reflect.Value
err error
}
var _ chschema.QueryAppender = InValues{}
func In(slice any) InValues {
v := reflect.ValueOf(slice)
if v.Kind() != reflect.Slice {
return InValues{
err: fmt.Errorf("ch: In(non-slice %T)", slice),
}
}
return InValues{
slice: v,
}
}
func (in InValues) AppendQuery(fmter chschema.Formatter, b []byte) (_ []byte, err error) {
if in.err != nil {
return nil, in.err
}
return appendIn(fmter, b, in.slice), nil
}
func appendIn(fmter chschema.Formatter, b []byte, slice reflect.Value) []byte {
sliceLen := slice.Len()
for i := 0; i < sliceLen; i++ {
if i > 0 {
b = append(b, ", "...)
}
elem := slice.Index(i)
if elem.Kind() == reflect.Interface {
elem = elem.Elem()
}
if elem.Kind() == reflect.Slice {
b = append(b, '(')
b = appendIn(fmter, b, elem)
b = append(b, ')')
} else {
b = chschema.AppendValue(fmter, b, elem)
}
}
return b
}

123
ch/chpool/conn.go Normal file
View File

@ -0,0 +1,123 @@
package chpool
import (
"context"
"net"
"sync/atomic"
"time"
"github.com/uptrace/go-clickhouse/ch/chproto"
)
var noDeadline = time.Time{}
type Conn struct {
netConn net.Conn
rd *chproto.Reader
wr *chproto.Writer
ServerInfo chproto.ServerInfo
pooled bool
Inited bool
createdAt time.Time
usedAt int64 // atomic
closed uint32 // atomic
}
func NewConn(netConn net.Conn) *Conn {
cn := &Conn{
netConn: netConn,
rd: chproto.NewReader(netConn),
wr: chproto.NewWriter(netConn),
createdAt: time.Now(),
}
return cn
}
func (cn *Conn) UsedAt() time.Time {
unix := atomic.LoadInt64(&cn.usedAt)
return time.Unix(unix, 0)
}
func (cn *Conn) SetUsedAt(tm time.Time) {
atomic.StoreInt64(&cn.usedAt, tm.Unix())
}
func (cn *Conn) RemoteAddr() net.Addr {
return cn.netConn.RemoteAddr()
}
func (cn *Conn) WithReader(
ctx context.Context,
timeout time.Duration,
fn func(rd *chproto.Reader) error,
) error {
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
return err
}
if err := fn(cn.rd); err != nil {
return err
}
return nil
}
func (cn *Conn) WithWriter(
ctx context.Context,
timeout time.Duration,
fn func(wb *chproto.Writer),
) error {
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
return err
}
fn(cn.wr)
if err := cn.wr.Flush(); err != nil {
return err
}
return nil
}
func (cn *Conn) Close() error {
if !atomic.CompareAndSwapUint32(&cn.closed, 0, 1) {
return nil
}
return cn.netConn.Close()
}
func (cn *Conn) Closed() bool {
return atomic.LoadUint32(&cn.closed) == 1
}
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
tm := time.Now()
cn.SetUsedAt(tm)
if timeout > 0 {
tm = tm.Add(timeout)
}
if ctx != nil {
deadline, ok := ctx.Deadline()
if ok {
if timeout == 0 {
return deadline
}
if deadline.Before(tm) {
return deadline
}
return tm
}
}
if timeout > 0 {
return tm
}
return noDeadline
}

455
ch/chpool/pool.go Normal file
View File

@ -0,0 +1,455 @@
package chpool
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"time"
"github.com/uptrace/go-clickhouse/ch/internal"
)
var (
ErrClosed = errors.New("ch: database is closed")
ErrPoolTimeout = errors.New("ch: connection pool timeout")
)
var timers = sync.Pool{
New: func() any {
t := time.NewTimer(time.Hour)
t.Stop()
return t
},
}
//------------------------------------------------------------------------------
type BadConnError struct {
wrapped error
}
var _ error = (*BadConnError)(nil)
func (e BadConnError) Error() string {
s := "ch: Conn is in a bad state"
if e.wrapped != nil {
s += ": " + e.wrapped.Error()
}
return s
}
func (e BadConnError) Unwrap() error {
return e.wrapped
}
//------------------------------------------------------------------------------
// Stats contains pool state information and accumulated stats.
type Stats struct {
Hits uint32 // number of times free connection was found in the pool
Misses uint32 // number of times free connection was NOT found in the pool
Timeouts uint32 // number of times a wait timeout occurred
TotalConns uint32 // number of total connections in the pool
IdleConns uint32 // number of idle connections in the pool
StaleConns uint32 // number of stale connections removed from the pool
}
type Pooler interface {
NewConn(context.Context) (*Conn, error)
CloseConn(*Conn) error
Get(context.Context) (*Conn, error)
Put(*Conn)
Remove(*Conn, error)
Len() int
IdleLen() int
Stats() *Stats
Close() error
}
type Config struct {
Dialer func(context.Context) (net.Conn, error)
OnClose func(*Conn) error
PoolSize int
PoolTimeout time.Duration
MinIdleConns int
MaxIdleConns int
MaxConnAge time.Duration
}
type ConnPool struct {
cfg *Config
dialErrorsNum uint32 // atomic
_closed uint32 // atomic
lastDialErrorMu sync.RWMutex
lastDialError error
queue chan struct{}
stats Stats
connsMu sync.Mutex
conns []*Conn
idleConns []*Conn
poolSize int
idleConnsLen int
}
var _ Pooler = (*ConnPool)(nil)
func New(cfg *Config) *ConnPool {
p := &ConnPool{
cfg: cfg,
queue: make(chan struct{}, cfg.PoolSize),
conns: make([]*Conn, 0, cfg.PoolSize),
idleConns: make([]*Conn, 0, cfg.PoolSize),
}
p.connsMu.Lock()
p.checkMinIdleConns()
p.connsMu.Unlock()
return p
}
func (p *ConnPool) checkMinIdleConns() {
if p.cfg.MinIdleConns == 0 {
return
}
for p.poolSize < p.cfg.PoolSize && p.idleConnsLen < p.cfg.MinIdleConns {
p.poolSize++
p.idleConnsLen++
go func() {
err := p.addIdleConn()
if err != nil {
p.connsMu.Lock()
p.poolSize--
p.idleConnsLen--
p.connsMu.Unlock()
}
}()
}
}
func (p *ConnPool) addIdleConn() error {
cn, err := p.dialConn(context.TODO(), true)
if err != nil {
return err
}
p.connsMu.Lock()
p.conns = append(p.conns, cn)
p.idleConns = append(p.idleConns, cn)
p.connsMu.Unlock()
return nil
}
func (p *ConnPool) NewConn(c context.Context) (*Conn, error) {
return p.newConn(c, false)
}
func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) {
cn, err := p.dialConn(c, pooled)
if err != nil {
return nil, err
}
p.connsMu.Lock()
p.conns = append(p.conns, cn)
if pooled {
// If pool is full remove the cn on next Put.
if p.poolSize >= p.cfg.PoolSize {
cn.pooled = false
} else {
p.poolSize++
}
}
p.connsMu.Unlock()
return cn, nil
}
func (p *ConnPool) dialConn(c context.Context, pooled bool) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.cfg.PoolSize) {
return nil, p.getLastDialError()
}
netConn, err := p.cfg.Dialer(c)
if err != nil {
p.setLastDialError(err)
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) {
go p.tryDial()
}
return nil, err
}
cn := NewConn(netConn)
cn.pooled = pooled
return cn, nil
}
func (p *ConnPool) tryDial() {
for {
if p.closed() {
return
}
conn, err := p.cfg.Dialer(context.TODO())
if err != nil {
p.setLastDialError(err)
time.Sleep(time.Second)
continue
}
atomic.StoreUint32(&p.dialErrorsNum, 0)
_ = conn.Close()
return
}
}
func (p *ConnPool) setLastDialError(err error) {
p.lastDialErrorMu.Lock()
p.lastDialError = err
p.lastDialErrorMu.Unlock()
}
func (p *ConnPool) getLastDialError() error {
p.lastDialErrorMu.RLock()
err := p.lastDialError
p.lastDialErrorMu.RUnlock()
return err
}
// Get returns an existing connection from the pool or creates a new one.
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
err := p.waitTurn(ctx)
if err != nil {
return nil, err
}
for {
p.connsMu.Lock()
cn := p.popIdle()
p.connsMu.Unlock()
if cn == nil {
break
}
if p.cfg.MaxConnAge > 0 && time.Since(cn.createdAt) >= p.cfg.MaxConnAge {
_ = p.CloseConn(cn)
continue
}
atomic.AddUint32(&p.stats.Hits, 1)
return cn, nil
}
atomic.AddUint32(&p.stats.Misses, 1)
newcn, err := p.newConn(ctx, true)
if err != nil {
p.freeTurn()
return nil, err
}
return newcn, nil
}
func (p *ConnPool) getTurn() {
p.queue <- struct{}{}
}
func (p *ConnPool) waitTurn(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
select {
case p.queue <- struct{}{}:
return nil
default:
}
timer := timers.Get().(*time.Timer)
timer.Reset(p.cfg.PoolTimeout)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
timers.Put(timer)
return ctx.Err()
case p.queue <- struct{}{}:
if !timer.Stop() {
<-timer.C
}
timers.Put(timer)
return nil
case <-timer.C:
timers.Put(timer)
atomic.AddUint32(&p.stats.Timeouts, 1)
return ErrPoolTimeout
}
}
func (p *ConnPool) freeTurn() {
<-p.queue
}
func (p *ConnPool) popIdle() *Conn {
if len(p.idleConns) == 0 {
return nil
}
idx := len(p.idleConns) - 1
cn := p.idleConns[idx]
p.idleConns = p.idleConns[:idx]
p.idleConnsLen--
p.checkMinIdleConns()
return cn
}
func (p *ConnPool) Put(cn *Conn) {
if cn.rd.Buffered() > 0 {
internal.Logger.Printf("Conn has unread data")
p.Remove(cn, BadConnError{})
return
}
if !cn.pooled {
p.Remove(cn, nil)
return
}
var atMaxCap bool
p.connsMu.Lock()
if len(p.idleConns) < p.cfg.MaxIdleConns {
p.idleConns = append(p.idleConns, cn)
p.idleConnsLen++
} else {
atMaxCap = true
}
p.connsMu.Unlock()
if atMaxCap {
p.Remove(cn, nil)
}
p.freeTurn()
}
func (p *ConnPool) Remove(cn *Conn, reason error) {
p.removeConnWithLock(cn)
p.freeTurn()
_ = p.closeConn(cn)
}
func (p *ConnPool) CloseConn(cn *Conn) error {
p.removeConnWithLock(cn)
return p.closeConn(cn)
}
func (p *ConnPool) removeConnWithLock(cn *Conn) {
p.connsMu.Lock()
p.removeConn(cn)
p.connsMu.Unlock()
}
func (p *ConnPool) removeConn(cn *Conn) {
for i, c := range p.conns {
if c == cn {
p.conns = append(p.conns[:i], p.conns[i+1:]...)
if cn.pooled {
p.poolSize--
p.checkMinIdleConns()
}
return
}
}
}
func (p *ConnPool) closeConn(cn *Conn) error {
if p.cfg.OnClose != nil {
_ = p.cfg.OnClose(cn)
}
return cn.Close()
}
// Len returns total number of connections.
func (p *ConnPool) Len() int {
p.connsMu.Lock()
n := len(p.conns)
p.connsMu.Unlock()
return n
}
// IdleLen returns number of idle connections.
func (p *ConnPool) IdleLen() int {
p.connsMu.Lock()
n := p.idleConnsLen
p.connsMu.Unlock()
return n
}
func (p *ConnPool) Stats() *Stats {
idleLen := p.IdleLen()
return &Stats{
Hits: atomic.LoadUint32(&p.stats.Hits),
Misses: atomic.LoadUint32(&p.stats.Misses),
Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
TotalConns: uint32(p.Len()),
IdleConns: uint32(idleLen),
StaleConns: atomic.LoadUint32(&p.stats.StaleConns),
}
}
func (p *ConnPool) closed() bool {
return atomic.LoadUint32(&p._closed) == 1
}
func (p *ConnPool) Close() error {
if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) {
return ErrClosed
}
var firstErr error
p.connsMu.Lock()
for _, cn := range p.conns {
if err := p.closeConn(cn); err != nil && firstErr == nil {
firstErr = err
}
}
p.conns = nil
p.poolSize = 0
p.idleConns = nil
p.idleConnsLen = 0
p.connsMu.Unlock()
return firstErr
}

124
ch/chproto/lz4_reader.go Normal file
View File

@ -0,0 +1,124 @@
package chproto
import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/pierrec/lz4/v4"
)
var errUnreadData = errors.New("ch: lz4 reader was closed with unread data")
type lz4Reader struct {
rd *bufio.Reader
header []byte
data []byte
pos int
}
func newLZ4Reader(r *bufio.Reader) *lz4Reader {
return &lz4Reader{
rd: r,
header: make([]byte, headerSize),
}
}
func (r *lz4Reader) Init() {}
func (r *lz4Reader) Release() error {
var err error
if r.Buffered() > 0 {
err = errUnreadData
}
r.data = nil
r.pos = 0
return err
}
func (r *lz4Reader) Buffered() int {
return len(r.data) - r.pos
}
func (r *lz4Reader) Read(buf []byte) (int, error) {
var nread int
if r.pos < len(r.data) {
n := copy(buf, r.data[r.pos:])
nread += n
r.pos += n
}
for nread < len(buf) {
if err := r.readData(); err != nil {
return nread, err
}
n := copy(buf[nread:], r.data)
nread += n
r.pos = n
}
return nread, nil
}
func (r *lz4Reader) ReadByte() (byte, error) {
if r.pos == len(r.data) {
if err := r.readData(); err != nil {
return 0, err
}
}
if r.pos < len(r.data) {
c := r.data[r.pos]
r.pos++
return c, nil
}
return 0, io.EOF
}
func (r *lz4Reader) readData() error {
if r.pos != len(r.data) {
panic("not reached")
}
_, err := io.ReadFull(r.rd, r.header)
if err != nil {
return err
}
if r.header[16] != lz4Compression {
return fmt.Errorf("ch: unsupported compression method: 0x%02x", r.header[16])
}
compressedSize := int(binary.LittleEndian.Uint32(r.header[17:])) - compressionHeaderSize
uncompressedSize := int(binary.LittleEndian.Uint32(r.header[21:]))
zdata := make([]byte, compressedSize)
r.data = grow(r.data, uncompressedSize)
if _, err := io.ReadFull(r.rd, zdata); err != nil {
return err
}
if _, err := lz4.UncompressBlock(zdata, r.data); err != nil {
return err
}
r.pos = 0
return nil
}
func grow(b []byte, n int) []byte {
if cap(b) < n {
return make([]byte, n)
}
return b[:n]
}

152
ch/chproto/lz4_writer.go Normal file
View File

@ -0,0 +1,152 @@
package chproto
import (
"bufio"
"encoding/binary"
"sync"
"github.com/pierrec/lz4/v4"
"github.com/uptrace/go-clickhouse/ch/internal"
"github.com/uptrace/go-clickhouse/ch/internal/cityhash102"
)
const (
noCompression = 0x02
lz4Compression = 0x82
zstdCompression = 0x90
)
const (
checksumSize = 16 // city hash 128
compressionHeaderSize = 1 + 4 + 4 // method + compressed + uncompressed
headerSize = checksumSize + compressionHeaderSize
blockSize = 1 << 20 // 1 MB
)
type writeBuffer struct {
buf []byte
}
var writeBufferPool = sync.Pool{
New: func() any {
return &writeBuffer{
buf: make([]byte, blockSize),
}
},
}
func getWriterBuffer() *writeBuffer {
return writeBufferPool.Get().(*writeBuffer)
}
func putWriterBuffer(db *writeBuffer) {
writeBufferPool.Put(db)
}
//------------------------------------------------------------------------------
type lz4Writer struct {
wr *bufio.Writer
data *writeBuffer
pos int
}
func newLZ4Writer(w *bufio.Writer) *lz4Writer {
return &lz4Writer{
wr: w,
}
}
func (w *lz4Writer) Init() {
w.data = getWriterBuffer()
w.pos = 0
}
func (w *lz4Writer) Close() error {
err := w.flush()
putWriterBuffer(w.data)
w.data = nil
return err
}
func (w *lz4Writer) Flush() error {
return w.Close()
}
func (w *lz4Writer) WriteByte(c byte) error {
w.data.buf[w.pos] = c
w.pos++
return w.checkFlush()
}
func (w *lz4Writer) WriteString(s string) (int, error) {
return w.Write(internal.Bytes(s))
}
func (w *lz4Writer) Write(data []byte) (int, error) {
var written int
for len(data) > 0 {
n := copy(w.data.buf[w.pos:], data)
data = data[n:]
w.pos += n
if err := w.checkFlush(); err != nil {
return written, err
}
written += n
}
return written, nil
}
func (w *lz4Writer) checkFlush() error {
if w.pos < len(w.data.buf) {
return nil
}
return w.flush()
}
func (w *lz4Writer) flush() error {
if w.pos == 0 {
return nil
}
zlen := headerSize + lz4.CompressBlockBound(w.pos)
zdata := make([]byte, zlen)
compressedSize, err := compress(zdata[headerSize:], w.data.buf[:w.pos])
if err != nil {
return err
}
compressedSize += compressionHeaderSize
zdata[16] = lz4Compression
binary.LittleEndian.PutUint32(zdata[17:], uint32(compressedSize))
binary.LittleEndian.PutUint32(zdata[21:], uint32(w.pos))
checkSum := cityhash102.CityHash128(zdata[16:], uint32(compressedSize))
binary.LittleEndian.PutUint64(zdata[0:], checkSum.Lower64())
binary.LittleEndian.PutUint64(zdata[8:], checkSum.Higher64())
w.wr.Write(zdata[:checksumSize+compressedSize])
w.pos = 0
return nil
}
//------------------------------------------------------------------------------
func compress(dest, src []byte) (int, error) {
if len(src) < 16 {
return uncompressable(dest, src), nil
}
var c lz4.Compressor
return c.CompressBlock(src, dest)
}
func uncompressable(dest, src []byte) int {
dest[0] = byte(len(src)) << 4
copy(dest[1:], src)
return len(src) + 1
}

37
ch/chproto/proto.go Normal file
View File

@ -0,0 +1,37 @@
package chproto
const (
ClientHello = 0
ClientQuery = 1
ClientData = 2
ClientCancel = 3
ClientPing = 4
ClientTablesStatus = 5
ClientKeepAlive = 6
)
const (
CompressionDisabled = 0
CompressionEnabled = 1
)
const (
ServerHello = 0
ServerData = 1
ServerException = 2
ServerProgress = 3
ServerPong = 4
ServerEndOfStream = 5
ServerProfileInfo = 6
ServerTotals = 7
ServerExtremes = 8
ServerTablesStatus = 9
ServerLog = 10
ServerTableColumns = 11
)
const (
QueryNo = 0
QueryInitial = 1
QuerySecondary = 2
)

202
ch/chproto/reader.go Normal file
View File

@ -0,0 +1,202 @@
package chproto
import (
"bufio"
"encoding/binary"
"fmt"
"io"
"math"
"time"
"github.com/uptrace/go-clickhouse/ch/internal"
)
type reader interface {
io.Reader
io.ByteReader
Buffered() int
}
type Reader struct {
br *bufio.Reader
zr *lz4Reader
rd reader // points to br or zr
buf []byte
}
func NewReader(r io.Reader) *Reader {
br := bufio.NewReader(r)
return &Reader{
br: br,
zr: newLZ4Reader(br),
rd: br,
buf: make([]byte, uuidLen),
}
}
func (r *Reader) WithCompression(fn func() error) error {
r.zr.Init()
r.rd = r.zr
firstErr := fn()
r.rd = r.br
if err := r.zr.Release(); err != nil && firstErr == nil {
firstErr = err
}
return firstErr
}
func (r *Reader) Read(buf []byte) (int, error) {
return r.rd.Read(buf)
}
func (r *Reader) Buffered() int {
return r.rd.Buffered()
}
func (r *Reader) Bool() (bool, error) {
c, err := r.rd.ReadByte()
if err != nil {
return false, err
}
return c == 1, nil
}
func (r *Reader) Uvarint() (uint64, error) {
return binary.ReadUvarint(r.rd)
}
func (r *Reader) Uint8() (uint8, error) {
c, err := r.rd.ReadByte()
if err != nil {
return 0, err
}
return c, nil
}
func (r *Reader) Uint16() (uint16, error) {
b, err := r.readNTemp(2)
if err != nil {
return 0, err
}
return binary.LittleEndian.Uint16(b), nil
}
func (r *Reader) Uint32() (uint32, error) {
b, err := r.readNTemp(4)
if err != nil {
return 0, err
}
return binary.LittleEndian.Uint32(b), nil
}
func (r *Reader) Uint64() (uint64, error) {
b, err := r.readNTemp(8)
if err != nil {
return 0, err
}
return binary.LittleEndian.Uint64(b), nil
}
func (r *Reader) Int8() (int8, error) {
num, err := r.Uint8()
return int8(num), err
}
func (r *Reader) Int16() (int16, error) {
num, err := r.Uint16()
return int16(num), err
}
func (r *Reader) Int32() (int32, error) {
num, err := r.Uint32()
return int32(num), err
}
func (r *Reader) Int64() (int64, error) {
num, err := r.Uint64()
return int64(num), err
}
func (r *Reader) Float32() (float32, error) {
num, err := r.Uint32()
if err != nil {
return 0, err
}
return math.Float32frombits(num), nil
}
func (r *Reader) Float64() (float64, error) {
num, err := r.Uint64()
if err != nil {
return 0, err
}
return math.Float64frombits(num), nil
}
func (r *Reader) Bytes() ([]byte, error) {
num, err := r.Uvarint()
if err != nil {
return nil, err
}
b := make([]byte, int(num))
_, err = io.ReadFull(r.rd, b)
if err != nil {
return nil, err
}
return b, nil
}
func (r *Reader) String() (string, error) {
b, err := r.Bytes()
if err != nil {
return "", err
}
return internal.String(b), nil
}
func (r *Reader) UUID(b []byte) error {
if len(b) != uuidLen {
return fmt.Errorf("got %d bytes, wanted %d", len(b), uuidLen)
}
_, err := io.ReadFull(r.rd, b)
if err != nil {
return err
}
packUUID(b)
return nil
}
func (r *Reader) readNTemp(n int) ([]byte, error) {
buf := r.buf[:n]
_, err := io.ReadFull(r.rd, buf)
return buf, err
}
func (r *Reader) DateTime() (time.Time, error) {
sec, err := r.Uint32()
if err != nil {
return time.Time{}, err
}
if sec == 0 {
return time.Time{}, nil
}
return time.Unix(int64(sec), 0), nil
}
func (r *Reader) Date() (time.Time, error) {
days, err := r.Uint16()
if err != nil {
return time.Time{}, err
}
if days == 0 {
return time.Time{}, nil
}
return time.Unix(int64(days)*secsInDay, 0), nil
}

44
ch/chproto/server_info.go Normal file
View File

@ -0,0 +1,44 @@
package chproto
import (
"fmt"
)
type ServerInfo struct {
Name string
MinorVersion uint64
MajorVersion uint64
Revision uint64
}
func (srv *ServerInfo) ReadFrom(rd *Reader) (err error) {
if srv.Name, err = rd.String(); err != nil {
return err
}
if srv.MajorVersion, err = rd.Uvarint(); err != nil {
return err
}
if srv.MinorVersion, err = rd.Uvarint(); err != nil {
return err
}
if srv.Revision, err = rd.Uvarint(); err != nil {
return err
}
timezone, err := rd.String()
if err != nil {
return err
}
if timezone != "UTC" {
return fmt.Errorf("ch: ClickHouse server uses timezone=%q, expected UTC", timezone)
}
if _, err = rd.String(); err != nil { // display name
return err
}
if _, err = rd.Uvarint(); err != nil { // server version patch
return err
}
return nil
}

189
ch/chproto/writer.go Normal file
View File

@ -0,0 +1,189 @@
package chproto
import (
"bufio"
"encoding/binary"
"io"
"math"
"time"
"github.com/uptrace/go-clickhouse/ch/internal"
)
const uuidLen = 16
type writer interface {
io.Writer
io.ByteWriter
Flush() error
}
type Writer struct {
bw *bufio.Writer
zw *lz4Writer
wr writer // points to bw or zw
err error
buf []byte
}
func NewWriter(w io.Writer) *Writer {
bw := bufio.NewWriter(w)
return &Writer{
bw: bw,
zw: newLZ4Writer(bw),
wr: bw,
buf: make([]byte, uuidLen),
}
}
func (w *Writer) WithCompression(fn func() error) {
if w.err != nil {
return
}
w.zw.Init()
w.wr = w.zw
w.err = fn()
if err := w.zw.Close(); err != nil && w.err == nil {
w.err = err
}
w.wr = w.bw
}
func (w *Writer) Flush() (err error) {
if w.err != nil {
err, w.err = w.err, nil
return err
}
return w.wr.Flush()
}
func (w *Writer) Write(b []byte) {
if w.err != nil {
return
}
_, err := w.wr.Write(b)
w.err = err
}
func (w *Writer) writeByte(c byte) {
if w.err != nil {
return
}
w.err = w.wr.WriteByte(c)
}
func (w *Writer) Bool(flag bool) {
var num uint8
if flag {
num = 1
}
w.Uint8(num)
}
func (w *Writer) Uvarint(num uint64) {
n := binary.PutUvarint(w.buf, num)
w.Write(w.buf[:n])
}
func (w *Writer) Uint8(num uint8) {
w.writeByte(num)
}
func (w *Writer) Uint16(num uint16) {
binary.LittleEndian.PutUint16(w.buf, num)
w.Write(w.buf[:2])
}
func (w *Writer) Uint32(num uint32) {
binary.LittleEndian.PutUint32(w.buf, num)
w.Write(w.buf[:4])
}
func (w *Writer) Uint64(num uint64) {
binary.LittleEndian.PutUint64(w.buf, num)
w.Write(w.buf[:8])
}
func (w *Writer) Int8(num int8) {
w.Uint8(uint8(num))
}
func (w *Writer) Int16(num int16) {
w.Uint16(uint16(num))
}
func (w *Writer) Int32(num int32) {
w.Uint32(uint32(num))
}
func (w *Writer) Int64(num int64) {
w.Uint64(uint64(num))
}
func (w *Writer) Float32(num float32) {
w.Uint32(math.Float32bits(num))
}
func (w *Writer) Float64(num float64) {
w.Uint64(math.Float64bits(num))
}
func (w *Writer) String(s string) {
w.Uvarint(uint64(len(s)))
w.Write(internal.Bytes(s))
}
func (w *Writer) Bytes(b []byte) {
w.Uvarint(uint64(len(b)))
w.Write(b)
}
func (w *Writer) UUID(b []byte) {
if len(b) != uuidLen {
panic("not reached")
}
buf := w.buf[:uuidLen]
copy(buf, b)
packUUID(buf)
w.Write(buf)
}
// 2 int64 in little endian order?
func packUUID(b []byte) []byte {
_ = b[15]
b[0], b[7] = b[7], b[0]
b[1], b[6] = b[6], b[1]
b[2], b[5] = b[5], b[2]
b[3], b[4] = b[4], b[3]
b[8], b[15] = b[15], b[8]
b[9], b[14] = b[14], b[9]
b[10], b[13] = b[13], b[10]
b[11], b[12] = b[12], b[11]
return b
}
func (w *Writer) DateTime(tm time.Time) {
w.Uint32(uint32(unixTime(tm)))
}
const secsInDay = 24 * 3600
func (w *Writer) Date(tm time.Time) {
w.Uint16(uint16(unixTime(tm) / secsInDay))
}
func unixTime(tm time.Time) int64 {
if tm.IsZero() {
return 0
}
return tm.Unix()
}

128
ch/chschema/append.go Normal file
View File

@ -0,0 +1,128 @@
package chschema
import (
"database/sql/driver"
"encoding/hex"
"fmt"
"math"
"strconv"
"time"
)
func Append(fmter Formatter, b []byte, v any) []byte {
switch v := v.(type) {
case nil:
return AppendNull(b)
case bool:
return AppendBool(b, v)
case int8:
return strconv.AppendInt(b, int64(v), 10)
case int16:
return strconv.AppendInt(b, int64(v), 10)
case int32:
return strconv.AppendInt(b, int64(v), 10)
case int64:
return strconv.AppendInt(b, v, 10)
case int:
return strconv.AppendInt(b, int64(v), 10)
case uint8:
return strconv.AppendUint(b, uint64(v), 10)
case uint16:
return strconv.AppendUint(b, uint64(v), 10)
case uint32:
return strconv.AppendUint(b, uint64(v), 10)
case uint64:
return strconv.AppendUint(b, v, 10)
case uint:
return strconv.AppendUint(b, uint64(v), 10)
case float32:
return appendFloat(b, float64(v), 32)
case float64:
return appendFloat(b, v, 64)
case string:
return AppendString(b, v)
case time.Time:
return AppendTime(b, v)
case []byte:
return AppendBytes(b, v)
case QueryAppender:
return AppendQueryAppender(fmter, b, v)
case driver.Valuer:
return appendDriverValue(fmter, b, v)
default:
return AppendError(b, fmt.Errorf("ch: can't append %T", v))
}
}
func AppendError(b []byte, err error) []byte {
b = append(b, "?!("...)
b = append(b, err.Error()...)
b = append(b, ')')
return b
}
func AppendNull(b []byte) []byte {
return append(b, "NULL"...)
}
func AppendBool(dst []byte, v bool) []byte {
var c byte
if v {
c = 1
}
return append(dst, c)
}
func AppendFloat(dst []byte, v float64) []byte {
return appendFloat(dst, v, 64)
}
func appendFloat(dst []byte, v float64, bitSize int) []byte {
switch {
case math.IsNaN(v):
return append(dst, "nan"...)
case math.IsInf(v, 1):
return append(dst, "inf"...)
case math.IsInf(v, -1):
return append(dst, "-inf"...)
default:
return strconv.AppendFloat(dst, v, 'f', -1, bitSize)
}
}
func AppendString(b []byte, s string) []byte {
b = append(b, '\'')
for i := 0; i < len(s); i++ {
c := s[i]
if c == '\'' {
b = append(b, '\\', '\'')
} else {
b = append(b, c)
}
}
b = append(b, '\'')
return b
}
func AppendTime(b []byte, tm time.Time) []byte {
return tm.UTC().AppendFormat(b, "'2006-01-02 15:04:05'")
}
func AppendBytes(b []byte, bytes []byte) []byte {
if bytes == nil {
return AppendNull(b)
}
b = append(b, '\'')
tmp := make([]byte, hex.EncodedLen(len(bytes)))
hex.Encode(tmp, bytes)
b = append(b, "\\x"...)
b = append(b, tmp...)
b = append(b, '\'')
return b
}

204
ch/chschema/append_value.go Normal file
View File

@ -0,0 +1,204 @@
package chschema
import (
"database/sql/driver"
"fmt"
"net"
"reflect"
"strconv"
"time"
"github.com/uptrace/go-clickhouse/ch/internal"
)
var (
driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem()
)
type AppenderFunc func(fmter Formatter, b []byte, v reflect.Value) []byte
var valueAppenders []AppenderFunc
//nolint
func init() {
valueAppenders = []AppenderFunc{
reflect.Bool: appendBoolValue,
reflect.Int: appendIntValue,
reflect.Int8: appendIntValue,
reflect.Int16: appendIntValue,
reflect.Int32: appendIntValue,
reflect.Int64: appendIntValue,
reflect.Uint: appendUintValue,
reflect.Uint8: appendUintValue,
reflect.Uint16: appendUintValue,
reflect.Uint32: appendUintValue,
reflect.Uint64: appendUintValue,
reflect.Uintptr: nil,
reflect.Float32: appendFloat32Value,
reflect.Float64: appendFloat64Value,
reflect.Complex64: nil,
reflect.Complex128: nil,
reflect.Array: nil,
reflect.Chan: nil,
reflect.Func: nil,
reflect.Interface: appendIfaceValue,
reflect.Map: nil,
reflect.Ptr: nil,
reflect.Slice: nil,
reflect.String: appendStringValue,
reflect.Struct: nil,
reflect.UnsafePointer: nil,
}
}
func Appender(typ reflect.Type) AppenderFunc {
switch typ {
case timeType:
return appendTimeValue
case ipType:
return appendIPValue
case ipNetType:
return appendIPNetValue
}
if typ.Implements(queryAppenderType) {
return appendQueryAppenderValue
}
if typ.Implements(driverValuerType) {
return appendDriverValuerValue
}
kind := typ.Kind()
if kind != reflect.Ptr {
ptr := reflect.PtrTo(typ)
if ptr.Implements(queryAppenderType) {
return addrAppender(appendQueryAppenderValue)
}
if ptr.Implements(driverValuerType) {
return addrAppender(appendDriverValuerValue)
}
}
switch kind {
case reflect.Ptr:
return ptrAppenderFunc(typ)
case reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 {
return appendBytesValue
}
case reflect.Array:
if typ.Elem().Kind() == reflect.Uint8 {
return appendArrayBytesValue
}
}
return valueAppenders[kind]
}
func ptrAppenderFunc(typ reflect.Type) AppenderFunc {
appender := Appender(typ.Elem())
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
if v.IsNil() {
return AppendNull(b)
}
return appender(fmter, b, v.Elem())
}
}
func AppendValue(fmter Formatter, b []byte, v reflect.Value) []byte {
if v.Kind() == reflect.Ptr && v.IsNil() {
return AppendNull(b)
}
appender := Appender(v.Type())
return appender(fmter, b, v)
}
func appendIfaceValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return Append(fmter, b, v.Interface())
}
func appendBoolValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return AppendBool(b, v.Bool())
}
func appendIntValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return strconv.AppendInt(b, v.Int(), 10)
}
func appendUintValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return strconv.AppendUint(b, v.Uint(), 10)
}
func appendFloat32Value(fmter Formatter, b []byte, v reflect.Value) []byte {
return appendFloat(b, v.Float(), 32)
}
func appendFloat64Value(fmter Formatter, b []byte, v reflect.Value) []byte {
return appendFloat(b, v.Float(), 64)
}
func appendBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return AppendBytes(b, v.Bytes())
}
func appendArrayBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return AppendBytes(b, v.Slice(0, v.Len()).Bytes())
}
func appendStringValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return AppendString(b, v.String())
}
func appendTimeValue(fmter Formatter, b []byte, v reflect.Value) []byte {
tm := v.Interface().(time.Time)
return AppendTime(b, tm)
}
func appendIPValue(fmter Formatter, b []byte, v reflect.Value) []byte {
ip := v.Interface().(net.IP)
return AppendString(b, ip.String())
}
func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte {
ipnet := v.Interface().(net.IPNet)
return AppendString(b, ipnet.String())
}
func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return AppendString(b, internal.String(v.Bytes()))
}
func appendQueryAppenderValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return AppendQueryAppender(fmter, b, v.Interface().(QueryAppender))
}
func appendDriverValuerValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return appendDriverValue(fmter, b, v.Interface().(driver.Valuer))
}
func addrAppender(fn AppenderFunc) AppenderFunc {
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
if !v.CanAddr() {
err := fmt.Errorf("ch: Append(nonaddressable %T)", v.Interface())
return AppendError(b, err)
}
return fn(fmter, b, v.Addr())
}
}
func AppendQueryAppender(fmter Formatter, b []byte, app QueryAppender) []byte {
bb, err := app.AppendQuery(fmter, b)
if err != nil {
return AppendError(b, err)
}
return bb
}
func appendDriverValue(fmter Formatter, b []byte, v driver.Valuer) []byte {
value, err := v.Value()
if err != nil {
return AppendError(b, err)
}
return Append(fmter, b, value)
}

84
ch/chschema/block.go Normal file
View File

@ -0,0 +1,84 @@
package chschema
import (
"fmt"
"github.com/uptrace/go-clickhouse/ch/chproto"
)
type Block struct {
Table *Table
NumColumn int // read-only
NumRow int // read-only
Columns []*Column
columnMap map[string]*Column
}
func NewBlock(table *Table, numCol, numRow int) *Block {
return &Block{
Table: table,
NumColumn: numCol,
NumRow: numRow,
}
}
func (b *Block) ColumnForField(field *Field) *Column {
col := b.Column(field.CHName, field.CHType)
col.Field = field
return col
}
func (b *Block) Column(colName, colType string) *Column {
if col, ok := b.columnMap[colName]; ok {
return col
}
var col *Column
if b.Table != nil {
col = b.Table.NewColumn(colName, colType, b.NumRow)
}
if col == nil {
col = &Column{
Name: colName,
Type: colType,
Columnar: NewColumnFromCHType(colType, b.NumRow),
}
}
if b.Columns == nil && b.columnMap == nil {
b.Columns = make([]*Column, 0, b.NumColumn)
b.columnMap = make(map[string]*Column, b.NumColumn)
}
b.Columns = append(b.Columns, col)
b.columnMap[colName] = col
return col
}
func (b *Block) WriteTo(wr *chproto.Writer) error {
// Can't use b.NumRow for column oriented struct.
var numRow int
if len(b.Columns) > 0 {
numRow = b.Columns[0].Len()
}
wr.Uvarint(uint64(len(b.Columns)))
wr.Uvarint(uint64(numRow))
for _, col := range b.Columns {
if col.Len() != numRow {
err := fmt.Errorf("%s does not have expected number of rows: got %d, wanted %d",
col, col.Len(), numRow)
panic(err)
}
wr.String(col.Name)
wr.String(col.Type)
if err := col.WriteTo(wr); err != nil {
return err
}
}
return nil
}

1323
ch/chschema/column.go Normal file

File diff suppressed because it is too large Load Diff

354
ch/chschema/column_array.go Normal file
View File

@ -0,0 +1,354 @@
package chschema
import (
"fmt"
"reflect"
"github.com/uptrace/go-clickhouse/ch/chproto"
)
type ArrayColumnar interface {
WriteOffset(wr *chproto.Writer, offset int) int
WriteData(wr *chproto.Writer) error
}
type ArrayLCStringColumn struct {
*LCStringColumn
}
func (c ArrayLCStringColumn) Type() reflect.Type {
return stringSliceType
}
func (c *ArrayLCStringColumn) WriteTo(wr *chproto.Writer) error {
c.writeData(wr)
return nil
}
func (c *ArrayLCStringColumn) ReadFrom(rd *chproto.Reader, numRow int) error {
if numRow == 0 {
return nil
}
return c.readData(rd, numRow)
}
//------------------------------------------------------------------------------
type ArrayColumn struct {
Column reflect.Value
typ reflect.Type
elem Columnar
arrayElem ArrayColumnar
}
var _ Columnar = (*ArrayColumn)(nil)
func NewArrayColumn(typ reflect.Type, chType string, numRow int) Columnar {
elemType := chArrayElemType(chType)
if elemType == "" {
panic(fmt.Errorf("invalid array type: %q (Go type is %s)",
chType, typ.String()))
}
elem := NewColumn(typ.Elem(), elemType, 0)
var arrayElem ArrayColumnar
if _, ok := elem.(*LCStringColumn); ok {
panic("not reached")
}
arrayElem, _ = elem.(ArrayColumnar)
c := &ArrayColumn{
typ: reflect.SliceOf(typ),
elem: elem,
arrayElem: arrayElem,
}
c.Column = reflect.MakeSlice(c.typ, 0, numRow)
return c
}
func (c ArrayColumn) Type() reflect.Type {
return c.typ.Elem()
}
func (c *ArrayColumn) Reset(numRow int) {
if c.Column.Cap() >= numRow {
c.Column = c.Column.Slice(0, 0)
} else {
c.Column = reflect.MakeSlice(c.typ, 0, numRow)
}
}
func (c *ArrayColumn) Set(v any) {
c.Column = reflect.ValueOf(v)
}
func (c *ArrayColumn) Value() any {
return c.Column.Interface()
}
func (c *ArrayColumn) Nullable(nulls Uint8Column) any {
panic("not implemented")
}
func (c *ArrayColumn) Len() int {
return c.Column.Len()
}
func (c *ArrayColumn) Index(idx int) any {
return c.Column.Index(idx).Interface()
}
func (c ArrayColumn) Slice(s, e int) any {
return c.Column.Slice(s, e).Interface()
}
func (c *ArrayColumn) ConvertAssign(idx int, v reflect.Value) error {
v.Set(c.Column.Index(idx))
return nil
}
func (c *ArrayColumn) AppendValue(v reflect.Value) {
c.Column = reflect.Append(c.Column, v)
}
func (c *ArrayColumn) ReadFrom(rd *chproto.Reader, numRow int) error {
if c.Column.Cap() >= numRow {
c.Column = c.Column.Slice(0, numRow)
} else {
c.Column = reflect.MakeSlice(c.typ, numRow, numRow)
}
if numRow == 0 {
return nil
}
offsets := make([]int, numRow)
for i := 0; i < len(offsets); i++ {
offset, err := rd.Uint64()
if err != nil {
return err
}
offsets[i] = int(offset)
}
if err := c.elem.ReadFrom(rd, offsets[len(offsets)-1]); err != nil {
return err
}
var prev int
for i, offset := range offsets {
c.Column.Index(i).Set(reflect.ValueOf(c.elem.Slice(prev, offset)))
prev = offset
}
return nil
}
func (c *ArrayColumn) WriteTo(wr *chproto.Writer) error {
_ = c.WriteOffset(wr, 0)
colLen := c.Column.Len()
for i := 0; i < colLen; i++ {
// TODO: add SetValue or SetPointer
c.elem.Set(c.Column.Index(i).Interface())
var err error
if c.arrayElem != nil {
err = c.arrayElem.WriteData(wr)
} else {
err = c.elem.WriteTo(wr)
}
if err != nil {
return err
}
}
return nil
}
func (c *ArrayColumn) WriteOffset(wr *chproto.Writer, offset int) int {
colLen := c.Column.Len()
for i := 0; i < colLen; i++ {
el := c.Column.Index(i)
offset += el.Len()
wr.Uint64(uint64(offset))
}
if c.arrayElem == nil {
return offset
}
offset = 0
for i := 0; i < colLen; i++ {
el := c.Column.Index(i)
c.elem.Set(el.Interface()) // Use SetValue or SetPointer
offset = c.arrayElem.WriteOffset(wr, offset)
}
return offset
}
//------------------------------------------------------------------------------
type StringArrayColumn struct {
Column [][]string
elem Columnar
stringElem *StringColumn
lcElem *LCStringColumn
}
var _ Columnar = (*StringArrayColumn)(nil)
func NewStringArrayColumn(typ reflect.Type, chType string, numRow int) Columnar {
if _, funcType := aggFuncNameAndType(chType); funcType != "" {
chType = funcType
}
elemType := chArrayElemType(chType)
if elemType == "" {
panic(fmt.Errorf("invalid array type: %q (Go type is %s)",
chType, typ.String()))
}
columnar := NewColumn(typ.Elem(), elemType, 0)
var stringElem *StringColumn
var lcElem *LCStringColumn
switch v := columnar.(type) {
case *StringColumn:
stringElem = v
case *LCStringColumn:
stringElem = &v.StringColumn
lcElem = v
columnar = &ArrayLCStringColumn{v}
case *EnumColumn:
stringElem = &v.StringColumn
default:
panic(fmt.Errorf("unsupported column: %T", v))
}
return &StringArrayColumn{
Column: make([][]string, 0, numRow),
elem: columnar,
stringElem: stringElem,
lcElem: lcElem,
}
}
func (c *StringArrayColumn) Reset(numRow int) {
if cap(c.Column) >= numRow {
c.Column = c.Column[:0]
} else {
c.Column = make([][]string, 0, numRow)
}
}
func (c *StringArrayColumn) Type() reflect.Type {
return stringSliceType
}
func (c *StringArrayColumn) Set(v any) {
c.Column = v.([][]string)
}
func (c *StringArrayColumn) Value() any {
return c.Column
}
func (c *StringArrayColumn) Nullable(nulls Uint8Column) any {
panic("not implemented")
}
func (c *StringArrayColumn) Len() int {
return len(c.Column)
}
func (c *StringArrayColumn) Index(idx int) any {
return c.Column[idx]
}
func (c StringArrayColumn) Slice(s, e int) any {
return c.Column[s:e]
}
func (c *StringArrayColumn) ConvertAssign(idx int, v reflect.Value) error {
v.Set(reflect.ValueOf(c.Column[idx]))
return nil
}
func (c *StringArrayColumn) AppendValue(v reflect.Value) {
c.Column = append(c.Column, v.Interface().([]string))
}
func (c *StringArrayColumn) ReadFrom(rd *chproto.Reader, numRow int) error {
if numRow == 0 {
return nil
}
if cap(c.Column) >= numRow {
c.Column = c.Column[:numRow]
} else {
c.Column = make([][]string, numRow)
}
if c.lcElem != nil {
if err := c.lcElem.readPrefix(rd, numRow); err != nil {
return err
}
}
offsets := make([]int, numRow)
for i := 0; i < len(offsets); i++ {
offset, err := rd.Uint64()
if err != nil {
return err
}
offsets[i] = int(offset)
}
if err := c.elem.ReadFrom(rd, offsets[len(offsets)-1]); err != nil {
return err
}
var prev int
for i, offset := range offsets {
c.Column[i] = c.stringElem.Column[prev:offset]
prev = offset
}
return nil
}
func (c *StringArrayColumn) WriteTo(wr *chproto.Writer) error {
if c.lcElem != nil {
c.lcElem.writePrefix(wr)
}
_ = c.WriteOffset(wr, 0)
return c.WriteData(wr)
}
var _ ArrayColumnar = (*StringArrayColumn)(nil)
func (c *StringArrayColumn) WriteOffset(wr *chproto.Writer, offset int) int {
for _, el := range c.Column {
offset += len(el)
wr.Uint64(uint64(offset))
}
return offset
}
func (c *StringArrayColumn) WriteData(wr *chproto.Writer) error {
for _, ss := range c.Column {
c.stringElem.Column = ss
if err := c.elem.WriteTo(wr); err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,100 @@
package chschema
import (
"reflect"
"github.com/uptrace/go-clickhouse/ch/chproto"
)
type NullableColumn struct {
Nulls Uint8Column
Values Columnar
nullable reflect.Value // reflect.Slice
}
func NullableNewColumnFunc(fn NewColumnFunc) NewColumnFunc {
return func(typ reflect.Type, chType string, numRow int) Columnar {
return &NullableColumn{
Values: fn(typ, chType, numRow),
}
}
}
var _ Columnar = (*NullableColumn)(nil)
func (c *NullableColumn) Type() reflect.Type {
return reflect.PtrTo(c.Values.Type())
}
func (c *NullableColumn) Set(v any) {
panic("not reached")
}
func (c *NullableColumn) AppendValue(v reflect.Value) {
if v.IsNil() {
c.Nulls.Column = append(c.Nulls.Column, 1)
c.Values.AppendValue(reflect.New(c.Values.Type()).Elem())
} else {
c.Nulls.Column = append(c.Nulls.Column, 0)
c.Values.AppendValue(v.Elem())
}
}
func (c *NullableColumn) Value() any {
return c.nullable.Interface()
}
func (c *NullableColumn) Nullable(nulls Uint8Column) any {
panic("not implemented")
}
func (c *NullableColumn) Len() int {
return c.Values.Len()
}
func (c *NullableColumn) Index(idx int) any {
elem := c.nullable.Index(idx)
if elem.IsNil() {
return nil
}
return elem.Elem().Interface()
}
func (c *NullableColumn) Slice(s, e int) any {
panic("not implemented")
}
func (c *NullableColumn) ConvertAssign(idx int, dest reflect.Value) error {
if idx < len(c.Nulls.Column) && c.Nulls.Column[idx] == 1 {
return nil
}
if dest.IsNil() {
dest.Set(reflect.New(dest.Type().Elem()))
}
return c.Values.ConvertAssign(idx, dest.Elem())
}
func (c *NullableColumn) ReadFrom(rd *chproto.Reader, numRow int) error {
if numRow == 0 {
return nil
}
if err := c.Nulls.ReadFrom(rd, numRow); err != nil {
return err
}
if err := c.Values.ReadFrom(rd, numRow); err != nil {
return err
}
c.nullable = reflect.ValueOf(c.Values.Nullable(c.Nulls))
return nil
}
func (c *NullableColumn) WriteTo(wr *chproto.Writer) error {
if err := c.Nulls.WriteTo(wr); err != nil {
return err
}
return c.Values.WriteTo(wr)
}
func isNilValue(v reflect.Value) bool {
return false
}

150
ch/chschema/enum.go Normal file
View File

@ -0,0 +1,150 @@
package chschema
import (
"fmt"
"strconv"
"strings"
"sync"
)
var enumMap sync.Map
type enumInfo struct {
chType string
dec []string
enc map[string]int16
}
func (e *enumInfo) Encode(val string) (int16, bool) {
i, ok := e.enc[val]
return i, ok
}
func (e *enumInfo) Decode(i int16) string {
return e.dec[i]
}
func parseEnum(s string) *enumInfo {
if v, ok := enumMap.Load(s); ok {
return v.(*enumInfo)
}
enumInfo, err := _parseEnum(s)
if err != nil {
panic(err)
}
enumInfo.chType = s
enumMap.Store(s, enumInfo)
return enumInfo
}
func _parseEnum(chType string) (*enumInfo, error) {
s := enumType(chType)
if s == "" {
return nil, fmt.Errorf("can't parse enum type: %q", chType)
}
var dec []string
for s != "" {
var key, val string
var ok bool
s, key, ok = scanEnumKey(s)
if !ok {
return nil, fmt.Errorf("can't parse enum key: %q", s)
}
s, ok = scanEnumChar(s, '=')
if !ok {
return nil, fmt.Errorf("can't parse enum '=': %q", s)
}
s, val = scanEnumValue(s)
if val == "" {
return nil, fmt.Errorf("can't parse enum value: %q", s)
}
n, err := strconv.ParseInt(val, 10, 16)
if err != nil {
return nil, err
}
ln := int(n + 1)
if len(dec) < ln {
dec = append(dec, make([]string, ln-len(dec))...)
}
dec[n] = key
s, _ = scanEnumChar(s, ',')
}
enc := make(map[string]int16, len(dec))
for i, s := range dec {
enc[s] = int16(i)
}
return &enumInfo{
chType: chType,
dec: dec,
enc: enc,
}, nil
}
func scanEnumKey(s string) (string, string, bool) {
loop:
for i := 0; i < len(s); i++ {
c := s[i]
switch c {
case ' ':
// ignore
case '\'':
s = s[i+1:]
break loop
default:
return s, "", false
}
}
i := strings.IndexByte(s, '\'')
if i == -1 {
return s, "", false
}
key := s[:i]
s = s[i+1:]
return s, key, true
}
func scanEnumChar(s string, ch byte) (string, bool) {
var start int
loop:
for i := 0; i < len(s); i++ {
c := s[i]
switch c {
case ' ':
start = i + 1
case ch:
return s[i+1:], true
default:
break loop
}
}
return s[start:], false
}
func scanEnumValue(s string) (string, string) {
var start int
for i := 0; i < len(s); i++ {
c := s[i]
switch {
case c == ' ':
start = i + 1
case c >= '0' && c <= '9':
// continue
default:
return s[i:], s[start:i]
}
}
return "", s[start:]
}

58
ch/chschema/field.go Normal file
View File

@ -0,0 +1,58 @@
package chschema
import (
"fmt"
"reflect"
)
const (
customTypeFlag = uint8(1) << iota
)
type Field struct {
Field reflect.StructField
Type reflect.Type
Index []int
GoName string // struct field name, e.g. Id
CHName string // SQL name, .e.g. id
Column Safe // escaped SQL name, e.g. "id"
CHType string
CHDefault Safe
NewColumn NewColumnFunc
appendValue AppenderFunc
IsPK bool
NotNull bool
flags uint8
}
func (f *Field) String() string {
return "field=" + f.GoName
}
func (f *Field) Value(strct reflect.Value) reflect.Value {
return fieldByIndexAlloc(strct, f.Index)
}
func (f *Field) AppendValue(fmter Formatter, b []byte, strct reflect.Value) []byte {
fv, ok := fieldByIndex(strct, f.Index)
if !ok {
return AppendNull(b)
}
if f.appendValue == nil {
return AppendError(b, fmt.Errorf("ch: AppendValue(unsupported %s)", fv.Type()))
}
return f.appendValue(fmter, b, fv)
}
func (f *Field) setFlag(flag uint8) {
f.flags |= flag
}
func (f *Field) hasFlag(flag uint8) bool {
return f.flags&flag != 0
}

217
ch/chschema/formatter.go Normal file
View File

@ -0,0 +1,217 @@
package chschema
import (
"reflect"
"strconv"
"strings"
"github.com/uptrace/go-clickhouse/ch/internal"
"github.com/uptrace/go-clickhouse/ch/internal/parser"
)
var emptyFmter Formatter
func FormatQuery(query string, args ...any) string {
return emptyFmter.FormatQuery(query, args...)
}
func AppendQuery(b []byte, query string, args ...any) []byte {
return emptyFmter.AppendQuery(b, query, args...)
}
type Formatter struct {
args *namedArgList
}
func NewFormatter() Formatter {
return Formatter{}
}
func (f Formatter) AppendIdent(b []byte, ident string) []byte {
return AppendIdent(b, ident)
}
func (f Formatter) WithArg(arg NamedArgAppender) Formatter {
return Formatter{
args: f.args.WithArg(arg),
}
}
func (f Formatter) WithNamedArg(name string, value any) Formatter {
return Formatter{
args: f.args.WithArg(&namedArg{name: name, value: value}),
}
}
func (f Formatter) FormatQuery(query string, args ...any) string {
if (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 {
return query
}
return internal.String(f.AppendQuery(nil, query, args...))
}
func (f Formatter) AppendQuery(b []byte, query string, args ...any) []byte {
if (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 {
return append(b, query...)
}
return f.append(b, parser.NewString(query), args)
}
func (f Formatter) append(dst []byte, p *parser.Parser, args []any) []byte {
var namedArgs NamedArgAppender
if len(args) == 1 {
if v, ok := args[0].(NamedArgAppender); ok {
namedArgs = v
} else if v, ok := newStructArgs(f, args[0]); ok {
namedArgs = v
}
}
var argIndex int
for p.Valid() {
b, ok := p.ReadSep('?')
if !ok {
dst = append(dst, b...)
continue
}
if len(b) > 0 && b[len(b)-1] == '\\' {
dst = append(dst, b[:len(b)-1]...)
dst = append(dst, '?')
continue
}
dst = append(dst, b...)
name, numeric := p.ReadIdentifier()
if name != "" {
if numeric {
idx, err := strconv.Atoi(name)
if err != nil {
goto restore_arg
}
if idx >= len(args) {
goto restore_arg
}
dst = f.appendArg(dst, args[idx])
continue
}
if namedArgs != nil {
dst, ok = namedArgs.AppendNamedArg(f, dst, name)
if ok {
continue
}
}
dst, ok = f.args.AppendNamedArg(f, dst, name)
if ok {
continue
}
restore_arg:
dst = append(dst, '?')
dst = append(dst, name...)
continue
}
if argIndex >= len(args) {
dst = append(dst, '?')
continue
}
arg := args[argIndex]
argIndex++
dst = f.appendArg(dst, arg)
}
return dst
}
func (f Formatter) appendArg(b []byte, arg any) []byte {
switch arg := arg.(type) {
case QueryAppender:
bb, err := arg.AppendQuery(f, b)
if err != nil {
return AppendError(b, err)
}
return bb
default:
return Append(f, b, arg)
}
}
//------------------------------------------------------------------------------
type NamedArgAppender interface {
AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool)
}
type namedArgList struct {
arg NamedArgAppender
next *namedArgList
}
func (l *namedArgList) WithArg(arg NamedArgAppender) *namedArgList {
return &namedArgList{
arg: arg,
next: l,
}
}
func (l *namedArgList) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
for l != nil && l.arg != nil {
if b, ok := l.arg.AppendNamedArg(fmter, b, name); ok {
return b, true
}
l = l.next
}
return b, false
}
//------------------------------------------------------------------------------
type namedArg struct {
name string
value any
}
var _ NamedArgAppender = (*namedArg)(nil)
func (a *namedArg) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
if a.name == name {
return fmter.appendArg(b, a.value), true
}
return b, false
}
//------------------------------------------------------------------------------
type structArgs struct {
table *Table
strct reflect.Value
}
var _ NamedArgAppender = (*structArgs)(nil)
func newStructArgs(fmter Formatter, strct any) (*structArgs, bool) {
v := reflect.ValueOf(strct)
if !v.IsValid() {
return nil, false
}
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return nil, false
}
return &structArgs{
table: TableForType(v.Type()),
strct: v,
}, true
}
func (m *structArgs) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
return m.table.AppendNamedArg(fmter, b, name, m.strct)
}

23
ch/chschema/hooks.go Normal file
View File

@ -0,0 +1,23 @@
package chschema
import (
"context"
"reflect"
)
type Query interface {
QueryAppender
Operation() string
GetModel() Model
GetTableName() string
}
type Model interface {
ScanBlock(*Block) error
}
type AfterScanRowHook interface {
AfterScanRow(context.Context) error
}
var afterScanBlockHookType = reflect.TypeOf((*AfterScanRowHook)(nil)).Elem()

53
ch/chschema/lowcard.go Normal file
View File

@ -0,0 +1,53 @@
package chschema
type lowCard struct {
slice sliceMap
dict map[string]int
}
func (lc *lowCard) Add(word string) int {
if i, ok := lc.dict[word]; ok {
return i
}
if lc.dict == nil {
lc.dict = make(map[string]int)
}
i := lc.slice.Add(word)
lc.dict[word] = i
return i
}
func (lc *lowCard) Dict() []string {
return lc.slice.Slice()
}
//------------------------------------------------------------------------------
type sliceMap struct {
ss []string
}
func (m sliceMap) Len() int {
return len(m.ss)
}
func (m sliceMap) Get(word string) (int, bool) {
for i, s := range m.ss {
if s == word {
return i, true
}
}
return 0, false
}
func (m *sliceMap) Add(word string) int {
m.ss = append(m.ss, word)
return len(m.ss) - 1
}
func (m sliceMap) Slice() []string {
return m.ss
}

74
ch/chschema/reflect.go Normal file
View File

@ -0,0 +1,74 @@
package chschema
import (
"reflect"
)
func indirect(v reflect.Value) reflect.Value {
switch v.Kind() {
case reflect.Interface:
return indirect(v.Elem())
case reflect.Ptr:
return v.Elem()
default:
return v
}
}
func indirectType(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}
func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) {
if len(index) == 1 {
return v.Field(index[0]), true
}
for i, idx := range index {
if i > 0 {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return v, false
}
v = v.Elem()
}
}
v = v.Field(idx)
}
return v, true
}
func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value {
if len(index) == 1 {
return v.Field(index[0])
}
for i, idx := range index {
if i > 0 {
v = indirectNil(v)
}
v = v.Field(idx)
}
return v
}
func indirectNil(v reflect.Value) reflect.Value {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
return v
}
func sliceElemType(v reflect.Value) reflect.Type {
elemType := v.Type().Elem()
if elemType.Kind() == reflect.Interface && v.Len() > 0 {
return indirect(v.Index(0).Elem()).Type()
}
return indirectType(elemType)
}

161
ch/chschema/sqlfmt.go Normal file
View File

@ -0,0 +1,161 @@
package chschema
import (
"strings"
"github.com/uptrace/go-clickhouse/ch/internal"
)
type QueryAppender interface {
AppendQuery(fmter Formatter, b []byte) ([]byte, error)
}
type ColumnsAppender interface {
AppendColumns(fmter Formatter, b []byte) ([]byte, error)
}
//------------------------------------------------------------------------------
// Safe represents a safe SQL query.
type Safe string
var _ QueryAppender = (*Safe)(nil)
func (s Safe) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
return append(b, s...), nil
}
//------------------------------------------------------------------------------
// FQN represents a fully qualified SQL name, for example, table or column name.
type FQN string
var _ QueryAppender = (*FQN)(nil)
func (s FQN) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
return fmter.AppendIdent(b, string(s)), nil
}
func AppendFQN(b []byte, field string) []byte {
return appendFQN(b, internal.Bytes(field))
}
func appendFQN(b, src []byte) []byte {
const quote = '"'
var quoted bool
loop:
for _, c := range src {
switch c {
case '*':
if !quoted {
b = append(b, '*')
continue loop
}
case '.':
if quoted {
b = append(b, quote)
quoted = false
}
b = append(b, '.')
continue loop
}
if !quoted {
b = append(b, quote)
quoted = true
}
if c == quote {
b = append(b, quote, quote)
} else {
b = append(b, c)
}
}
if quoted {
b = append(b, quote)
}
return b
}
// Ident represents a SQL identifier, for example, table or column name.
type Ident string
var _ QueryAppender = (*Ident)(nil)
func (s Ident) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
return fmter.AppendIdent(b, string(s)), nil
}
func AppendIdent(b []byte, field string) []byte {
return appendIdent(b, internal.Bytes(field))
}
func appendIdent(b, src []byte) []byte {
const quote = '"'
b = append(b, quote)
for _, c := range src {
if c == quote {
b = append(b, quote, quote)
} else {
b = append(b, c)
}
}
b = append(b, quote)
return b
}
//------------------------------------------------------------------------------
type QueryWithArgs struct {
Query string
Args []any
}
var _ QueryAppender = (*QueryWithArgs)(nil)
func SafeQuery(query string, args []any) QueryWithArgs {
if args == nil {
args = make([]any, 0)
} else if len(query) > 0 && strings.IndexByte(query, '?') == -1 {
internal.Warn.Printf("query %q has %v args, but no placeholders", query, args)
}
return QueryWithArgs{
Query: query,
Args: args,
}
}
func UnsafeIdent(ident string) QueryWithArgs {
return QueryWithArgs{Query: ident}
}
func (q QueryWithArgs) IsZero() bool {
return q.Query == "" && q.Args == nil
}
func (q QueryWithArgs) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
if q.Args == nil {
return fmter.AppendIdent(b, q.Query), nil
}
return fmter.AppendQuery(b, q.Query, q.Args...), nil
}
func (q QueryWithArgs) Value() Safe {
b, _ := q.AppendQuery(emptyFmter, nil)
return Safe(b)
}
//------------------------------------------------------------------------------
type QueryWithSep struct {
QueryWithArgs
Sep string
}
func SafeQueryWithSep(query string, args []any, sep string) QueryWithSep {
return QueryWithSep{
QueryWithArgs: SafeQuery(query, args),
Sep: sep,
}
}

311
ch/chschema/table.go Normal file
View File

@ -0,0 +1,311 @@
package chschema
import (
"fmt"
"reflect"
"github.com/codemodus/kace"
"github.com/jinzhu/inflection"
"github.com/uptrace/go-clickhouse/ch/chtype"
"github.com/uptrace/go-clickhouse/ch/internal"
"github.com/uptrace/go-clickhouse/ch/internal/tagparser"
)
const (
discardUnknownColumnsFlag = internal.Flag(1) << iota
columnarFlag
afterScanBlockHookFlag
)
var (
chModelType = reflect.TypeOf((*CHModel)(nil)).Elem()
tableNameInflector = inflection.Plural
)
type CHModel struct{}
// SetTableNameInflector overrides the default func that pluralizes
// model name to get table name, e.g. my_article becomes my_articles.
func SetTableNameInflector(fn func(string) string) {
tableNameInflector = fn
}
type Table struct {
Type reflect.Type
ModelName string
Name string
CHName Safe
CHInsertName Safe
CHAlias Safe
CHEngine string
CHPartition string
Fields []*Field // PKs + DataFields
PKs []*Field
DataFields []*Field
FieldMap map[string]*Field
flags internal.Flag
}
func newTable(typ reflect.Type) *Table {
t := new(Table)
t.Type = typ
t.ModelName = kace.Snake(t.Type.Name())
tableName := tableNameInflector(t.ModelName)
t.setName(tableName)
t.CHAlias = quoteColumnName(t.ModelName)
t.initFields()
typ = reflect.PtrTo(t.Type)
if typ.Implements(afterScanBlockHookType) {
t.flags.Set(afterScanBlockHookFlag)
}
return t
}
func (t *Table) String() string {
return "model=" + t.ModelName
}
func (t *Table) IsColumnar() bool {
return t.flags.Has(columnarFlag)
}
func (t *Table) setName(name string) {
quoted := quoteTableName(name)
t.Name = name
t.CHName = quoted
t.CHInsertName = quoted
if t.CHAlias == "" {
t.CHAlias = quoted
}
}
func (t *Table) Field(name string) (*Field, error) {
field, ok := t.FieldMap[name]
if !ok {
return nil, &UnknownColumnError{
Table: t,
Column: name,
}
}
return field, nil
}
func (t *Table) initFields() {
t.Fields = make([]*Field, 0, t.Type.NumField())
t.FieldMap = make(map[string]*Field, t.Type.NumField())
t.addFields(t.Type, nil)
}
func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
tag := tagparser.Parse(f.Tag.Get("ch"))
if tag.Name == "-" {
continue
}
// Make a copy so slice is not shared between fields.
index := make([]int, len(baseIndex))
copy(index, baseIndex)
if f.Anonymous {
if f.Name == "CHModel" && f.Type == chModelType {
if len(index) == 0 {
t.processCHModelField(f)
}
continue
}
fieldType := indirectType(f.Type)
if fieldType.Kind() != reflect.Struct {
continue
}
t.addFields(fieldType, append(index, f.Index...))
if _, ok := tag.Options["inherit"]; ok {
embeddedTable := globalTables.Get(fieldType)
t.ModelName = embeddedTable.ModelName
t.CHName = embeddedTable.CHName
t.CHAlias = embeddedTable.CHAlias
}
continue
}
if field := t.newField(f, index, tag); field != nil {
t.addField(field)
}
}
for _, f := range t.FieldMap {
if t.IsColumnar() {
f.Type = f.Type.Elem()
if !f.hasFlag(customTypeFlag) {
if s := chArrayElemType(f.CHType); s != "" {
f.CHType = s
}
}
}
if f.NewColumn == nil {
f.NewColumn = ColumnFactory(f.Type, f.CHType)
}
}
}
func (t *Table) processCHModelField(f reflect.StructField) {
tag := tagparser.Parse(f.Tag.Get("ch"))
if tag.Name != "" {
t.setName(tag.Name)
}
if s, ok := tag.Option("table"); ok {
t.setName(s)
}
if s, ok := tag.Option("alias"); ok {
t.CHAlias = quoteColumnName(s)
}
if s, ok := tag.Option("insert"); ok {
t.CHInsertName = quoteTableName(s)
}
if s, ok := tag.Option("engine"); ok {
t.CHEngine = s
}
if s, ok := tag.Option("partition"); ok {
t.CHPartition = s
}
if tag.HasOption("columnar") {
t.flags |= columnarFlag
}
}
func (t *Table) newField(f reflect.StructField, index []int, tag tagparser.Tag) *Field {
if f.PkgPath != "" {
return nil
}
if tag.Name == "" {
tag.Name = kace.Snake(f.Name)
}
field := &Field{
Field: f,
Type: f.Type,
GoName: f.Name,
CHName: tag.Name,
Column: quoteColumnName(tag.Name),
Index: append(index, f.Index...),
}
field.NotNull = tag.HasOption("notnull")
field.IsPK = tag.HasOption("pk")
if s, ok := tag.Option("type"); ok {
field.CHType = s
field.setFlag(customTypeFlag)
} else {
field.CHType = clickhouseType(f.Type)
}
if tag.HasOption("lc") {
if s := chSubType(field.CHType, "Array("); s != "" && s == chtype.String {
field.CHType = "Array(LowCardinality(String))"
} else if field.CHType == chtype.String {
field.CHType = "LowCardinality(String)"
} else {
panic(fmt.Errorf("unsupported lc option on %s type", field.CHType))
}
}
if s, ok := tag.Option("default"); ok {
field.CHDefault = Safe(s)
}
field.appendValue = Appender(f.Type)
if s, ok := tag.Option("alt"); ok {
t.FieldMap[s] = field
}
if tag.HasOption("scanonly") {
t.FieldMap[field.CHName] = field
return nil
}
return field
}
func (t *Table) addField(field *Field) {
t.Fields = append(t.Fields, field)
if field.IsPK {
t.PKs = append(t.PKs, field)
} else {
t.DataFields = append(t.DataFields, field)
}
t.FieldMap[field.CHName] = field
}
func (t *Table) NewColumn(colName, colType string, numRow int) *Column {
field, ok := t.FieldMap[colName]
if !ok {
internal.Logger.Printf("ch: %s has no column=%q", t, colName)
return nil
}
if colType != field.CHType {
if field.CHType != chtype.Any {
internal.Logger.Printf("got column type %q, but %s.%s has type %q",
colType, t.Type.Name(), field.GoName, field.CHType)
}
return &Column{
Name: colName,
Type: colType,
Columnar: ColumnFactory(field.Type, colType)(field.Type, colType, numRow),
}
}
return &Column{
Name: colName,
Type: field.CHType,
Columnar: field.NewColumn(field.Type, field.CHType, numRow),
}
}
func (t *Table) HasAfterScanRowHook() bool { return t.flags.Has(afterScanBlockHookFlag) }
func (t *Table) AppendNamedArg(
fmter Formatter, b []byte, name string, strct reflect.Value,
) ([]byte, bool) {
if field, ok := t.FieldMap[name]; ok {
return field.AppendValue(fmter, b, strct), true
}
return b, false
}
func quoteTableName(s string) Safe {
return Safe(appendFQN(nil, internal.Bytes(s)))
}
func quoteColumnName(s string) Safe {
return Safe(appendIdent(nil, internal.Bytes(s)))
}
//------------------------------------------------------------------------------
type UnknownColumnError struct {
Table *Table
Column string
}
func (err *UnknownColumnError) Error() string {
return fmt.Sprintf("ch: %s does not have column=%q",
err.Table, err.Column)
}

51
ch/chschema/tables.go Normal file
View File

@ -0,0 +1,51 @@
package chschema
import (
"fmt"
"reflect"
"sync"
)
var globalTables = newTablesMap()
func TableForType(typ reflect.Type) *Table {
return globalTables.Get(typ)
}
type tablesMap struct {
m sync.Map
}
func newTablesMap() *tablesMap {
return new(tablesMap)
}
func (t *tablesMap) Get(typ reflect.Type) *Table {
if typ.Kind() != reflect.Struct {
panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct))
}
if v, ok := t.m.Load(typ); ok {
return v.(*Table)
}
table := newTable(typ)
if v, loaded := t.m.LoadOrStore(typ, table); loaded {
return v.(*Table)
}
return table
}
func (t *tablesMap) getByName(name string) *Table {
var found *Table
t.m.Range(func(key, value any) bool {
t := value.(*Table)
if t.Name == name || t.ModelName == name {
found = t
return false
}
return true
})
return found
}

404
ch/chschema/types.go Normal file
View File

@ -0,0 +1,404 @@
package chschema
import (
"fmt"
"net"
"reflect"
"strings"
"time"
"github.com/uptrace/go-clickhouse/ch/chtype"
"github.com/uptrace/go-clickhouse/ch/internal"
)
var chType = [...]string{
reflect.Bool: chtype.UInt8,
reflect.Int: chtype.Int64,
reflect.Int8: chtype.Int8,
reflect.Int16: chtype.Int16,
reflect.Int32: chtype.Int32,
reflect.Int64: chtype.Int64,
reflect.Uint: chtype.UInt64,
reflect.Uint8: chtype.UInt8,
reflect.Uint16: chtype.UInt16,
reflect.Uint32: chtype.UInt32,
reflect.Uint64: chtype.UInt64,
reflect.Uintptr: "",
reflect.Float32: chtype.Float32,
reflect.Float64: chtype.Float64,
reflect.Complex64: "",
reflect.Complex128: "",
reflect.Array: "",
reflect.Chan: "",
reflect.Func: "",
reflect.Interface: chtype.Any,
reflect.Map: chtype.String,
reflect.Ptr: "",
reflect.Slice: "",
reflect.String: chtype.String,
reflect.Struct: chtype.String,
reflect.UnsafePointer: "",
}
// keep in sync with ColumnFactory
func clickhouseType(typ reflect.Type) string {
switch typ {
case timeType:
return chtype.DateTime
case ipType:
return chtype.IPv6
}
kind := typ.Kind()
switch kind {
case reflect.Ptr:
if typ.Elem().Kind() == reflect.Struct {
return chtype.String
}
return fmt.Sprintf("Nullable(%s)", clickhouseType(typ.Elem()))
case reflect.Slice:
switch elem := typ.Elem(); elem.Kind() {
case reflect.Ptr:
if elem.Elem().Kind() == reflect.Struct {
return chtype.String // json
}
case reflect.Struct:
if elem != timeType {
return chtype.String // json
}
case reflect.Uint8:
return chtype.String // []byte
}
return "Array(" + clickhouseType(typ.Elem()) + ")"
case reflect.Array:
if isUUID(typ) {
return chtype.UUID
}
}
if s := chType[kind]; s != "" {
return s
}
panic(fmt.Errorf("ch: unsupported Go type: %s", typ))
}
type NewColumnFunc func(typ reflect.Type, chType string, numRow int) Columnar
var kindToColumn = [...]NewColumnFunc{
reflect.Bool: NewBoolColumn,
reflect.Int: NewInt64Column,
reflect.Int8: NewInt8Column,
reflect.Int16: NewInt16Column,
reflect.Int32: NewInt32Column,
reflect.Int64: NewInt64Column,
reflect.Uint: NewUint64Column,
reflect.Uint8: NewUint8Column,
reflect.Uint16: NewUint16Column,
reflect.Uint32: NewUint32Column,
reflect.Uint64: NewUint64Column,
reflect.Uintptr: nil,
reflect.Float32: NewFloat32Column,
reflect.Float64: NewFloat64Column,
reflect.Complex64: nil,
reflect.Complex128: nil,
reflect.Array: nil,
reflect.Chan: nil,
reflect.Func: nil,
reflect.Interface: nil,
reflect.Map: NewJSONColumn,
reflect.Ptr: nil,
reflect.Slice: nil,
reflect.String: NewStringColumn,
reflect.Struct: NewJSONColumn,
reflect.UnsafePointer: nil,
}
// keep in sync with clickhouseType
func ColumnFactory(typ reflect.Type, chType string) NewColumnFunc {
if chType == chtype.Any {
return nil
}
if s := lowCardinalityType(chType); s != "" {
switch s {
case chtype.String:
return NewLCStringColumn
}
panic(fmt.Errorf("got %s, wanted LowCardinality(String)", chType))
}
if s := enumType(chType); s != "" {
return NewEnumColumn
}
if strings.HasPrefix(chType, "SimpleAggregateFunction(") {
chType = chSubType(chType, "SimpleAggregateFunction(")
} else if s := dateTimeType(chType); s != "" {
chType = s
}
switch typ {
case timeType:
switch chType {
case chtype.DateTime:
return NewDateTimeColumn
case chtype.Date:
return NewDateColumn
case chtype.Int64:
return NewTimeColumn
}
case ipType:
return NewIPColumn
}
kind := typ.Kind()
switch kind {
case reflect.Ptr:
if typ.Elem().Kind() == reflect.Struct {
return NewJSONColumn
}
return NullableNewColumnFunc(ColumnFactory(typ.Elem(), nullableType(chType)))
case reflect.Slice:
switch elem := typ.Elem(); elem.Kind() {
case reflect.Ptr:
if elem.Elem().Kind() == reflect.Struct {
return NewJSONColumn
}
case reflect.Uint8:
if chType == chtype.String {
return NewBytesColumn
}
case reflect.String:
return NewStringArrayColumn
case reflect.Struct:
if elem != timeType {
return NewJSONColumn
}
}
return NewArrayColumn
case reflect.Array:
if isUUID(typ) {
return NewUUIDColumn
}
case reflect.Interface:
return columnFromCHType(chType)
}
switch chType {
case chtype.DateTime:
switch typ {
case uint32Type:
return NewUint32Column
case int64Type:
return NewInt64TimeColumn
default:
return NewDateTimeColumn
}
}
fn := kindToColumn[kind]
if fn != nil {
return fn
}
panic(fmt.Errorf("unsupported go_type=%q ch_type=%q", typ.String(), chType))
}
func columnFromCHType(chType string) NewColumnFunc {
switch chType {
case chtype.String:
return NewStringColumn
case chtype.UUID:
return NewUUIDColumn
case chtype.Int8:
return NewInt8Column
case chtype.Int16:
return NewInt16Column
case chtype.Int32:
return NewInt32Column
case chtype.Int64:
return NewInt64Column
case chtype.UInt8:
return NewUint8Column
case chtype.UInt16:
return NewUint16Column
case chtype.UInt32:
return NewUint32Column
case chtype.UInt64:
return NewUint64Column
case chtype.Float32:
return NewFloat32Column
case chtype.Float64:
return NewFloat64Column
case chtype.DateTime:
return NewDateTimeColumn
case chtype.Date:
return NewDateColumn
case chtype.IPv6:
return NewIPColumn
default:
return nil
}
}
var (
boolType = reflect.TypeOf(false)
int8Type = reflect.TypeOf(int8(0))
int16Type = reflect.TypeOf(int16(0))
int32Type = reflect.TypeOf(int32(0))
int64Type = reflect.TypeOf(int64(0))
uint8Type = reflect.TypeOf(uint8(0))
uint16Type = reflect.TypeOf(uint16(0))
uint32Type = reflect.TypeOf(uint32(0))
uint64Type = reflect.TypeOf(uint64(0))
float32Type = reflect.TypeOf(float32(0))
float64Type = reflect.TypeOf(float64(0))
stringType = reflect.TypeOf("")
bytesType = reflect.TypeOf((*[]byte)(nil)).Elem()
uuidType = reflect.TypeOf((*UUID)(nil)).Elem()
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
bfloat16HistType = reflect.TypeOf((*map[chtype.BFloat16]uint64)(nil)).Elem()
int64SliceType = reflect.TypeOf((*[]int64)(nil)).Elem()
uint64SliceType = reflect.TypeOf((*[]uint64)(nil)).Elem()
float32SliceType = reflect.TypeOf((*[]float32)(nil)).Elem()
float64SliceType = reflect.TypeOf((*[]float64)(nil)).Elem()
stringSliceType = reflect.TypeOf((*[]string)(nil)).Elem()
)
func goType(chType string) reflect.Type {
switch chType {
case chtype.Int8:
return int8Type
case chtype.Int32:
return int32Type
case chtype.Int64:
return int64Type
case chtype.UInt8:
return uint8Type
case chtype.UInt16:
return uint16Type
case chtype.UInt32:
return uint32Type
case chtype.UInt64:
return uint64Type
case chtype.Float32:
return float32Type
case chtype.Float64:
return float64Type
case chtype.String:
return stringType
case chtype.UUID:
return uuidType
case chtype.DateTime:
return timeType
case chtype.Date:
return timeType
case chtype.IPv6:
return ipType
default:
}
if s := chArrayElemType(chType); s != "" {
return reflect.SliceOf(goType(s))
}
if s := lowCardinalityType(chType); s != "" {
return goType(s)
}
if s := enumType(chType); s != "" {
return stringType
}
if s := dateTimeType(chType); s != "" {
return timeType
}
if s := nullableType(chType); s != "" {
return reflect.PtrTo(goType(s))
}
if _, funcType := aggFuncNameAndType(chType); funcType != "" {
return goType(funcType)
}
panic(fmt.Errorf("unsupported ClickHouse type=%q", chType))
}
func chArrayElemType(s string) string {
s = chSubType(s, "Array(")
if s == "" {
return ""
}
elemType := s
s = chSubType(s, "SimpleAggregateFunction(")
if s == "" {
return elemType
}
if i := strings.Index(s, ", "); i >= 0 {
return s[i+2:]
}
return s
}
func lowCardinalityType(s string) string {
return chSubType(s, "LowCardinality(")
}
func enumType(s string) string {
return chSubType(s, "Enum8(")
}
func dateTimeType(s string) string {
s = chSubType(s, "DateTime(")
if s == "" {
return ""
}
if s != "'UTC'" {
internal.Logger.Printf("DateTime has timezeone=%q, expected UTC", s)
}
return chtype.DateTime
}
func nullableType(s string) string {
return chSubType(s, "Nullable(")
}
func aggFuncNameAndType(chType string) (funcName, funcType string) {
s := chSubType(chType, "SimpleAggregateFunction(")
if s == "" {
return "", ""
}
const sep = ", "
idx := strings.LastIndex(s, sep)
if idx == -1 {
return "", ""
}
funcName = s[:idx]
funcType = s[idx+len(sep):]
if idx := strings.IndexByte(funcName, '('); idx >= 0 {
funcName = funcName[:idx]
}
return funcName, funcType
}
func chSubType(s, prefix string) string {
if strings.HasPrefix(s, prefix) && strings.HasSuffix(s, ")") {
return s[len(prefix) : len(s)-1]
}
return ""
}
func isUUID(typ reflect.Type) bool {
return typ.Len() == 16 && typ.Elem().Kind() == reflect.Uint8
}

20
ch/chtype/chtype.go Normal file
View File

@ -0,0 +1,20 @@
package chtype
const (
Any = "_" // for decoding into interface{}
String = "String"
UUID = "UUID"
Int8 = "Int8"
Int16 = "Int16"
Int32 = "Int32"
Int64 = "Int64"
UInt8 = "UInt8"
UInt16 = "UInt16"
UInt32 = "UInt32"
UInt64 = "UInt64"
Float32 = "Float32"
Float64 = "Float64"
DateTime = "DateTime"
Date = "Date"
IPv6 = "IPv6"
)

View File

@ -0,0 +1,13 @@
package chtype
import "math"
type BFloat16 uint16
func ToBFloat16(f float64) BFloat16 {
return BFloat16(math.Float32bits(float32(f)) >> 16)
}
func (f BFloat16) Float32() float32 {
return math.Float32frombits(uint32(f) << 16)
}

372
ch/config.go Normal file
View File

@ -0,0 +1,372 @@
package ch
import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"runtime"
"strconv"
"strings"
"time"
"github.com/uptrace/go-clickhouse/ch/chpool"
"github.com/uptrace/go-clickhouse/ch/internal"
)
const (
discardUnknownColumnsFlag internal.Flag = 1 << iota
)
type Config struct {
chpool.Config
Network string
Addr string
User string
Password string
Database string
DialTimeout time.Duration
TLSConfig *tls.Config
QuerySettings map[string]any
ReadTimeout time.Duration
WriteTimeout time.Duration
MaxRetries int
MinRetryBackoff time.Duration
MaxRetryBackoff time.Duration
}
func (cfg *Config) netDialer() *net.Dialer {
return &net.Dialer{
Timeout: cfg.DialTimeout,
KeepAlive: 5 * time.Minute,
}
}
func defaultConfig() *Config {
var cfg *Config
poolSize := 2 * runtime.GOMAXPROCS(0)
cfg = &Config{
Network: "tcp",
Addr: "localhost:9000",
User: "default",
Database: "default",
DialTimeout: 5 * time.Second,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
MaxRetries: 2,
MinRetryBackoff: 500 * time.Millisecond,
MaxRetryBackoff: time.Second,
Config: chpool.Config{
PoolSize: poolSize,
MaxIdleConns: poolSize,
PoolTimeout: 30 * time.Second,
},
}
return cfg
}
type Option func(db *DB)
func WithDiscardUnknownColumns() Option {
return func(db *DB) {
db.flags.Set(discardUnknownColumnsFlag)
}
}
// WithAddr configures TCP host:port or Unix socket depending on Network.
func WithAddr(addr string) Option {
return func(db *DB) {
db.cfg.Addr = addr
}
}
// WithTLSConfig configures TLS config for secure connections.
func WithTLSConfig(cfg *tls.Config) Option {
return func(db *DB) {
db.cfg.TLSConfig = cfg
}
}
func WithQuerySettings(params map[string]any) Option {
return func(db *DB) {
db.cfg.QuerySettings = params
}
}
func WithInsecure(on bool) Option {
return func(db *DB) {
if on {
db.cfg.TLSConfig = nil
} else {
db.cfg.TLSConfig = &tls.Config{InsecureSkipVerify: true}
}
}
}
func WithUser(user string) Option {
return func(db *DB) {
db.cfg.User = user
}
}
func WithPassword(password string) Option {
return func(db *DB) {
db.cfg.Password = password
}
}
func WithDatabase(database string) Option {
return func(db *DB) {
db.cfg.Database = database
}
}
// WithDialTimeout configures dial timeout for establishing new connections.
// Default is 5 seconds.
func WithDialTimeout(timeout time.Duration) Option {
return func(db *DB) {
db.cfg.DialTimeout = timeout
}
}
// WithReadTimeout configures timeout for socket reads. If reached, commands will fail
// with a timeout instead of blocking.
func WithReadTimeout(timeout time.Duration) Option {
return func(db *DB) {
db.cfg.ReadTimeout = timeout
}
}
// WithWriteTimeout configures timeout for socket writes. If reached, commands will fail
// with a timeout instead of blocking.
func WithWriteTimeout(timeout time.Duration) Option {
return func(db *DB) {
db.cfg.WriteTimeout = timeout
}
}
func WithTimeout(timeout time.Duration) Option {
return func(db *DB) {
db.cfg.DialTimeout = timeout
db.cfg.ReadTimeout = timeout
db.cfg.WriteTimeout = timeout
}
}
// WithMaxRetries configures maximum number of retries before giving up.
// Default is to retry query 2 times.
func WithMaxRetries(maxRetries int) Option {
return func(db *DB) {
db.cfg.MaxRetries = maxRetries
}
}
// WithMinRetryBackoff configures minimum backoff between each retry.
// Default is 250 milliseconds; -1 disables backoff.
func WithMinRetryBackoff(backoff time.Duration) Option {
return func(db *DB) {
db.cfg.MinRetryBackoff = backoff
}
}
// WithMaxRetryBackoff configures maximum backoff between each retry.
// Default is 4 seconds; -1 disables backoff.
func WithMaxRetryBackoff(backoff time.Duration) Option {
return func(db *DB) {
db.cfg.MaxRetryBackoff = backoff
}
}
// WithPoolSize configures maximum number of socket connections.
// Default is 2 connections per every CPU as reported by runtime.NumCPU.
func WithPoolSize(poolSize int) Option {
return func(db *DB) {
db.cfg.PoolSize = poolSize
db.cfg.MaxIdleConns = poolSize
}
}
// WithMinIdleConns configures minimum number of idle connections which is useful when establishing
// new connection is slow.
func WithMinIdleConns(minIdleConns int) Option {
return func(db *DB) {
db.cfg.MinIdleConns = minIdleConns
}
}
// WithMaxConnAge configures Connection age at which client retires (closes) the connection.
// It is useful with proxies like HAProxy.
// Default is to not close aged connections.
func WithMaxConnAge(timeout time.Duration) Option {
return func(db *DB) {
db.cfg.MaxConnAge = timeout
}
}
// WithPoolTimeout configures time for which client waits for free connection if all
// connections are busy before returning an error.
// Default is 30 seconds if ReadTimeOut is not defined, otherwise,
// ReadTimeout + 1 second.
func WithPoolTimeout(timeout time.Duration) Option {
return func(db *DB) {
db.cfg.PoolTimeout = timeout
}
}
func WithDSN(dsn string) Option {
return func(db *DB) {
opts, err := parseDSN(dsn)
if err != nil {
panic(err)
}
for _, opt := range opts {
opt(db)
}
}
}
func parseDSN(dsn string) ([]Option, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, err
}
q := queryOptions{q: u.Query()}
var opts []Option
switch u.Scheme {
case "ch", "clickhouse":
if u.Host != "" {
addr := u.Host
if !strings.Contains(addr, ":") {
addr += ":5432"
}
opts = append(opts, WithAddr(addr))
}
if len(u.Path) > 1 {
opts = append(opts, WithDatabase(u.Path[1:]))
}
if host := q.string("host"); host != "" {
opts = append(opts, WithAddr(host))
}
default:
return nil, errors.New("ch: unknown scheme: " + u.Scheme)
}
if u.User != nil {
opts = append(opts, WithUser(u.User.Username()))
if password, ok := u.User.Password(); ok {
opts = append(opts, WithPassword(password))
}
}
switch sslMode := q.string("sslmode"); sslMode {
case "verify-ca", "verify-full":
opts = append(opts, WithTLSConfig(new(tls.Config)))
case "allow", "prefer", "require", "":
opts = append(opts, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
case "disable":
opts = append(opts, WithInsecure(true))
default:
return nil, fmt.Errorf("ch: sslmode '%s' is not supported", sslMode)
}
if d := q.duration("timeout"); d != 0 {
opts = append(opts, WithTimeout(d))
}
if d := q.duration("dial_timeout"); d != 0 {
opts = append(opts, WithDialTimeout(d))
}
if d := q.duration("read_timeout"); d != 0 {
opts = append(opts, WithReadTimeout(d))
}
if d := q.duration("write_timeout"); d != 0 {
opts = append(opts, WithWriteTimeout(d))
}
rem, err := q.remaining()
if err != nil {
return nil, q.err
}
if len(rem) > 0 {
params := make(map[string]any, len(rem))
for k, v := range rem {
params[k] = parseSettingValue(v)
}
opts = append(opts, WithQuerySettings(params))
}
return opts, nil
}
func parseSettingValue(s string) any {
if b, err := strconv.ParseBool(s); err == nil {
return b
}
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
return i
}
return s
}
type queryOptions struct {
q url.Values
err error
}
func (o *queryOptions) string(name string) string {
vs := o.q[name]
if len(vs) == 0 {
return ""
}
delete(o.q, name) // enable detection of unknown parameters
return vs[len(vs)-1]
}
func (o *queryOptions) duration(name string) time.Duration {
s := o.string(name)
if s == "" {
return 0
}
// try plain number first
if i, err := strconv.Atoi(s); err == nil {
if i <= 0 {
// disable timeouts
return -1
}
return time.Duration(i) * time.Second
}
dur, err := time.ParseDuration(s)
if err == nil {
return dur
}
if o.err == nil {
o.err = fmt.Errorf("ch: invalid %s duration: %w", name, err)
}
return 0
}
func (o *queryOptions) remaining() (map[string]string, error) {
if o.err != nil {
return nil, o.err
}
if len(o.q) == 0 {
return nil, nil
}
m := make(map[string]string, len(o.q))
for k, ss := range o.q {
m[k] = ss[len(ss)-1]
}
return m, nil
}

548
ch/db.go Normal file
View File

@ -0,0 +1,548 @@
package ch
import (
"context"
"crypto/tls"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"net"
"reflect"
"sync/atomic"
"time"
"github.com/uptrace/go-clickhouse/ch/chpool"
"github.com/uptrace/go-clickhouse/ch/chproto"
"github.com/uptrace/go-clickhouse/ch/chschema"
"github.com/uptrace/go-clickhouse/ch/internal"
)
type DBStats struct {
Queries uint64
Errors uint64
}
type DB struct {
cfg *Config
pool *chpool.ConnPool
queryHooks []QueryHook
fmter chschema.Formatter
flags internal.Flag
stats DBStats
}
func Connect(opts ...Option) *DB {
db := &DB{
cfg: defaultConfig(),
}
for _, opt := range opts {
opt(db)
}
db.pool = newConnPool(db.cfg)
return db
}
func newConnPool(cfg *Config) *chpool.ConnPool {
poolcfg := cfg.Config
poolcfg.Dialer = func(ctx context.Context) (net.Conn, error) {
if cfg.TLSConfig != nil {
return tls.DialWithDialer(
cfg.netDialer(),
cfg.Network,
cfg.Addr,
cfg.TLSConfig,
)
}
return cfg.netDialer().DialContext(ctx, cfg.Network, cfg.Addr)
}
return chpool.New(&poolcfg)
}
// Close closes the database client, releasing any open resources.
//
// It is rare to Close a DB, as the DB handle is meant to be
// long-lived and shared between many goroutines.
func (db *DB) Close() error {
return db.pool.Close()
}
func (db *DB) String() string {
return fmt.Sprintf("DB<addr: %s>", db.cfg.Addr)
}
func (db *DB) Config() *Config {
return db.cfg
}
func (db *DB) WithTimeout(d time.Duration) *DB {
newcfg := *db.cfg
newcfg.ReadTimeout = d
newcfg.WriteTimeout = d
clone := db.clone()
clone.cfg = &newcfg
return clone
}
func (db *DB) clone() *DB {
clone := *db
l := len(db.queryHooks)
clone.queryHooks = db.queryHooks[:l:l]
return &clone
}
func (db *DB) Stats() DBStats {
return DBStats{
Queries: atomic.LoadUint64(&db.stats.Queries),
Errors: atomic.LoadUint64(&db.stats.Errors),
}
}
func (db *DB) getConn(ctx context.Context) (*chpool.Conn, error) {
cn, err := db.pool.Get(ctx)
if err != nil {
return nil, err
}
if err := db.initConn(ctx, cn); err != nil {
db.pool.Remove(cn, err)
if err := internal.Unwrap(err); err != nil {
return nil, err
}
return nil, err
}
return cn, nil
}
func (db *DB) initConn(ctx context.Context, cn *chpool.Conn) error {
if cn.Inited {
return nil
}
cn.Inited = true
return db.hello(ctx, cn)
}
func (db *DB) releaseConn(cn *chpool.Conn, err error) {
if isBadConn(err, false) || cn.Closed() {
db.pool.Remove(cn, err)
} else {
db.pool.Put(cn)
}
}
func (db *DB) withConn(ctx context.Context, fn func(*chpool.Conn) error) error {
err := db._withConn(ctx, fn)
atomic.AddUint64(&db.stats.Queries, 1)
if err != nil {
atomic.AddUint64(&db.stats.Errors, 1)
}
return err
}
func (db *DB) _withConn(ctx context.Context, fn func(*chpool.Conn) error) error {
cn, err := db.getConn(ctx)
if err != nil {
return err
}
var done chan struct{}
if ctxDone := ctx.Done(); ctxDone != nil {
done = make(chan struct{})
go func() {
select {
case <-done:
// fn has finished, skip cancel
case <-ctxDone:
db.cancelConn(ctx, cn)
// Signal end of conn use.
done <- struct{}{}
}
}()
}
defer func() {
if done != nil {
select {
case <-done: // wait for cancel to finish request
case done <- struct{}{}: // signal fn finish, skip cancel goroutine
}
}
db.releaseConn(cn, err)
}()
// err is used in releaseConn above
err = fn(cn)
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
db.cancelConn(ctx, cn)
}
return err
}
func (db *DB) cancelConn(ctx context.Context, cn *chpool.Conn) {
if err := cn.WithWriter(ctx, db.cfg.WriteTimeout, func(wr *chproto.Writer) {
writeCancel(wr)
}); err != nil {
internal.Logger.Printf("writeCancel failed: %s", err)
}
_ = cn.Close()
}
func (db *DB) Ping(ctx context.Context) error {
return db.withConn(ctx, func(cn *chpool.Conn) error {
if err := cn.WithWriter(ctx, db.cfg.WriteTimeout, func(wr *chproto.Writer) {
writePing(wr)
}); err != nil {
return err
}
return cn.WithReader(ctx, db.cfg.ReadTimeout, func(rd *chproto.Reader) error {
return readPong(rd)
})
})
}
func (db *DB) Exec(query string, args ...any) (sql.Result, error) {
return db.ExecContext(context.Background(), query, args...)
}
func (db *DB) ExecContext(
ctx context.Context, query string, args ...any,
) (sql.Result, error) {
query = db.FormatQuery(query, args...)
ctx, evt := db.beforeQuery(ctx, nil, query, args, nil)
res, err := db.query(ctx, nil, query)
db.afterQuery(ctx, evt, res, err)
return res, err
}
func (db *DB) Query(query string, args ...any) (*Rows, error) {
return db.QueryContext(context.Background(), query, args...)
}
func (db *DB) QueryContext(
ctx context.Context, query string, args ...any,
) (*Rows, error) {
rows := newRows()
query = db.FormatQuery(query, args...)
ctx, evt := db.beforeQuery(ctx, nil, query, args, nil)
res, err := db.query(ctx, rows, query)
db.afterQuery(ctx, evt, res, err)
if err != nil {
return nil, err
}
return rows, nil
}
func (db *DB) QueryRow(query string, args ...any) *Row {
return db.QueryRowContext(context.Background(), query, args...)
}
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
rows, err := db.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err}
}
func (db *DB) query(ctx context.Context, model Model, query string) (*result, error) {
var res *result
var lastErr error
for attempt := 0; attempt <= db.cfg.MaxRetries; attempt++ {
if attempt > 0 {
lastErr = internal.Sleep(ctx, db.retryBackoff(attempt-1))
if lastErr != nil {
break
}
}
res, lastErr = db._query(ctx, model, query)
if !db.shouldRetry(lastErr) {
break
}
}
if lastErr == nil {
if model, ok := model.(AfterScanRowHook); ok {
if err := model.AfterScanRow(ctx); err != nil {
lastErr = err
}
}
}
return res, lastErr
}
func (db *DB) _query(ctx context.Context, model Model, query string) (*result, error) {
var res *result
err := db.withConn(ctx, func(cn *chpool.Conn) error {
if err := cn.WithWriter(ctx, db.cfg.WriteTimeout, func(wr *chproto.Writer) {
db.writeQuery(wr, query)
writeBlock(ctx, wr, nil)
}); err != nil {
return err
}
return cn.WithReader(ctx, db.cfg.ReadTimeout, func(rd *chproto.Reader) error {
var err error
res, err = readDataBlocks(rd, model)
return err
})
})
return res, err
}
func (db *DB) insert(
ctx context.Context, model TableModel, query string, fields []*chschema.Field,
) (*result, error) {
block := model.Block(fields)
var res *result
var lastErr error
for attempt := 0; attempt <= db.cfg.MaxRetries; attempt++ {
if attempt > 0 {
lastErr = internal.Sleep(ctx, db.retryBackoff(attempt-1))
if lastErr != nil {
break
}
}
res, lastErr = db._insert(ctx, model, query, block)
if !db.shouldRetry(lastErr) {
break
}
}
return res, lastErr
}
func (db *DB) _insert(
ctx context.Context, model TableModel, query string, block *chschema.Block,
) (*result, error) {
var res *result
err := db.withConn(ctx, func(cn *chpool.Conn) error {
if err := cn.WithWriter(ctx, db.cfg.WriteTimeout, func(wr *chproto.Writer) {
db.writeQuery(wr, query)
writeBlock(ctx, wr, nil)
}); err != nil {
return err
}
if err := cn.WithReader(ctx, db.cfg.ReadTimeout, func(rd *chproto.Reader) error {
_, err := readSampleBlock(rd)
return err
}); err != nil {
return err
}
if err := cn.WithWriter(ctx, db.cfg.WriteTimeout, func(wr *chproto.Writer) {
writeBlock(ctx, wr, block)
writeBlock(ctx, wr, nil)
}); err != nil {
return err
}
return cn.WithReader(ctx, db.cfg.ReadTimeout, func(rd *chproto.Reader) error {
var err error
res, err = readPacket(rd)
if err != nil {
return err
}
res.affected = block.NumRow
return nil
})
})
return res, err
}
func (db *DB) NewSelect() *SelectQuery {
return NewSelectQuery(db)
}
func (db *DB) NewInsert() *InsertQuery {
return NewInsertQuery(db)
}
func (db *DB) NewCreateTable() *CreateTableQuery {
return NewCreateTableQuery(db)
}
func (db *DB) NewDropTable() *DropTableQuery {
return NewDropTableQuery(db)
}
func (db *DB) NewTruncateTable() *TruncateTableQuery {
return NewTruncateTableQuery(db)
}
func (db *DB) ResetModel(ctx context.Context, models ...any) error {
for _, model := range models {
if _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx); err != nil {
return err
}
if _, err := db.NewCreateTable().Model(model).Exec(ctx); err != nil {
return err
}
}
return nil
}
func (db *DB) Formatter() chschema.Formatter {
return db.fmter
}
func (db *DB) WithFormatter(fmter chschema.Formatter) *DB {
clone := db.clone()
clone.fmter = fmter
return clone
}
func (db *DB) shouldRetry(err error) bool {
switch err {
case driver.ErrBadConn:
return true
case nil, context.Canceled, context.DeadlineExceeded:
return false
}
if err, ok := err.(*Error); ok {
// https://github.com/ClickHouse/ClickHouse/blob/master/src/Common/ErrorCodes.cpp
const (
timeoutExceeded = 159
tooManySimultaneousQueries = 202
memoryLimitExceeded = 241
)
switch err.Code {
case timeoutExceeded, tooManySimultaneousQueries, memoryLimitExceeded:
return true
}
}
return false
}
func (db *DB) retryBackoff(attempt int) time.Duration {
return internal.RetryBackoff(
attempt, db.cfg.MinRetryBackoff, db.cfg.MaxRetryBackoff)
}
func (db *DB) FormatQuery(query string, args ...any) string {
return db.fmter.FormatQuery(query, args...)
}
func (db *DB) makeQueryBytes() []byte {
// TODO: make this configurable?
return make([]byte, 0, 4096)
}
//------------------------------------------------------------------------------
// Rows is the result of a query. Its cursor starts before the first row of the result set.
// Use Next to advance from row to row.
type Rows struct {
blocks []*chschema.Block
block *chschema.Block
blockIndex int
rowIndex int
}
func newRows() *Rows {
return new(Rows)
}
func (rs *Rows) Close() error {
return nil
}
func (rs *Rows) ColumnTypes() ([]*sql.ColumnType, error) {
return nil, errors.New("not implemented")
}
func (rs *Rows) Columns() ([]string, error) {
return nil, errors.New("not implemented")
}
func (rs *Rows) Err() error {
return nil
}
func (rs *Rows) Next() bool {
if rs.block != nil && rs.rowIndex < rs.block.NumRow {
rs.rowIndex++
return true
}
for rs.blockIndex < len(rs.blocks) {
rs.block = rs.blocks[rs.blockIndex]
rs.blockIndex++
if rs.block.NumRow > 0 {
rs.rowIndex = 1
return true
}
}
return false
}
func (rs *Rows) NextResultSet() bool {
return false
}
func (rs *Rows) Scan(dest ...any) error {
if rs.block == nil {
return errors.New("ch: Scan called without calling Next")
}
if rs.block.NumColumn != len(dest) {
return fmt.Errorf("ch: got %d columns, but Scan has %d values",
rs.block.NumColumn, len(dest))
}
for i, col := range rs.block.Columns {
if err := col.ConvertAssign(rs.rowIndex-1, reflect.ValueOf(dest[i]).Elem()); err != nil {
return err
}
}
return nil
}
func (rs *Rows) ScanBlock(block *chschema.Block) error {
rs.blocks = append(rs.blocks, block)
return nil
}
type Row struct {
rows *Rows
err error
}
func (r *Row) Err() error {
return r.err
}
func (r *Row) Scan(dest ...any) error {
if r.err != nil {
return r.err
}
defer r.rows.Close()
if r.rows.Next() {
return r.rows.Scan(dest...)
}
return sql.ErrNoRows
}

475
ch/db_test.go Normal file
View File

@ -0,0 +1,475 @@
package ch_test
import (
"context"
"database/sql"
"fmt"
"os"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/uptrace/go-clickhouse/ch"
"github.com/uptrace/go-clickhouse/chdebug"
)
func chDB(opts ...ch.Option) *ch.DB {
dsn := os.Getenv("CH")
if dsn == "" {
dsn = "clickhouse://localhost:9000/test?sslmode=disable"
}
opts = append(opts, ch.WithDSN(dsn))
db := ch.Connect(opts...)
db.AddQueryHook(chdebug.NewQueryHook(
chdebug.WithEnabled(false),
chdebug.FromEnv("CHDEBUG"),
))
return db
}
func TestCHError(t *testing.T) {
ctx := context.Background()
db := chDB()
defer db.Close()
err := db.Ping(ctx)
require.NoError(t, err)
res, err := db.ExecContext(ctx, "hi")
require.Error(t, err)
require.Nil(t, res)
exc := err.(*ch.Error)
require.Equal(t, int32(62), exc.Code)
require.Equal(t, "DB::Exception", exc.Name)
}
func TestCHTimeout(t *testing.T) {
ctx := context.Background()
db := chDB(ch.WithTimeout(time.Second), ch.WithMaxRetries(0))
defer db.Close()
_, err := db.ExecContext(
ctx, "SELECT sleepEachRow(0.01) from numbers(10000) settings max_block_size=10")
require.Error(t, err)
require.Contains(t, err.Error(), "i/o timeout")
require.Eventually(t, func() bool {
var num int
err := db.NewSelect().ColumnExpr("count()").TableExpr("system.processes").Scan(ctx, &num)
require.NoError(t, err)
return num == 1
}, time.Second, 100*time.Millisecond)
}
func TestDSNSetting(t *testing.T) {
ctx := context.Background()
for _, value := range []int{0, 1} {
t.Run("prefer_column_name_to_alias=%d", func(t *testing.T) {
db := ch.Connect(ch.WithDSN(fmt.Sprintf(
"clickhouse://localhost:9000/default?sslmode=disable&prefer_column_name_to_alias=%d",
value,
)))
defer db.Close()
err := db.Ping(ctx)
require.NoError(t, err)
var got string
err = db.NewSelect().
ColumnExpr("value").
TableExpr("system.settings").
Where("name = 'prefer_column_name_to_alias'").
Scan(ctx, &got)
require.NoError(t, err)
require.Equal(t, got, fmt.Sprint(value))
})
}
}
func TestNullable(t *testing.T) {
ctx := context.Background()
db := chDB()
defer db.Close()
type Model struct {
Name *string
CreatedAt time.Time `ch:",pk"`
}
err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)
models := []Model{
{Name: strptr("hello"), CreatedAt: time.Unix(1e6, 0).Local()},
{Name: strptr(""), CreatedAt: time.Unix(1e6+1, 0).Local()},
{Name: nil, CreatedAt: time.Unix(1e6+2, 0).Local()},
}
_, err = db.NewInsert().Model(&models).Exec(ctx)
require.NoError(t, err)
var models2 []Model
err = db.NewSelect().Model(&models2).Scan(ctx)
require.NoError(t, err)
require.Equal(t, models, models2)
var ms []map[string]any
err = db.NewSelect().Model((*Model)(nil)).OrderExpr("created_at").Scan(ctx, &ms)
require.NoError(t, err)
require.Equal(t, []map[string]any{
{"name": "hello", "created_at": time.Unix(1e6, 0)},
{"name": "", "created_at": time.Unix(1e6+1, 0)},
{"name": nil, "created_at": time.Unix(1e6+2, 0)},
}, ms)
}
func TestPlaceholder(t *testing.T) {
ctx := context.Background()
db := chDB()
defer db.Close()
params := struct {
A int
B int
Alias ch.Ident
}{
A: 1,
B: 2,
Alias: "sum",
}
t.Run("raw", func(t *testing.T) {
var sum int
err := db.QueryRow("SELECT ?a + ?b AS ?alias", params).Scan(&sum)
require.NoError(t, err)
require.Equal(t, 3, sum)
res, err := db.Exec("SELECT ?a + ?b AS ?alias", params)
require.NoError(t, err)
n, err := res.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1), n)
})
t.Run("query builder", func(t *testing.T) {
var sum int
err := db.NewSelect().ColumnExpr("?a + ?b AS ?alias", params).Scan(ctx, &sum)
require.NoError(t, err)
require.Equal(t, 3, sum)
})
}
func TestScanArray(t *testing.T) {
ctx := context.Background()
db := chDB()
defer db.Close()
t.Run("uint64", func(t *testing.T) {
var nums []uint64
err := db.NewSelect().
ColumnExpr("groupArray(number)").
TableExpr("numbers(3)").
Scan(ctx, &nums)
require.NoError(t, err)
require.Equal(t, []uint64{0, 1, 2}, nums)
})
t.Run("float64", func(t *testing.T) {
var nums []float64
var str string
err := db.NewSelect().ColumnExpr("[1., 2, 3], 'hello'").Scan(ctx, &nums, &str)
require.NoError(t, err)
require.Equal(t, []float64{1, 2, 3}, nums)
require.Equal(t, "hello", str)
})
}
func TestScanEmptyResult(t *testing.T) {
ctx := context.Background()
db := chDB()
defer db.Close()
var m map[string]any
err := db.NewSelect().TableExpr("numbers(0)").Scan(ctx, &m)
require.NoError(t, err)
require.Equal(t, map[string]any{
"number": uint64(0),
}, m)
}
func TestScanNaN(t *testing.T) {
ctx := context.Background()
db := chDB()
defer db.Close()
t.Run("uint32", func(t *testing.T) {
var num uint32
err := db.QueryRowContext(ctx, "SELECT NaN").Scan(&num)
require.NoError(t, err)
require.Equal(t, uint32(0), num)
})
t.Run("int32", func(t *testing.T) {
var num int32
err := db.QueryRowContext(ctx, "SELECT NaN").Scan(&num)
require.NoError(t, err)
require.Equal(t, int32(0), num)
})
}
func TestScanArrayUint8(t *testing.T) {
ctx := context.Background()
db := chDB()
defer db.Close()
var m map[string]any
err := db.NewSelect().
ColumnExpr("topK(3)(toUInt8(number)) AS ns").
TableExpr("numbers(10)").
Scan(ctx, &m)
require.NoError(t, err)
require.Equal(t, map[string]any{"ns": []uint8{0, 1, 2}}, m)
}
type Event struct {
ch.CHModel `ch:"goch_events,partition:toYYYYMM(created_at)"`
ID uint64
Name string `ch:",lc"`
Count uint32
Keys []string `ch:",lc"`
Values [][]string
Kind string `ch:"type:Enum8('invalid' = 0, 'hello' = 1, 'world' = 2)"`
CreatedAt time.Time `ch:",pk"`
}
type EventColumnar struct {
ch.CHModel `ch:"goch_events,columnar"`
ID []uint64
Name []string `ch:",lc"`
Count []uint32
Keys [][]string `ch:"type:Array(LowCardinality(String))"`
Values [][][]string
Kind []string `ch:"type:Enum8('invalid' = 0, 'hello' = 1, 'world' = 2)"`
CreatedAt []time.Time
}
func TestORM(t *testing.T) {
ctx := context.Background()
db := chDB()
defer db.Close()
err := db.ResetModel(ctx, (*Event)(nil))
require.NoError(t, err)
tests := []func(t *testing.T, db *ch.DB){
testORMStruct,
testORMSlice,
testORMColumnarStruct,
testORMInvalidEnumValue,
}
for _, fn := range tests {
_, err := db.NewTruncateTable().Model((*Event)(nil)).Exec(ctx)
require.NoError(t, err)
t.Run("", func(t *testing.T) {
fn(t, db)
})
}
}
func testORMStruct(t *testing.T, db *ch.DB) {
ctx := context.Background()
err := db.NewSelect().Model(new(Event)).Scan(ctx)
require.Equal(t, sql.ErrNoRows, err)
src := &Event{
ID: 1,
Name: "hello",
Count: 42,
Keys: []string{"foo", "bar"},
Values: [][]string{{}, {"hello", "world"}},
Kind: "hello",
CreatedAt: time.Time{},
}
_, err = db.NewInsert().Model(src).Exec(ctx)
require.NoError(t, err)
dest := new(Event)
err = db.NewSelect().Model(dest).Scan(ctx)
require.NoError(t, err)
require.Equal(t, src, dest)
n, err := db.NewSelect().Model((*Event)(nil)).Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, n)
names := make([]string, 0)
counts := make([]uint32, 0)
err = db.NewSelect().
Model((*Event)(nil)).
Column("name", "count").
ScanColumns(ctx, &names, &counts)
require.NoError(t, err)
require.Equal(t, []string{"hello"}, names)
require.Equal(t, []uint32{42}, counts)
var m map[string]any
err = db.NewSelect().Model((*Event)(nil)).ScanColumns(ctx, &m)
require.NoError(t, err)
require.Equal(t, map[string]any{
"id": []uint64{1},
"name": []string{"hello"},
"count": []uint32{42},
"keys": [][]string{{"foo", "bar"}},
"values": [][][]string{{{}, {"hello", "world"}}},
"kind": []string{"hello"},
"created_at": []time.Time{{}},
}, m)
}
func testORMSlice(t *testing.T, db *ch.DB) {
ctx := context.Background()
var events []*Event
err := db.NewSelect().Model(&events).Scan(ctx)
require.NoError(t, err)
require.Equal(t, 0, len(events))
src := []*Event{{
ID: 1,
Name: "hello",
Count: 42,
Keys: []string{"foo", "bar"},
Values: [][]string{{}, {"hello", "world"}},
Kind: "hello",
CreatedAt: time.Time{},
}, {
ID: 2,
Name: "world",
Count: 84,
Keys: []string{"1", "2", "3"},
Values: [][]string{{}, {"hello", "world"}, {}},
Kind: "world",
CreatedAt: time.Unix(1000, 0),
}}
_, err = db.NewInsert().Model(&src).Exec(ctx)
require.NoError(t, err)
var dest []*Event
err = db.NewSelect().Model(&dest).OrderExpr("id ASC").Scan(ctx)
require.NoError(t, err)
require.Equal(t, src, dest)
n, err := db.NewSelect().Model((*Event)(nil)).Count(ctx)
require.NoError(t, err)
require.Equal(t, 2, n)
var temp []struct {
Name string `ch:"type:LowCardinality(String)"`
Count uint64
}
err = db.NewSelect().
Model((*Event)(nil)).
ColumnExpr("name, count(*) as count").
GroupExpr("name").
OrderExpr("name asc").
Scan(ctx, &temp)
require.NoError(t, err)
require.Equal(t, 2, len(temp))
require.Equal(t, "hello", temp[0].Name)
require.Equal(t, uint64(1), temp[0].Count)
require.Equal(t, "world", temp[1].Name)
require.Equal(t, uint64(1), temp[1].Count)
names := make([]string, 0)
counts := make([]uint32, 0)
err = db.NewSelect().
Model((*Event)(nil)).
Column("name", "count").
ScanColumns(ctx, &names, &counts)
require.NoError(t, err)
require.Equal(t, []string{"hello", "world"}, names)
require.Equal(t, []uint32{42, 84}, counts)
var values []map[string]any
err = db.NewSelect().Model((*Event)(nil)).Scan(ctx, &values)
require.NoError(t, err)
require.Equal(t, []map[string]any{{
"id": uint64(1),
"name": "hello",
"count": uint32(42),
"keys": []string{"foo", "bar"},
"values": [][]string{{}, {"hello", "world"}},
"kind": "hello",
"created_at": time.Time{},
}, {
"id": uint64(2),
"name": "world",
"count": uint32(84),
"keys": []string{"1", "2", "3"},
"values": [][]string{{}, {"hello", "world"}, {}},
"kind": "world",
"created_at": time.Unix(1000, 0),
}}, values)
}
func testORMColumnarStruct(t *testing.T, db *ch.DB) {
ctx := context.Background()
err := db.NewSelect().Model(new(EventColumnar)).Scan(ctx)
require.NoError(t, err)
src := &EventColumnar{
ID: []uint64{1, 2},
Name: []string{"hello", "world"},
Count: []uint32{42, 84},
Keys: [][]string{{"foo", "bar"}, {"1", "2", "3"}},
Values: [][][]string{{{}, {"hello", "world"}}, {{}, {}, {}}},
Kind: []string{"hello", "world"},
CreatedAt: []time.Time{{}, time.Unix(1000, 0)},
}
_, err = db.NewInsert().Model(src).Exec(ctx)
require.NoError(t, err)
dest := new(EventColumnar)
err = db.NewSelect().Model(dest).OrderExpr("id ASC").Scan(ctx)
require.NoError(t, err)
require.Equal(t, src, dest)
}
func testORMInvalidEnumValue(t *testing.T, db *ch.DB) {
ctx := context.Background()
src := &Event{
Kind: "foobar",
}
_, err := db.NewInsert().Model(src).Exec(ctx)
require.NoError(t, err)
dest := new(Event)
err = db.NewSelect().Model(dest).Scan(ctx)
require.NoError(t, err)
require.Equal(t, "invalid", dest.Kind)
}
func strptr(s string) *string {
return &s
}

134
ch/hook.go Normal file
View File

@ -0,0 +1,134 @@
package ch
import (
"context"
"database/sql"
"reflect"
"strings"
"time"
)
type QueryEvent struct {
DB *DB
Model Model
IQuery Query
Query string
QueryArgs []any
StartTime time.Time
Result sql.Result
Err error
Stash map[any]any
}
func (e *QueryEvent) Operation() string {
if e.IQuery != nil {
return e.IQuery.Operation()
}
return queryOperation(e.Query)
}
func queryOperation(query string) string {
if idx := strings.IndexByte(query, ' '); idx > 0 {
query = query[:idx]
}
if len(query) > 16 {
query = query[:16]
}
return query
}
// QueryHook ...
type QueryHook interface {
BeforeQuery(context.Context, *QueryEvent) context.Context
AfterQuery(context.Context, *QueryEvent)
}
// AddQueryHook adds a hook into query processing.
func (db *DB) AddQueryHook(hook QueryHook) {
db.queryHooks = append(db.queryHooks, hook)
}
func (db *DB) beforeQuery(
ctx context.Context,
iquery Query,
query string,
params []any,
model Model,
) (context.Context, *QueryEvent) {
if len(db.queryHooks) == 0 {
return ctx, nil
}
evt := &QueryEvent{
StartTime: time.Now(),
DB: db,
Model: model,
IQuery: iquery,
Query: query,
QueryArgs: params,
}
for _, hook := range db.queryHooks {
ctx = hook.BeforeQuery(ctx, evt)
}
return ctx, evt
}
func (db *DB) afterQuery(
ctx context.Context,
evt *QueryEvent,
res *result,
err error,
) {
if evt == nil {
return
}
evt.Err = err
if res != nil {
evt.Result = res
}
for _, hook := range db.queryHooks {
hook.AfterQuery(ctx, evt)
}
}
//---------------------------------------------------------------------------------------
func callAfterScanRowHook(ctx context.Context, v reflect.Value) error {
return v.Interface().(AfterScanRowHook).AfterScanRow(ctx)
}
func callAfterScanRowHookSlice(ctx context.Context, slice reflect.Value) error {
return callHookSlice(ctx, slice, callAfterScanRowHook)
}
func callHookSlice(
ctx context.Context,
slice reflect.Value,
hook func(context.Context, reflect.Value) error,
) error {
var ptr bool
switch slice.Type().Elem().Kind() {
case reflect.Ptr, reflect.Interface:
ptr = true
}
var firstErr error
sliceLen := slice.Len()
for i := 0; i < sliceLen; i++ {
v := slice.Index(i)
if !ptr {
v = v.Addr()
}
err := hook(ctx, v)
if err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}

View File

@ -0,0 +1,20 @@
The MIT License (MIT)
Copyright (c) 2013 zhenjl
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -0,0 +1,45 @@
package cityhash102
import (
"encoding/binary"
"hash"
)
type City64 struct {
s []byte
}
var _ hash.Hash64 = (*City64)(nil)
var _ hash.Hash = (*City64)(nil)
func New64() hash.Hash64 {
return &City64{}
}
func (this *City64) Sum(b []byte) []byte {
b2 := make([]byte, 8)
binary.BigEndian.PutUint64(b2, this.Sum64())
b = append(b, b2...)
return b
}
func (this *City64) Sum64() uint64 {
return CityHash64(this.s, uint32(len(this.s)))
}
func (this *City64) Reset() {
this.s = this.s[0:0]
}
func (this *City64) BlockSize() int {
return 1
}
func (this *City64) Write(s []byte) (n int, err error) {
this.s = append(this.s, s...)
return len(s), nil
}
func (this *City64) Size() int {
return 8
}

View File

@ -0,0 +1,383 @@
/*
* Go implementation of Google city hash (MIT license)
* https://code.google.com/p/cityhash/
*
* MIT License http://www.opensource.org/licenses/mit-license.php
*
* I don't even want to pretend to understand the details of city hash.
* I am only reproducing the logic in Go as faithfully as I can.
*
*/
package cityhash102
import (
"encoding/binary"
)
const (
k0 uint64 = 0xc3a5c85c97cb3127
k1 uint64 = 0xb492b66fbe98f273
k2 uint64 = 0x9ae16a3b2f90404f
k3 uint64 = 0xc949d7c7509e6557
kMul uint64 = 0x9ddfea08eb382d69
)
func fetch64(p []byte) uint64 {
return binary.LittleEndian.Uint64(p)
//return uint64InExpectedOrder(unalignedLoad64(p))
}
func fetch32(p []byte) uint32 {
return binary.LittleEndian.Uint32(p)
//return uint32InExpectedOrder(unalignedLoad32(p))
}
func rotate64(val uint64, shift uint32) uint64 {
if shift != 0 {
return ((val >> shift) | (val << (64 - shift)))
}
return val
}
func rotate32(val uint32, shift uint32) uint32 {
if shift != 0 {
return ((val >> shift) | (val << (32 - shift)))
}
return val
}
func swap64(a, b *uint64) {
*a, *b = *b, *a
}
func swap32(a, b *uint32) {
*a, *b = *b, *a
}
func permute3(a, b, c *uint32) {
swap32(a, b)
swap32(a, c)
}
func rotate64ByAtLeast1(val uint64, shift uint32) uint64 {
return (val >> shift) | (val << (64 - shift))
}
func shiftMix(val uint64) uint64 {
return val ^ (val >> 47)
}
type Uint128 [2]uint64
func (this *Uint128) setLower64(l uint64) {
this[0] = l
}
func (this *Uint128) setHigher64(h uint64) {
this[1] = h
}
func (this Uint128) Lower64() uint64 {
return this[0]
}
func (this Uint128) Higher64() uint64 {
return this[1]
}
func (this Uint128) Bytes() []byte {
b := make([]byte, 16)
binary.LittleEndian.PutUint64(b, this[0])
binary.LittleEndian.PutUint64(b[8:], this[1])
return b
}
func hash128to64(x Uint128) uint64 {
// Murmur-inspired hashing.
var a = (x.Lower64() ^ x.Higher64()) * kMul
a ^= (a >> 47)
var b = (x.Higher64() ^ a) * kMul
b ^= (b >> 47)
b *= kMul
return b
}
func hashLen16(u, v uint64) uint64 {
return hash128to64(Uint128{u, v})
}
func hashLen16_3(u, v, mul uint64) uint64 {
// Murmur-inspired hashing.
var a = (u ^ v) * mul
a ^= (a >> 47)
var b = (v ^ a) * mul
b ^= (b >> 47)
b *= mul
return b
}
func hashLen0to16(s []byte, length uint32) uint64 {
if length > 8 {
var a = fetch64(s)
var b = fetch64(s[length-8:])
return hashLen16(a, rotate64ByAtLeast1(b+uint64(length), length)) ^ b
}
if length >= 4 {
var a = fetch32(s)
return hashLen16(uint64(length)+(uint64(a)<<3), uint64(fetch32(s[length-4:])))
}
if length > 0 {
var a uint8 = uint8(s[0])
var b uint8 = uint8(s[length>>1])
var c uint8 = uint8(s[length-1])
var y uint32 = uint32(a) + (uint32(b) << 8)
var z uint32 = length + (uint32(c) << 2)
return shiftMix(uint64(y)*k2^uint64(z)*k3) * k2
}
return k2
}
// This probably works well for 16-byte strings as well, but it may be overkill
func hashLen17to32(s []byte, length uint32) uint64 {
var a = fetch64(s) * k1
var b = fetch64(s[8:])
var c = fetch64(s[length-8:]) * k2
var d = fetch64(s[length-16:]) * k0
return hashLen16(rotate64(a-b, 43)+rotate64(c, 30)+d,
a+rotate64(b^k3, 20)-c+uint64(length))
}
func weakHashLen32WithSeeds(w, x, y, z, a, b uint64) Uint128 {
a += w
b = rotate64(b+a+z, 21)
var c uint64 = a
a += x
a += y
b += rotate64(a, 44)
return Uint128{a + z, b + c}
}
func weakHashLen32WithSeeds_3(s []byte, a, b uint64) Uint128 {
return weakHashLen32WithSeeds(fetch64(s), fetch64(s[8:]), fetch64(s[16:]), fetch64(s[24:]), a, b)
}
func hashLen33to64(s []byte, length uint32) uint64 {
var z uint64 = fetch64(s[24:])
var a uint64 = fetch64(s) + (uint64(length)+fetch64(s[length-16:]))*k0
var b uint64 = rotate64(a+z, 52)
var c uint64 = rotate64(a, 37)
a += fetch64(s[8:])
c += rotate64(a, 7)
a += fetch64(s[16:])
var vf uint64 = a + z
var vs = b + rotate64(a, 31) + c
a = fetch64(s[16:]) + fetch64(s[length-32:])
z = fetch64(s[length-8:])
b = rotate64(a+z, 52)
c = rotate64(a, 37)
a += fetch64(s[length-24:])
c += rotate64(a, 7)
a += fetch64(s[length-16:])
wf := a + z
ws := b + rotate64(a, 31) + c
r := shiftMix((vf+ws)*k2 + (wf+vs)*k0)
return shiftMix(r*k0+vs) * k2
}
func CityHash64(s []byte, length uint32) uint64 {
if length <= 32 {
if length <= 16 {
return hashLen0to16(s, length)
} else {
return hashLen17to32(s, length)
}
} else if length <= 64 {
return hashLen33to64(s, length)
}
var x uint64 = fetch64(s)
var y uint64 = fetch64(s[length-16:]) ^ k1
var z uint64 = fetch64(s[length-56:]) ^ k0
var v Uint128 = weakHashLen32WithSeeds_3(s[length-64:], uint64(length), y)
var w Uint128 = weakHashLen32WithSeeds_3(s[length-32:], uint64(length)*k1, k0)
z += shiftMix(v.Higher64()) * k1
x = rotate64(z+x, 39) * k1
y = rotate64(y, 33) * k1
length = (length - 1) & ^uint32(63)
for {
x = rotate64(x+y+v.Lower64()+fetch64(s[16:]), 37) * k1
y = rotate64(y+v.Higher64()+fetch64(s[48:]), 42) * k1
x ^= w.Higher64()
y ^= v.Lower64()
z = rotate64(z^w.Lower64(), 33)
v = weakHashLen32WithSeeds_3(s, v.Higher64()*k1, x+w.Lower64())
w = weakHashLen32WithSeeds_3(s[32:], z+w.Higher64(), y)
swap64(&z, &x)
s = s[64:]
length -= 64
if length == 0 {
break
}
}
return hashLen16(hashLen16(v.Lower64(), w.Lower64())+shiftMix(y)*k1+z, hashLen16(v.Higher64(), w.Higher64())+x)
}
func CityHash64WithSeed(s []byte, length uint32, seed uint64) uint64 {
return CityHash64WithSeeds(s, length, k2, seed)
}
func CityHash64WithSeeds(s []byte, length uint32, seed0, seed1 uint64) uint64 {
return hashLen16(CityHash64(s, length)-seed0, seed1)
}
func cityMurmur(s []byte, length uint32, seed Uint128) Uint128 {
var a uint64 = seed.Lower64()
var b uint64 = seed.Higher64()
var c uint64 = 0
var d uint64 = 0
var l int32 = int32(length) - 16
if l <= 0 { // len <= 16
a = shiftMix(a*k1) * k1
c = b*k1 + hashLen0to16(s, length)
if length >= 8 {
d = shiftMix(a + fetch64(s))
} else {
d = shiftMix(a + c)
}
} else { // len > 16
c = hashLen16(fetch64(s[length-8:])+k1, a)
d = hashLen16(b+uint64(length), c+fetch64(s[length-16:]))
a += d
for {
a ^= shiftMix(fetch64(s)*k1) * k1
a *= k1
b ^= a
c ^= shiftMix(fetch64(s[8:])*k1) * k1
c *= k1
d ^= c
s = s[16:]
l -= 16
if l <= 0 {
break
}
}
}
a = hashLen16(a, c)
b = hashLen16(d, b)
return Uint128{a ^ b, hashLen16(b, a)}
}
func CityHash128WithSeed(s []byte, length uint32, seed Uint128) Uint128 {
if length < 128 {
return cityMurmur(s, length, seed)
}
// We expect length >= 128 to be the common case. Keep 56 bytes of state:
// v, w, x, y, and z.
var v, w Uint128
var x uint64 = seed.Lower64()
var y uint64 = seed.Higher64()
var z uint64 = uint64(length) * k1
var pos uint32
var t = s
v.setLower64(rotate64(y^k1, 49)*k1 + fetch64(s))
v.setHigher64(rotate64(v.Lower64(), 42)*k1 + fetch64(s[8:]))
w.setLower64(rotate64(y+z, 35)*k1 + x)
w.setHigher64(rotate64(x+fetch64(s[88:]), 53) * k1)
// This is the same inner loop as CityHash64(), manually unrolled.
for {
x = rotate64(x+y+v.Lower64()+fetch64(s[16:]), 37) * k1
y = rotate64(y+v.Higher64()+fetch64(s[48:]), 42) * k1
x ^= w.Higher64()
y ^= v.Lower64()
z = rotate64(z^w.Lower64(), 33)
v = weakHashLen32WithSeeds_3(s, v.Higher64()*k1, x+w.Lower64())
w = weakHashLen32WithSeeds_3(s[32:], z+w.Higher64(), y)
swap64(&z, &x)
s = s[64:]
pos += 64
x = rotate64(x+y+v.Lower64()+fetch64(s[16:]), 37) * k1
y = rotate64(y+v.Higher64()+fetch64(s[48:]), 42) * k1
x ^= w.Higher64()
y ^= v.Lower64()
z = rotate64(z^w.Lower64(), 33)
v = weakHashLen32WithSeeds_3(s, v.Higher64()*k1, x+w.Lower64())
w = weakHashLen32WithSeeds_3(s[32:], z+w.Higher64(), y)
swap64(&z, &x)
s = s[64:]
pos += 64
length -= 128
if length < 128 {
break
}
}
y += rotate64(w.Lower64(), 37)*k0 + z
x += rotate64(v.Lower64()+z, 49) * k0
// If 0 < length < 128, hash up to 4 chunks of 32 bytes each from the end of s.
var tailDone uint32
for tailDone = 0; tailDone < length; {
tailDone += 32
y = rotate64(y-x, 42)*k0 + v.Higher64()
//TODO why not use origin_len ?
w.setLower64(w.Lower64() + fetch64(t[pos+length-tailDone+16:]))
x = rotate64(x, 49)*k0 + w.Lower64()
w.setLower64(w.Lower64() + v.Lower64())
v = weakHashLen32WithSeeds_3(t[pos+length-tailDone:], v.Lower64(), v.Higher64())
}
// At this point our 48 bytes of state should contain more than
// enough information for a strong 128-bit hash. We use two
// different 48-byte-to-8-byte hashes to get a 16-byte final result.
x = hashLen16(x, v.Lower64())
y = hashLen16(y, w.Lower64())
return Uint128{hashLen16(x+v.Higher64(), w.Higher64()) + y,
hashLen16(x+w.Higher64(), y+v.Higher64())}
}
func CityHash128(s []byte, length uint32) (result Uint128) {
if length >= 16 {
result = CityHash128WithSeed(s[16:length], length-16, Uint128{fetch64(s) ^ k3, fetch64(s[8:])})
} else if length >= 8 {
result = CityHash128WithSeed(nil, 0, Uint128{fetch64(s) ^ (uint64(length) * k0), fetch64(s[length-8:]) ^ k1})
} else {
result = CityHash128WithSeed(s, length, Uint128{k0, k1})
}
return
}

View File

@ -0,0 +1,65 @@
package cityhash102
import (
"bufio"
"os"
"strconv"
"strings"
"testing"
)
const (
kSeed0 uint64 = 1234567
kSeed1 uint64 = k0
)
type TestCase struct {
key string
lower uint64
upper uint64
}
var testdata = []TestCase{}
func buildData(t *testing.T) {
f, err := os.Open("testdata/hashs.txt")
if err != nil {
t.Fatal(err)
}
scanner := bufio.NewScanner(f)
var lower uint64
var upper uint64
for scanner.Scan() {
strs := strings.Split(scanner.Text(), ",")
lower, _ = strconv.ParseUint(strs[1], 16, 64)
upper, _ = strconv.ParseUint(strs[2], 16, 64)
testdata = append(testdata, TestCase{strs[0], lower, upper})
}
}
func check(str string, expected, actual uint64, t *testing.T) {
if expected != actual {
t.Errorf("ERROR: %s expected 0x%x but got 0x%x\n", str, expected, actual)
}
}
func test(str string, lower uint64, upper uint64, t *testing.T) {
var u Uint128 = CityHash128([]byte(str), uint32(len(str)))
check(str, lower, u.Lower64(), t)
check(str, upper, u.Higher64(), t)
}
func Test_Hash(t *testing.T) {
buildData(t)
var i int
for i = 0; i < len(testdata); i++ {
t.Logf("INFO: offset = %d, length = %d", i, len(testdata))
test(testdata[i].key, testdata[i].lower, testdata[i].upper, t)
}
return
}

View File

@ -0,0 +1,5 @@
/** COPY from https://github.com/zentures/cityhash/
NOTE: The code is modified to be compatible with CityHash128 used in ClickHouse
*/
package cityhash102

View File

@ -0,0 +1,365 @@
#include <fstream>
#include <iostream>
#include <cstdio>
#include <string.h>
#include <algorithm>
typedef uint8_t uint8;
typedef uint32_t uint32;
typedef uint64_t uint64;
typedef std::pair<uint64, uint64> uint128;
using namespace std;
uint64 Uint128Low64(const uint128& x) { return x.first; }
uint64 Uint128High64(const uint128& x) { return x.second; }
// Hash function for a byte array.
uint64 CityHash64(const char *buf, size_t len);
// Hash function for a byte array. For convenience, a 64-bit seed is also
// hashed into the result.
uint64 CityHash64WithSeed(const char *buf, size_t len, uint64 seed);
// Hash function for a byte array. For convenience, two seeds are also
// hashed into the result.
uint64 CityHash64WithSeeds(const char *buf, size_t len,
uint64 seed0, uint64 seed1);
// Hash function for a byte array.
uint128 CityHash128(const char *s, size_t len);
// Hash function for a byte array. For convenience, a 128-bit seed is also
// hashed into the result.
uint128 CityHash128WithSeed(const char *s, size_t len, uint128 seed);
uint64 Hash128to64(const uint128& x) {
// Murmur-inspired hashing.
const uint64 kMul = 0x9ddfea08eb382d69ULL;
uint64 a = (Uint128Low64(x) ^ Uint128High64(x)) * kMul;
a ^= (a >> 47);
uint64 b = (Uint128High64(x) ^ a) * kMul;
b ^= (b >> 47);
b *= kMul;
return b;
}
#define uint32_in_expected_order(x) (x)
#define uint64_in_expected_order(x) (x)
static uint64 UNALIGNED_LOAD64(const char *p) {
uint64 result;
memcpy(&result, p, sizeof(result));
return result;
}
static uint32 UNALIGNED_LOAD32(const char *p) {
uint32 result;
memcpy(&result, p, sizeof(result));
return result;
}
static uint64 Fetch64(const char *p) {
return uint64_in_expected_order(UNALIGNED_LOAD64(p));
}
static uint32 Fetch32(const char *p) {
return uint32_in_expected_order(UNALIGNED_LOAD32(p));
}
// Some primes between 2^63 and 2^64 for various uses.
static const uint64 k0 = 0xc3a5c85c97cb3127ULL;
static const uint64 k1 = 0xb492b66fbe98f273ULL;
static const uint64 k2 = 0x9ae16a3b2f90404fULL;
static const uint64 k3 = 0xc949d7c7509e6557ULL;
// Bitwise right rotate. Normally this will compile to a single
// instruction, especially if the shift is a manifest constant.
static uint64 Rotate(uint64 val, int shift) {
// Avoid shifting by 64: doing so yields an undefined result.
return shift == 0 ? val : ((val >> shift) | (val << (64 - shift)));
}
// Equivalent to Rotate(), but requires the second arg to be non-zero.
// On x86-64, and probably others, it's possible for this to compile
// to a single instruction if both args are already in registers.
static uint64 RotateByAtLeast1(uint64 val, int shift) {
return (val >> shift) | (val << (64 - shift));
}
static uint64 ShiftMix(uint64 val) {
return val ^ (val >> 47);
}
static uint64 HashLen16(uint64 u, uint64 v) {
return Hash128to64(uint128(u, v));
}
static uint64 HashLen0to16(const char *s, size_t len) {
if (len > 8) {
uint64 a = Fetch64(s);
uint64 b = Fetch64(s + len - 8);
return HashLen16(a, RotateByAtLeast1(b + len, len)) ^ b;
}
if (len >= 4) {
uint64 a = Fetch32(s);
return HashLen16(len + (a << 3), Fetch32(s + len - 4));
}
if (len > 0) {
uint8 a = s[0];
uint8 b = s[len >> 1];
uint8 c = s[len - 1];
uint32 y = static_cast<uint32>(a) + (static_cast<uint32>(b) << 8);
uint32 z = len + (static_cast<uint32>(c) << 2);
return ShiftMix(y * k2 ^ z * k3) * k2;
}
return k2;
}
// This probably works well for 16-byte strings as well, but it may be overkill
// in that case.
static uint64 HashLen17to32(const char *s, size_t len) {
uint64 a = Fetch64(s) * k1;
uint64 b = Fetch64(s + 8);
uint64 c = Fetch64(s + len - 8) * k2;
uint64 d = Fetch64(s + len - 16) * k0;
return HashLen16(Rotate(a - b, 43) + Rotate(c, 30) + d,
a + Rotate(b ^ k3, 20) - c + len);
}
// Return a 16-byte hash for 48 bytes. Quick and dirty.
// Callers do best to use "random-looking" values for a and b.
static pair<uint64, uint64> WeakHashLen32WithSeeds(
uint64 w, uint64 x, uint64 y, uint64 z, uint64 a, uint64 b) {
a += w;
b = Rotate(b + a + z, 21);
uint64 c = a;
a += x;
a += y;
b += Rotate(a, 44);
return make_pair(a + z, b + c);
}
// Return a 16-byte hash for s[0] ... s[31], a, and b. Quick and dirty.
static pair<uint64, uint64> WeakHashLen32WithSeeds(
const char* s, uint64 a, uint64 b) {
return WeakHashLen32WithSeeds(Fetch64(s),
Fetch64(s + 8),
Fetch64(s + 16),
Fetch64(s + 24),
a,
b);
}
// Return an 8-byte hash for 33 to 64 bytes.
static uint64 HashLen33to64(const char *s, size_t len) {
uint64 z = Fetch64(s + 24);
uint64 a = Fetch64(s) + (len + Fetch64(s + len - 16)) * k0;
uint64 b = Rotate(a + z, 52);
uint64 c = Rotate(a, 37);
a += Fetch64(s + 8);
c += Rotate(a, 7);
a += Fetch64(s + 16);
uint64 vf = a + z;
uint64 vs = b + Rotate(a, 31) + c;
a = Fetch64(s + 16) + Fetch64(s + len - 32);
z = Fetch64(s + len - 8);
b = Rotate(a + z, 52);
c = Rotate(a, 37);
a += Fetch64(s + len - 24);
c += Rotate(a, 7);
a += Fetch64(s + len - 16);
uint64 wf = a + z;
uint64 ws = b + Rotate(a, 31) + c;
uint64 r = ShiftMix((vf + ws) * k2 + (wf + vs) * k0);
return ShiftMix(r * k0 + vs) * k2;
}
uint64 CityHash64(const char *s, size_t len) {
if (len <= 32) {
if (len <= 16) {
return HashLen0to16(s, len);
} else {
return HashLen17to32(s, len);
}
} else if (len <= 64) {
return HashLen33to64(s, len);
}
// For strings over 64 bytes we hash the end first, and then as we
// loop we keep 56 bytes of state: v, w, x, y, and z.
uint64 x = Fetch64(s);
uint64 y = Fetch64(s + len - 16) ^ k1;
uint64 z = Fetch64(s + len - 56) ^ k0;
pair<uint64, uint64> v = WeakHashLen32WithSeeds(s + len - 64, len, y);
pair<uint64, uint64> w = WeakHashLen32WithSeeds(s + len - 32, len * k1, k0);
z += ShiftMix(v.second) * k1;
x = Rotate(z + x, 39) * k1;
y = Rotate(y, 33) * k1;
// Decrease len to the nearest multiple of 64, and operate on 64-byte chunks.
len = (len - 1) & ~static_cast<size_t>(63);
do {
x = Rotate(x + y + v.first + Fetch64(s + 16), 37) * k1;
y = Rotate(y + v.second + Fetch64(s + 48), 42) * k1;
x ^= w.second;
y ^= v.first;
z = Rotate(z ^ w.first, 33);
v = WeakHashLen32WithSeeds(s, v.second * k1, x + w.first);
w = WeakHashLen32WithSeeds(s + 32, z + w.second, y);
std::swap(z, x);
s += 64;
len -= 64;
} while (len != 0);
return HashLen16(HashLen16(v.first, w.first) + ShiftMix(y) * k1 + z,
HashLen16(v.second, w.second) + x);
}
uint64 CityHash64WithSeed(const char *s, size_t len, uint64 seed) {
return CityHash64WithSeeds(s, len, k2, seed);
}
uint64 CityHash64WithSeeds(const char *s, size_t len,
uint64 seed0, uint64 seed1) {
return HashLen16(CityHash64(s, len) - seed0, seed1);
}
// A subroutine for CityHash128(). Returns a decent 128-bit hash for strings
// of any length representable in ssize_t. Based on City and Murmur.
static uint128 CityMurmur(const char *s, size_t len, uint128 seed) {
uint64 a = Uint128Low64(seed);
uint64 b = Uint128High64(seed);
uint64 c = 0;
uint64 d = 0;
ssize_t l = len - 16;
if (l <= 0) { // len <= 16
a = ShiftMix(a * k1) * k1;
c = b * k1 + HashLen0to16(s, len);
d = ShiftMix(a + (len >= 8 ? Fetch64(s) : c));
} else { // len > 16
c = HashLen16(Fetch64(s + len - 8) + k1, a);
d = HashLen16(b + len, c + Fetch64(s + len - 16));
a += d;
do {
a ^= ShiftMix(Fetch64(s) * k1) * k1;
a *= k1;
b ^= a;
c ^= ShiftMix(Fetch64(s + 8) * k1) * k1;
c *= k1;
d ^= c;
s += 16;
l -= 16;
} while (l > 0);
}
a = HashLen16(a, c);
b = HashLen16(d, b);
return uint128(a ^ b, HashLen16(b, a));
}
uint128 CityHash128WithSeed(const char *s, size_t len, uint128 seed) {
if (len < 128) {
return CityMurmur(s, len, seed);
}
// We expect len >= 128 to be the common case. Keep 56 bytes of state:
// v, w, x, y, and z.
pair<uint64, uint64> v, w;
uint64 x = Uint128Low64(seed);
uint64 y = Uint128High64(seed);
uint64 z = len * k1;
v.first = Rotate(y ^ k1, 49) * k1 + Fetch64(s);
v.second = Rotate(v.first, 42) * k1 + Fetch64(s + 8);
w.first = Rotate(y + z, 35) * k1 + x;
w.second = Rotate(x + Fetch64(s + 88), 53) * k1;
// This is the same inner loop as CityHash64(), manually unrolled.
do {
x = Rotate(x + y + v.first + Fetch64(s + 16), 37) * k1;
y = Rotate(y + v.second + Fetch64(s + 48), 42) * k1;
x ^= w.second;
y ^= v.first;
z = Rotate(z ^ w.first, 33);
v = WeakHashLen32WithSeeds(s, v.second * k1, x + w.first);
w = WeakHashLen32WithSeeds(s + 32, z + w.second, y);
std::swap(z, x);
s += 64;
x = Rotate(x + y + v.first + Fetch64(s + 16), 37) * k1;
y = Rotate(y + v.second + Fetch64(s + 48), 42) * k1;
x ^= w.second;
y ^= v.first;
z = Rotate(z ^ w.first, 33);
v = WeakHashLen32WithSeeds(s, v.second * k1, x + w.first);
w = WeakHashLen32WithSeeds(s + 32, z + w.second, y);
std::swap(z, x);
s += 64;
len -= 128;
} while (len >= 128);
y += Rotate(w.first, 37) * k0 + z;
x += Rotate(v.first + z, 49) * k0;
// If 0 < len < 128, hash up to 4 chunks of 32 bytes each from the end of s.
for (size_t tail_done = 0; tail_done < len; ) {
tail_done += 32;
y = Rotate(y - x, 42) * k0 + v.second;
w.first += Fetch64(s + len - tail_done + 16);
x = Rotate(x, 49) * k0 + w.first;
w.first += v.first;
v = WeakHashLen32WithSeeds(s + len - tail_done, v.first, v.second);
}
// At this point our 48 bytes of state should contain more than
// enough information for a strong 128-bit hash. We use two
// different 48-byte-to-8-byte hashes to get a 16-byte final result.
x = HashLen16(x, v.first);
y = HashLen16(y, w.first);
return uint128(HashLen16(x + v.second, w.second) + y,
HashLen16(x + w.second, y + v.second));
}
uint128 CityHash128(const char *s, size_t len) {
if (len >= 16) {
return CityHash128WithSeed(s + 16,
len - 16,
uint128(Fetch64(s) ^ k3,
Fetch64(s + 8)));
} else if (len >= 8) {
return CityHash128WithSeed(NULL,
0,
uint128(Fetch64(s) ^ (len * k0),
Fetch64(s + len - 8) ^ k1));
} else {
return CityHash128WithSeed(s, len, uint128(k0, k1));
}
}
std::string random_string( size_t length )
{
auto randchar = []() -> char
{
const char charset[] =
"0123456789"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz";
const size_t max_index = (sizeof(charset) - 1);
return charset[ rand() % max_index ];
};
std::string str(length,0);
std::generate_n( str.begin(), length, randchar );
return str;
}
// g++ cityhash.cpp && ./a.out > hashs.txt
int main()
{
for (int i = 0; i < 1000; i++)
{
auto str = random_string( rand() % 1000 + 1);
auto res = CityHash128(str.c_str(), str.size() );
printf("%s,%lx,%lx\n", str.data() , res.first, res.second);
}
}

1000
ch/internal/cityhash102/testdata/hashs.txt vendored Normal file

File diff suppressed because it is too large Load Diff

15
ch/internal/flag.go Normal file
View File

@ -0,0 +1,15 @@
package internal
type Flag uint64
func (flag Flag) Has(other Flag) bool {
return flag&other == other
}
func (flag *Flag) Set(other Flag) {
*flag = *flag | other
}
func (flag *Flag) Remove(other Flag) {
*flag &= ^other
}

View File

@ -0,0 +1,141 @@
package parser
import (
"bytes"
"strconv"
"github.com/uptrace/go-clickhouse/ch/internal"
)
type Parser struct {
b []byte
i int
}
func New(b []byte) *Parser {
return &Parser{
b: b,
}
}
func NewString(s string) *Parser {
return New(internal.Bytes(s))
}
func (p *Parser) Valid() bool {
return p.i < len(p.b)
}
func (p *Parser) Bytes() []byte {
return p.b[p.i:]
}
func (p *Parser) Read() byte {
if p.Valid() {
c := p.b[p.i]
p.Advance()
return c
}
return 0
}
func (p *Parser) Peek() byte {
if p.Valid() {
return p.b[p.i]
}
return 0
}
func (p *Parser) Advance() {
p.i++
}
func (p *Parser) Skip(skip byte) bool {
if p.Peek() == skip {
p.Advance()
return true
}
return false
}
func (p *Parser) SkipBytes(skip []byte) bool {
if len(skip) > len(p.b[p.i:]) {
return false
}
if !bytes.Equal(p.b[p.i:p.i+len(skip)], skip) {
return false
}
p.i += len(skip)
return true
}
func (p *Parser) ReadSep(sep byte) ([]byte, bool) {
ind := bytes.IndexByte(p.b[p.i:], sep)
if ind == -1 {
b := p.b[p.i:]
p.i = len(p.b)
return b, false
}
b := p.b[p.i : p.i+ind]
p.i += ind + 1
return b, true
}
func (p *Parser) ReadIdentifier() (string, bool) {
if p.i < len(p.b) && p.b[p.i] == '(' {
s := p.i + 1
if ind := bytes.IndexByte(p.b[s:], ')'); ind != -1 {
b := p.b[s : s+ind]
p.i = s + ind + 1
return internal.String(b), false
}
}
ind := len(p.b) - p.i
var alpha bool
for i, c := range p.b[p.i:] {
if isDigit(c) {
continue
}
if isAlpha(c) || (i > 0 && alpha && c == '_') {
alpha = true
continue
}
ind = i
break
}
if ind == 0 {
return "", false
}
b := p.b[p.i : p.i+ind]
p.i += ind
return internal.String(b), !alpha
}
func (p *Parser) ReadNumber() int {
ind := len(p.b) - p.i
for i, c := range p.b[p.i:] {
if !isDigit(c) {
ind = i
break
}
}
if ind == 0 {
return 0
}
n, err := strconv.Atoi(string(p.b[p.i : p.i+ind]))
if err != nil {
panic(err)
}
p.i += ind
return n
}
func isDigit(c byte) bool {
return c >= '0' && c <= '9'
}
func isAlpha(c byte) bool {
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')
}

11
ch/internal/safe.go Normal file
View File

@ -0,0 +1,11 @@
//go:build appengine
package internal
func String(b []byte) string {
return string(b)
}
func Bytes(s string) []byte {
return []byte(s)
}

View File

@ -0,0 +1,184 @@
package tagparser
import (
"strings"
)
type Tag struct {
Name string
Options map[string][]string
}
func (t Tag) IsZero() bool {
return t.Name == "" && t.Options == nil
}
func (t Tag) HasOption(name string) bool {
_, ok := t.Options[name]
return ok
}
func (t Tag) Option(name string) (string, bool) {
if vs, ok := t.Options[name]; ok {
return vs[len(vs)-1], true
}
return "", false
}
func Parse(s string) Tag {
if s == "" {
return Tag{}
}
p := parser{
s: s,
}
p.parse()
return p.tag
}
type parser struct {
s string
i int
tag Tag
seenName bool // for empty names
}
func (p *parser) setName(name string) {
if p.seenName {
p.addOption(name, "")
} else {
p.seenName = true
p.tag.Name = name
}
}
func (p *parser) addOption(key, value string) {
p.seenName = true
if key == "" {
return
}
if p.tag.Options == nil {
p.tag.Options = make(map[string][]string)
}
if vs, ok := p.tag.Options[key]; ok {
p.tag.Options[key] = append(vs, value)
} else {
p.tag.Options[key] = []string{value}
}
}
func (p *parser) parse() {
for p.valid() {
p.parseKeyValue()
if p.peek() == ',' {
p.i++
}
}
}
func (p *parser) parseKeyValue() {
start := p.i
for p.valid() {
switch c := p.read(); c {
case ',':
key := p.s[start : p.i-1]
p.setName(key)
return
case ':':
key := p.s[start : p.i-1]
value := p.parseValue()
p.addOption(key, value)
return
case '"':
key := p.parseQuotedValue()
p.setName(key)
return
}
}
key := p.s[start:p.i]
p.setName(key)
}
func (p *parser) parseValue() string {
start := p.i
for p.valid() {
switch c := p.read(); c {
case '"':
return p.parseQuotedValue()
case ',':
return p.s[start : p.i-1]
case '(':
p.skipPairs('(', ')')
}
}
if p.i == start {
return ""
}
return p.s[start:p.i]
}
func (p *parser) parseQuotedValue() string {
if i := strings.IndexByte(p.s[p.i:], '"'); i >= 0 && p.s[p.i+i-1] != '\\' {
s := p.s[p.i : p.i+i]
p.i += i + 1
return s
}
b := make([]byte, 0, 16)
for p.valid() {
switch c := p.read(); c {
case '\\':
b = append(b, p.read())
case '"':
return string(b)
default:
b = append(b, c)
}
}
return ""
}
func (p *parser) skipPairs(start, end byte) {
var lvl int
for p.valid() {
switch c := p.read(); c {
case '"':
_ = p.parseQuotedValue()
case start:
lvl++
case end:
if lvl == 0 {
return
}
lvl--
}
}
}
func (p *parser) valid() bool {
return p.i < len(p.s)
}
func (p *parser) read() byte {
if !p.valid() {
return 0
}
c := p.s[p.i]
p.i++
return c
}
func (p *parser) peek() byte {
if !p.valid() {
return 0
}
c := p.s[p.i]
return c
}

View File

@ -0,0 +1,45 @@
package tagparser_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/uptrace/go-clickhouse/ch/internal/tagparser"
)
var tagTests = []struct {
tag string
name string
options map[string][]string
}{
{"", "", nil},
{"hello", "hello", nil},
{"hello,world", "hello", map[string][]string{"world": {""}}},
{`"hello,world'`, "", nil},
{`"hello:world"`, `hello:world`, nil},
{",hello", "", map[string][]string{"hello": {""}}},
{",hello,world", "", map[string][]string{"hello": {""}, "world": {""}}},
{"hello:", "", map[string][]string{"hello": {""}}},
{"hello:world", "", map[string][]string{"hello": {"world"}}},
{"hello:world,foo", "", map[string][]string{"hello": {"world"}, "foo": {""}}},
{"hello:world,foo:bar", "", map[string][]string{"hello": {"world"}, "foo": {"bar"}}},
{"hello:\"world1,world2\"", "", map[string][]string{"hello": {"world1,world2"}}},
{`hello:"world1,world2",world3`, "", map[string][]string{"hello": {"world1,world2"}, "world3": {""}}},
{`hello:"world1:world2",world3`, "", map[string][]string{"hello": {"world1:world2"}, "world3": {""}}},
{`hello:"D'Angelo, esquire",foo:bar`, "", map[string][]string{"hello": {"D'Angelo, esquire"}, "foo": {"bar"}}},
{`hello:"world('foo', 'bar')"`, "", map[string][]string{"hello": {"world('foo', 'bar')"}}},
{" hello,foo: bar ", " hello", map[string][]string{"foo": {" bar "}}},
{"foo:bar(hello, world)", "", map[string][]string{"foo": {"bar(hello, world)"}}},
{"foo:bar(hello(), world)", "", map[string][]string{"foo": {"bar(hello(), world)"}}},
{"type:geometry(POINT, 4326)", "", map[string][]string{"type": {"geometry(POINT, 4326)"}}},
{"foo:bar,foo:baz", "", map[string][]string{"foo": []string{"bar", "baz"}}},
}
func TestTagParser(t *testing.T) {
for i, test := range tagTests {
tag := tagparser.Parse(test.tag)
require.Equal(t, test.name, tag.Name, "#%d", i)
require.Equal(t, test.options, tag.Options, "#%d", i)
}
}

22
ch/internal/unsafe.go Normal file
View File

@ -0,0 +1,22 @@
//go:build !appengine
package internal
import (
"unsafe"
)
// BytesToString converts byte slice to string.
func String(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
// StringToBytes converts string to byte slice.
func Bytes(s string) []byte {
return *(*[]byte)(unsafe.Pointer(
&struct {
string
Cap int
}{s, len(s)},
))
}

102
ch/internal/util.go Normal file
View File

@ -0,0 +1,102 @@
package internal
import (
"context"
"log"
"math/rand"
"os"
"reflect"
"time"
)
var (
Logger = log.New(os.Stderr, "ch: ", log.LstdFlags|log.Lshortfile)
Warn = log.New(os.Stderr, "WARN: ch: ", log.LstdFlags|log.Lshortfile)
Deprecated = log.New(os.Stderr, "DEPRECATED: ch: ", log.LstdFlags|log.Lshortfile)
)
func Sleep(ctx context.Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func Unwrap(err error) error {
u, ok := err.(interface {
Unwrap() error
})
if !ok {
return nil
}
return u.Unwrap()
}
func MakeSliceNextElemFunc(v reflect.Value) func() reflect.Value {
if v.Kind() == reflect.Array {
var pos int
return func() reflect.Value {
v := v.Index(pos)
pos++
return v
}
}
elemType := v.Type().Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
return func() reflect.Value {
if v.Len() < v.Cap() {
v.Set(v.Slice(0, v.Len()+1))
elem := v.Index(v.Len() - 1)
if elem.IsNil() {
elem.Set(reflect.New(elemType))
}
return elem.Elem()
}
elem := reflect.New(elemType)
v.Set(reflect.Append(v, elem))
return elem.Elem()
}
}
zero := reflect.Zero(elemType)
return func() reflect.Value {
if v.Len() < v.Cap() {
v.Set(v.Slice(0, v.Len()+1))
return v.Index(v.Len() - 1)
}
v.Set(reflect.Append(v, zero))
return v.Index(v.Len() - 1)
}
}
func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration {
if retry < 0 {
panic("not reached")
}
if minBackoff == 0 {
return 0
}
d := minBackoff << uint(retry)
if d < minBackoff {
return maxBackoff
}
d = minBackoff + time.Duration(rand.Int63n(int64(d)))
if d > maxBackoff || d < minBackoff {
d = maxBackoff
}
return d
}

117
ch/model.go Normal file
View File

@ -0,0 +1,117 @@
package ch
import (
"errors"
"fmt"
"reflect"
"time"
"github.com/uptrace/go-clickhouse/ch/chschema"
)
var errNilModel = errors.New("ch: Model(nil)")
var (
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
mapType = reflect.TypeOf((*map[string]any)(nil)).Elem()
)
type (
Query = chschema.Query
Model = chschema.Model
)
func newModel(db *DB, values ...any) (Model, error) {
if len(values) > 1 {
return scan(values...), nil
}
v0 := values[0]
switch v0 := v0.(type) {
case Model:
return v0, nil
}
v := reflect.ValueOf(v0)
if !v.IsValid() {
return nil, errNilModel
}
if v.Kind() != reflect.Ptr {
return nil, fmt.Errorf("ch: Model(non-pointer %T)", v0)
}
v = v.Elem()
switch v.Kind() {
case reflect.Struct:
if v.Type() != timeType {
return newStructTableModelValue(db, v), nil
}
case reflect.Slice:
typ := v.Type()
elemType := indirectType(typ.Elem())
if elemType == mapType {
return newSliceMapModel(v), nil
}
if elemType.Kind() == reflect.Struct && elemType != timeType {
return newSliceTableModel(db, v, elemType), nil
}
case reflect.Map:
if v.Type() == mapType {
return newMapModel(v), nil
}
}
return scan(v0), nil
}
func scan(values ...any) Model {
m := &scanModel{
values: make([]reflect.Value, len(values)),
}
for i, v := range values {
m.values[i] = reflect.ValueOf(v).Elem()
}
return m
}
type TableModel interface {
Model
Table() *chschema.Table
Block(fields []*chschema.Field) *chschema.Block
}
func newTableModel(db *DB, value any) (TableModel, error) {
if value, ok := value.(TableModel); ok {
return value, nil
}
v := reflect.ValueOf(value)
if !v.IsValid() {
return nil, errNilModel
}
if v.Kind() != reflect.Ptr {
return nil, fmt.Errorf("ch: Model(non-pointer %T)", value)
}
if v.IsNil() {
typ := v.Type().Elem()
if typ.Kind() == reflect.Struct {
return newStructTableModel(db, chschema.TableForType(typ)), nil
}
return nil, errNilModel
}
v = v.Elem()
switch v.Kind() {
case reflect.Struct:
return newStructTableModelValue(db, v), nil
case reflect.Slice:
elemType := sliceElemType(v)
if elemType.Kind() == reflect.Struct {
return newSliceTableModel(db, v, elemType), nil
}
}
return nil, fmt.Errorf("ch: Model(unsupported %s)", v.Type())
}

46
ch/model_map.go Normal file
View File

@ -0,0 +1,46 @@
package ch
import (
"reflect"
"github.com/uptrace/go-clickhouse/ch/chschema"
)
type mapModel struct {
m map[string]any
columnar bool
}
var _ Model = (*mapModel)(nil)
func newMapModel(v reflect.Value) *mapModel {
if v.IsNil() {
v.Set(reflect.MakeMap(mapType))
}
return &mapModel{
m: v.Interface().(map[string]any),
}
}
func (m *mapModel) SetColumnar(on bool) {
m.columnar = on
}
func (m *mapModel) ScanBlock(block *chschema.Block) error {
if m.columnar {
for _, col := range block.Columns {
set(m.m, col.Name, col.Value())
}
return nil
}
for _, col := range block.Columns {
if col.Len() > 0 {
set(m.m, col.Name, col.Index(0))
} else {
zero := reflect.Zero(col.Columnar.Type()).Interface()
set(m.m, col.Name, zero)
}
}
return nil
}

56
ch/model_scan.go Normal file
View File

@ -0,0 +1,56 @@
package ch
import (
"fmt"
"reflect"
"github.com/uptrace/go-clickhouse/ch/chschema"
)
type columnarModel struct {
columnar bool
}
func (m *columnarModel) SetColumnar(on bool) {
m.columnar = on
}
type scanModel struct {
columnarModel
values []reflect.Value
}
var _ Model = (*scanModel)(nil)
func (m *scanModel) UseQueryRow() bool {
return true
}
func (m *scanModel) ScanBlock(block *chschema.Block) error {
if block.NumRow == 0 {
return nil
}
if block.NumColumn != len(m.values) {
return fmt.Errorf("ch: got %d columns, but Scan has %d values",
block.NumColumn, len(m.values))
}
if m.columnar {
for i, col := range block.Columns {
v := m.values[i]
if v.Kind() == reflect.Interface {
v.Set(reflect.ValueOf(col.Value()))
} else {
v.Set(reflect.AppendSlice(v, reflect.ValueOf(col.Value())))
}
}
return nil
}
for i, col := range block.Columns {
if err := col.ConvertAssign(0, m.values[i]); err != nil {
return err
}
}
return nil
}

57
ch/model_slice_map.go Normal file
View File

@ -0,0 +1,57 @@
package ch
import (
"reflect"
"strings"
"github.com/uptrace/go-clickhouse/ch/chschema"
)
type sliceMapModel struct {
v reflect.Value
slice []map[string]any
}
var _ Model = (*sliceMapModel)(nil)
func newSliceMapModel(v reflect.Value) *sliceMapModel {
return &sliceMapModel{
v: v,
slice: v.Interface().([]map[string]any),
}
}
func (m *sliceMapModel) ScanBlock(block *chschema.Block) error {
for i := 0; i < block.NumRow; i++ {
row := make(map[string]any, block.NumColumn)
for _, col := range block.Columns {
set(row, col.Name, col.Index(i))
}
m.slice = append(m.slice, row)
}
m.v.Set(reflect.ValueOf(m.slice))
return nil
}
func set(m map[string]any, key string, value any) {
const sep = "__"
for {
idx := strings.Index(key, sep)
if idx == -1 {
break
}
subKey := key[:idx]
key = key[idx+len(sep):]
if subMap, ok := m[subKey].(map[string]any); ok {
m = subMap
continue
}
subMap := make(map[string]any)
m[subKey] = subMap
m = subMap
}
m[key] = value
}

78
ch/model_table_slice.go Normal file
View File

@ -0,0 +1,78 @@
package ch
import (
"context"
"reflect"
"github.com/uptrace/go-clickhouse/ch/chschema"
"github.com/uptrace/go-clickhouse/ch/internal"
)
type sliceTableModel struct {
db *DB
table *chschema.Table
slice reflect.Value
nextElem func() reflect.Value
}
var _ TableModel = (*sliceTableModel)(nil)
func newSliceTableModel(db *DB, slice reflect.Value, elemType reflect.Type) TableModel {
return &sliceTableModel{
db: db,
table: chschema.TableForType(elemType),
slice: slice,
nextElem: internal.MakeSliceNextElemFunc(slice),
}
}
func (m *sliceTableModel) Table() *chschema.Table {
return m.table
}
func (m *sliceTableModel) AppendParam(
fmter chschema.Formatter, b []byte, name string,
) ([]byte, bool) {
return b, false
}
func (m *sliceTableModel) ScanBlock(block *chschema.Block) error {
for row := 0; row < block.NumRow; row++ {
elem := m.nextElem()
if err := scanRow(m.db, m.table, elem, block, row); err != nil {
return err
}
}
return nil
}
func (m *sliceTableModel) Block(fields []*chschema.Field) *chschema.Block {
sliceLen := m.slice.Len()
block := chschema.NewBlock(m.table, len(fields), sliceLen)
if sliceLen == 0 {
return block
}
for _, field := range fields {
_ = block.ColumnForField(field)
}
for i := 0; i < sliceLen; i++ {
elem := indirect(m.slice.Index(i))
for _, col := range block.Columns {
col.AppendValue(col.Field.Value(elem))
}
}
return block
}
var _ AfterScanRowHook = (*sliceTableModel)(nil)
func (m *sliceTableModel) AfterScanRow(ctx context.Context) error {
if m.table.HasAfterScanRowHook() {
return callAfterScanRowHookSlice(ctx, m.slice)
}
return nil
}

129
ch/model_table_struct.go Normal file
View File

@ -0,0 +1,129 @@
package ch
import (
"context"
"reflect"
"github.com/uptrace/go-clickhouse/ch/chschema"
)
type structTableModel struct {
db *DB
table *chschema.Table
strct reflect.Value
}
var _ TableModel = (*structTableModel)(nil)
func newStructTableModel(db *DB, table *chschema.Table) *structTableModel {
return &structTableModel{
db: db,
table: table,
}
}
func newStructTableModelValue(db *DB, v reflect.Value) *structTableModel {
return &structTableModel{
db: db,
table: chschema.TableForType(v.Type()),
strct: v,
}
}
func (m *structTableModel) UseQueryRow() bool {
return !m.table.IsColumnar()
}
func (m *structTableModel) Table() *chschema.Table {
return m.table
}
func (m *structTableModel) AppendNamedArg(
fmter chschema.Formatter, b []byte, name string,
) ([]byte, bool) {
field, ok := m.table.FieldMap[name]
if ok {
b = field.AppendValue(fmter, b, m.strct)
return b, true
}
return b, false
}
func (m *structTableModel) ScanBlock(block *chschema.Block) error {
if block.NumRow == 0 {
return nil
}
if m.table.IsColumnar() {
return scanColumns(m.db, m.table, m.strct, block)
}
return scanRow(m.db, m.table, m.strct, block, 0)
}
func scanRow(
db *DB, table *chschema.Table, strct reflect.Value, block *chschema.Block, row int,
) error {
for _, col := range block.Columns {
field := table.FieldMap[col.Name]
if field == nil {
if !db.flags.Has(discardUnknownColumnsFlag) {
return &chschema.UnknownColumnError{
Table: table,
Column: col.Name,
}
}
continue
}
fieldValue := field.Value(strct)
if err := col.ConvertAssign(row, fieldValue); err != nil {
return err
}
}
return nil
}
func scanColumns(db *DB, table *chschema.Table, strct reflect.Value, block *chschema.Block) error {
for _, col := range block.Columns {
field := table.FieldMap[col.Name]
if field == nil {
if !db.flags.Has(discardUnknownColumnsFlag) {
return &chschema.UnknownColumnError{
Table: table,
Column: col.Name,
}
}
continue
}
fieldValue := field.Value(strct)
fieldValue.Set(reflect.AppendSlice(fieldValue, reflect.ValueOf(col.Value())))
}
return nil
}
func (m *structTableModel) Block(fields []*chschema.Field) *chschema.Block {
block := chschema.NewBlock(m.table, len(fields), 1)
for _, field := range fields {
fieldValue := field.Value(m.strct)
col := block.Column(field.CHName, field.CHType)
if m.table.IsColumnar() {
col.Set(fieldValue.Interface())
} else {
col.AppendValue(fieldValue)
}
}
return block
}
var _ AfterScanRowHook = (*structTableModel)(nil)
func (m *structTableModel) AfterScanRow(ctx context.Context) error {
if m.table.HasAfterScanRowHook() {
return callAfterScanRowHook(ctx, m.strct.Addr())
}
return nil
}

426
ch/proto.go Normal file
View File

@ -0,0 +1,426 @@
package ch
import (
"context"
"errors"
"fmt"
"os"
"strings"
"github.com/uptrace/go-clickhouse/ch/chpool"
"github.com/uptrace/go-clickhouse/ch/chproto"
"github.com/uptrace/go-clickhouse/ch/chschema"
)
const (
clientName = "go-clickhouse"
chVersionMajor = 19
chVersionMinor = 17
chVersionPatch = 5
chRevision = 54428
)
func (db *DB) hello(ctx context.Context, cn *chpool.Conn) error {
err := cn.WithWriter(ctx, db.cfg.WriteTimeout, func(wr *chproto.Writer) {
wr.Uvarint(chproto.ClientHello)
writeClientInfo(wr)
wr.String(db.cfg.Database)
wr.String(db.cfg.User)
wr.String(db.cfg.Password)
})
if err != nil {
return err
}
return cn.WithReader(ctx, db.cfg.ReadTimeout, func(rd *chproto.Reader) error {
packet, err := rd.Uvarint()
if err != nil {
return err
}
switch packet {
case chproto.ServerHello:
return cn.ServerInfo.ReadFrom(rd)
case chproto.ServerException:
return readException(rd)
default:
return fmt.Errorf("ch: hello: unexpected packet: %d", packet)
}
})
}
func writeClientInfo(wr *chproto.Writer) {
wr.String(clientName)
wr.Uvarint(chVersionMajor)
wr.Uvarint(chVersionMinor)
wr.Uvarint(chRevision)
}
func readException(rd *chproto.Reader) (err error) {
var exc Error
if exc.Code, err = rd.Int32(); err != nil {
return err
}
if exc.Name, err = rd.String(); err != nil {
return err
}
if exc.Message, err = rd.String(); err != nil {
return err
}
exc.Message = strings.TrimSpace(strings.TrimPrefix(exc.Message, exc.Name+":"))
if exc.StackTrace, err = rd.String(); err != nil {
return err
}
hasNested, err := rd.Bool()
if err != nil {
return err
}
if hasNested {
exc.nested = readException(rd)
}
return &exc
}
func readProfileInfo(rd *chproto.Reader) error {
if _, err := rd.Uvarint(); err != nil {
return err
}
if _, err := rd.Uvarint(); err != nil {
return err
}
if _, err := rd.Uvarint(); err != nil {
return err
}
if _, err := rd.Bool(); err != nil {
return err
}
if _, err := rd.Uvarint(); err != nil {
return err
}
if _, err := rd.Bool(); err != nil {
return err
}
return nil
}
func readProgress(rd *chproto.Reader) error {
if _, err := rd.Uvarint(); err != nil {
return err
}
if _, err := rd.Uvarint(); err != nil {
return err
}
if _, err := rd.Uvarint(); err != nil {
return err
}
if _, err := rd.Uvarint(); err != nil {
return err
}
if _, err := rd.Uvarint(); err != nil {
return err
}
return nil
}
func writePing(wr *chproto.Writer) {
wr.Uvarint(chproto.ClientPing)
}
func readPong(rd *chproto.Reader) error {
for {
packet, err := rd.Uvarint()
if err != nil {
return err
}
switch packet {
case chproto.ServerPong:
return nil
case chproto.ServerException:
return readException(rd)
case chproto.ServerEndOfStream:
return nil
default:
return fmt.Errorf("ch: readPong: unexpected packet: %d", packet)
}
}
}
var hostname string
func (db *DB) writeQuery(wr *chproto.Writer, query string) {
if hostname == "" {
hostname, _ = os.Hostname()
}
wr.Uvarint(chproto.ClientQuery)
wr.String("")
// TODO: use QuerySecondary - https://github.com/ClickHouse/ClickHouse/blob/master/dbms/src/Client/Connection.cpp#L388-L404
wr.Uvarint(chproto.QueryInitial)
wr.String("") // initial user
wr.String("") // initial query id
wr.String("[::ffff:127.0.0.1]:0")
wr.Uvarint(1) // iface type TCP
wr.String(hostname)
wr.String(hostname)
writeClientInfo(wr)
wr.String("") // quota key
wr.Uvarint(chVersionPatch) // client version patch
db.writeSettings(wr)
wr.Uvarint(2)
wr.Uvarint(chproto.CompressionEnabled)
wr.String(query)
}
func (db *DB) writeSettings(wr *chproto.Writer) {
for key, value := range db.cfg.QuerySettings {
wr.String(key)
switch value := value.(type) {
case string:
wr.String(value)
case int:
wr.Uvarint(uint64(value))
case int64:
wr.Uvarint(uint64(value))
case uint64:
wr.Uvarint(value)
case bool:
wr.Bool(value)
default:
panic(fmt.Errorf("%s setting has unsupported type: %T", key, value))
}
}
wr.String("")
}
var emptyBlock chschema.Block
func writeBlock(ctx context.Context, wr *chproto.Writer, block *chschema.Block) {
if block == nil {
block = &emptyBlock
}
wr.Uvarint(chproto.ClientData)
wr.String("")
wr.WithCompression(func() error {
writeBlockInfo(wr)
return block.WriteTo(wr)
})
}
func writeBlockInfo(wr *chproto.Writer) {
wr.Uvarint(1)
wr.Bool(false)
wr.Uvarint(2)
wr.Int32(-1)
wr.Uvarint(0)
}
func readSampleBlock(rd *chproto.Reader) (*chschema.Block, error) {
for {
packet, err := rd.Uvarint()
if err != nil {
return nil, err
}
switch packet {
case chproto.ServerData:
block := new(chschema.Block)
if err := readBlock(rd, block); err != nil {
return nil, err
}
return block, nil
case chproto.ServerTableColumns:
if err := readServerTableColumns(rd); err != nil {
return nil, err
}
case chproto.ServerException:
return nil, readException(rd)
default:
return nil, fmt.Errorf("ch: readSampleBlock: unexpected packet: %d", packet)
}
}
}
func readDataBlocks(rd *chproto.Reader, model Model) (*result, error) {
var res *result
for {
packet, err := rd.Uvarint()
if err != nil {
return nil, err
}
switch packet {
case chproto.ServerData:
block := new(chschema.Block)
if model, ok := model.(TableModel); ok {
block.Table = model.Table()
}
if err := readBlock(rd, block); err != nil {
return nil, err
}
if res == nil {
res = new(result)
}
res.affected += block.NumRow
if model != nil {
if err := model.ScanBlock(block); err != nil {
return nil, err
}
}
case chproto.ServerException:
return nil, readException(rd)
case chproto.ServerProgress:
if err := readProgress(rd); err != nil {
return nil, err
}
case chproto.ServerProfileInfo:
if err := readProfileInfo(rd); err != nil {
return nil, err
}
case chproto.ServerTableColumns:
if err := readServerTableColumns(rd); err != nil {
return nil, err
}
case chproto.ServerEndOfStream:
return res, nil
default:
return nil, fmt.Errorf("ch: readDataBlocks: unexpected packet: %d", packet)
}
}
}
func readPacket(rd *chproto.Reader) (*result, error) {
packet, err := rd.Uvarint()
if err != nil {
return nil, err
}
res := new(result)
switch packet {
case chproto.ServerException:
return nil, readException(rd)
case chproto.ServerProgress:
if err := readProgress(rd); err != nil {
return nil, err
}
return res, nil
case chproto.ServerProfileInfo:
if err := readProfileInfo(rd); err != nil {
return nil, err
}
return res, nil
case chproto.ServerTableColumns:
if err := readServerTableColumns(rd); err != nil {
return nil, err
}
return res, nil
case chproto.ServerEndOfStream:
return res, nil
default:
return nil, fmt.Errorf("ch: readPacket: unexpected packet: %d", packet)
}
}
// TODO: return block
func readBlock(rd *chproto.Reader, block *chschema.Block) error {
if _, err := rd.String(); err != nil {
return err
}
return rd.WithCompression(func() error {
if err := readBlockInfo(rd); err != nil {
return err
}
numColumn, err := rd.Uvarint()
if err != nil {
return err
}
numRow, err := rd.Uvarint()
if err != nil {
return err
}
block.NumColumn = int(numColumn)
block.NumRow = int(numRow)
for i := 0; i < int(numColumn); i++ {
colName, err := rd.String()
if err != nil {
return err
}
if colName == "" {
return errors.New("ch: column has empty name")
}
colType, err := rd.String()
if err != nil {
return err
}
if colType == "" {
return fmt.Errorf("ch: column=%s has empty type", colName)
}
col := block.Column(colName, colType)
if err := col.ReadFrom(rd, int(numRow)); err != nil {
return err
}
}
return nil
})
}
func readBlockInfo(rd *chproto.Reader) error {
if _, err := rd.Uvarint(); err != nil {
return err
}
if _, err := rd.Bool(); err != nil {
return err
}
if _, err := rd.Uvarint(); err != nil {
return err
}
if _, err := rd.Int32(); err != nil {
return err
}
if _, err := rd.Uvarint(); err != nil {
return err
}
return nil
}
func writeCancel(wr *chproto.Writer) {
wr.Uvarint(chproto.ClientCancel)
}
func readServerTableColumns(rd *chproto.Reader) error {
_, err := rd.String()
if err != nil {
return err
}
_, err = rd.String()
if err != nil {
return err
}
return nil
}

403
ch/query_base.go Normal file
View File

@ -0,0 +1,403 @@
package ch
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/uptrace/go-clickhouse/ch/chschema"
"github.com/uptrace/go-clickhouse/ch/internal"
)
type withQuery struct {
name string
query chschema.QueryAppender
cte bool
}
type baseQuery struct {
db *DB
err error
tableModel TableModel
table *chschema.Table
with []withQuery
modelTableName chschema.QueryWithArgs
tables []chschema.QueryWithArgs
columns []chschema.QueryWithArgs
settings []chschema.QueryWithArgs
flags internal.Flag
}
func (q *baseQuery) DB() *DB {
return q.db
}
func (q *baseQuery) GetModel() Model {
return q.tableModel
}
func (q *baseQuery) GetTableName() string {
if q.table != nil {
return q.table.Name
}
for _, wq := range q.with {
if v, ok := wq.query.(Query); ok {
if model := v.GetModel(); model != nil {
return v.GetTableName()
}
}
}
if q.modelTableName.Query != "" {
return q.modelTableName.Query
}
if len(q.tables) > 0 {
b, _ := q.tables[0].AppendQuery(q.db.fmter, nil)
if len(b) < 64 {
return string(b)
}
}
return ""
}
func (q *baseQuery) setConn(db *DB) {
q.db = db
}
func (q *baseQuery) setErr(err error) {
if q.err == nil {
q.err = err
}
}
func (q *baseQuery) setTableModel(model any) {
tm, err := newTableModel(q.db, model)
if err != nil {
q.setErr(err)
return
}
q.tableModel = tm
q.table = tm.Table()
}
func (q *baseQuery) newModel(values ...any) (Model, error) {
if len(values) > 0 {
return newModel(q.db, values...)
}
return q.tableModel, nil
}
func (q *baseQuery) exec(
ctx context.Context,
iquery Query,
query string,
) (sql.Result, error) {
ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, q.tableModel)
res, err := q.db.query(ctx, nil, query)
q.db.afterQuery(ctx, event, res, err)
return res, err
}
//------------------------------------------------------------------------------
func (q *baseQuery) AppendNamedArg(fmter chschema.Formatter, b []byte, name string) ([]byte, bool) {
return b, false
}
func appendColumns(b []byte, table Safe, fields []*chschema.Field) []byte {
for i, f := range fields {
if i > 0 {
b = append(b, ", "...)
}
if len(table) > 0 {
b = append(b, table...)
b = append(b, '.')
}
b = append(b, f.Column...)
}
return b
}
func formatterWithModel(
fmter chschema.Formatter, model chschema.NamedArgAppender,
) chschema.Formatter {
return fmter.WithArg(model)
}
//------------------------------------------------------------------------------
func (q *baseQuery) addTable(table chschema.QueryWithArgs) {
q.tables = append(q.tables, table)
}
func (q *baseQuery) modelHasTableName() bool {
if !q.modelTableName.IsZero() {
return q.modelTableName.Query != ""
}
return q.table != nil
}
func (q *baseQuery) hasTables() bool {
return q.modelHasTableName() || len(q.tables) > 0
}
func (q *baseQuery) appendTables(fmter chschema.Formatter, b []byte) (_ []byte, err error) {
return q._appendTables(fmter, b, false)
}
func (q *baseQuery) appendTablesWithAlias(fmter chschema.Formatter, b []byte) (_ []byte, err error) {
return q._appendTables(fmter, b, true)
}
func (q *baseQuery) _appendTables(
fmter chschema.Formatter, b []byte, withAlias bool,
) (_ []byte, err error) {
startLen := len(b)
if q.modelHasTableName() {
if !q.modelTableName.IsZero() {
b, err = q.modelTableName.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
} else {
b = fmter.AppendQuery(b, string(q.table.CHName))
if withAlias && q.table.CHAlias != q.table.CHName {
b = append(b, " AS "...)
b = append(b, q.table.CHAlias...)
}
}
}
for _, table := range q.tables {
if len(b) > startLen {
b = append(b, ", "...)
}
b, err = table.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
func (q *baseQuery) appendFirstTable(fmter chschema.Formatter, b []byte) ([]byte, error) {
return q._appendFirstTable(fmter, b, false)
}
func (q *baseQuery) appendFirstTableWithAlias(
fmter chschema.Formatter, b []byte,
) ([]byte, error) {
return q._appendFirstTable(fmter, b, true)
}
func (q *baseQuery) _appendFirstTable(
fmter chschema.Formatter, b []byte, withAlias bool,
) ([]byte, error) {
if !q.modelTableName.IsZero() {
return q.modelTableName.AppendQuery(fmter, b)
}
if q.table != nil {
b = fmter.AppendQuery(b, string(q.table.CHName))
if withAlias {
b = append(b, " AS "...)
b = append(b, q.table.CHAlias...)
}
return b, nil
}
if len(q.tables) > 0 {
return q.tables[0].AppendQuery(fmter, b)
}
return nil, errors.New("ch: query does not have a table")
}
func (q *baseQuery) hasMultiTables() bool {
if q.modelHasTableName() {
return len(q.tables) >= 1
}
return len(q.tables) >= 2
}
func (q *baseQuery) appendOtherTables(fmter chschema.Formatter, b []byte) (_ []byte, err error) {
tables := q.tables
if !q.modelHasTableName() {
tables = tables[1:]
}
for i, table := range tables {
if i > 0 {
b = append(b, ", "...)
}
b, err = table.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
//------------------------------------------------------------------------------
func (q *baseQuery) addColumn(column chschema.QueryWithArgs) {
q.columns = append(q.columns, column)
}
func (q *baseQuery) excludeColumn(columns []string) {
if q.columns == nil {
for _, f := range q.table.Fields {
q.columns = append(q.columns, chschema.UnsafeIdent(f.CHName))
}
}
if len(columns) == 1 && columns[0] == "*" {
q.columns = make([]chschema.QueryWithArgs, 0)
return
}
for _, column := range columns {
if !q._excludeColumn(column) {
q.setErr(fmt.Errorf("ch: can't find column=%q", column))
return
}
}
}
func (q *baseQuery) _excludeColumn(column string) bool {
for i, col := range q.columns {
if col.Args == nil && col.Query == column {
q.columns = append(q.columns[:i], q.columns[i+1:]...)
return true
}
}
return false
}
func (q *baseQuery) getFields() ([]*chschema.Field, error) {
if len(q.columns) == 0 {
if q.table == nil {
return nil, nil
}
return q.table.Fields, nil
}
fields := make([]*chschema.Field, 0, len(q.columns))
for _, col := range q.columns {
if col.Args != nil {
continue
}
field, err := q.table.Field(col.Query)
if err != nil {
return nil, err
}
fields = append(fields, field)
}
return fields, nil
}
func (q *baseQuery) appendSettings(fmter chschema.Formatter, b []byte) (_ []byte, err error) {
if len(q.settings) > 0 {
b = append(b, " SETTINGS "...)
for i, opt := range q.settings {
if i > 0 {
b = append(b, ", "...)
}
b, err = opt.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
}
return b, nil
}
//------------------------------------------------------------------------------
type WhereQuery struct {
where []chschema.QueryWithSep
}
func (q *WhereQuery) addWhere(where chschema.QueryWithSep) {
q.where = append(q.where, where)
}
func (q *WhereQuery) WhereGroup(sep string, fn func(*WhereQuery)) {
q.addWhereGroup(sep, fn)
}
func (q *WhereQuery) addWhereGroup(sep string, fn func(*WhereQuery)) {
q2 := new(WhereQuery)
fn(q2)
if len(q2.where) > 0 {
q2.where[0].Sep = ""
q.addWhere(chschema.SafeQueryWithSep("", nil, sep+"("))
q.where = append(q.where, q2.where...)
q.addWhere(chschema.SafeQueryWithSep("", nil, ")"))
}
}
//------------------------------------------------------------------------------
type whereBaseQuery struct {
baseQuery
WhereQuery
}
func (q *whereBaseQuery) mustAppendWhere(fmter chschema.Formatter, b []byte) ([]byte, error) {
if len(q.where) == 0 {
err := errors.New("ch: Update and Delete queries require at least one Where")
return nil, err
}
return q.appendWhere(fmter, b)
}
func (q *whereBaseQuery) appendWhere(fmter chschema.Formatter, b []byte) (_ []byte, err error) {
if len(q.where) == 0 {
return b, nil
}
b = append(b, " WHERE "...)
b, err = appendWhere(fmter, b, q.where)
if err != nil {
return nil, err
}
return b, nil
}
func appendWhere(
fmter chschema.Formatter, b []byte, where []chschema.QueryWithSep,
) (_ []byte, err error) {
for i, where := range where {
if i > 0 || where.Sep == "(" {
b = append(b, where.Sep...)
}
if where.Query == "" && where.Args == nil {
continue
}
b = append(b, '(')
b, err = where.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
b = append(b, ')')
}
return b, nil
}

203
ch/query_insert.go Normal file
View File

@ -0,0 +1,203 @@
package ch
import (
"context"
"database/sql"
"errors"
"github.com/uptrace/go-clickhouse/ch/chschema"
"github.com/uptrace/go-clickhouse/ch/internal"
)
type InsertQuery struct {
whereBaseQuery
}
var _ Query = (*InsertQuery)(nil)
func NewInsertQuery(db *DB) *InsertQuery {
return &InsertQuery{
whereBaseQuery: whereBaseQuery{
baseQuery: baseQuery{
db: db,
},
},
}
}
func (q *InsertQuery) Model(model any) *InsertQuery {
q.setTableModel(model)
return q
}
//------------------------------------------------------------------------------
func (q *InsertQuery) Table(tables ...string) *InsertQuery {
for _, table := range tables {
q.addTable(chschema.UnsafeIdent(table))
}
return q
}
func (q *InsertQuery) TableExpr(query string, args ...any) *InsertQuery {
q.addTable(chschema.SafeQuery(query, args))
return q
}
func (q *InsertQuery) ModelTableExpr(query string, args ...any) *InsertQuery {
q.modelTableName = chschema.SafeQuery(query, args)
return q
}
func (q *InsertQuery) Setting(query string, args ...any) *InsertQuery {
q.settings = append(q.settings, chschema.SafeQuery(query, args))
return q
}
//------------------------------------------------------------------------------
func (q *InsertQuery) Column(columns ...string) *InsertQuery {
for _, column := range columns {
q.addColumn(chschema.UnsafeIdent(column))
}
return q
}
func (q *InsertQuery) ColumnExpr(query string, args ...any) *InsertQuery {
q.addColumn(chschema.SafeQuery(query, args))
return q
}
func (q *InsertQuery) ExcludeColumn(columns ...string) *InsertQuery {
q.excludeColumn(columns)
return q
}
//------------------------------------------------------------------------------
func (q *InsertQuery) Where(query string, args ...any) *InsertQuery {
q.addWhere(chschema.SafeQueryWithSep(query, args, " AND "))
return q
}
func (q *InsertQuery) WhereOr(query string, args ...any) *InsertQuery {
q.addWhere(chschema.SafeQueryWithSep(query, args, " OR "))
return q
}
func (q *InsertQuery) WhereGroup(sep string, fn func(*WhereQuery)) *InsertQuery {
q.addWhereGroup(sep, fn)
return q
}
//------------------------------------------------------------------------------
func (q *InsertQuery) Operation() string {
return "INSERT"
}
var _ chschema.QueryAppender = (*InsertQuery)(nil)
func (q *InsertQuery) AppendQuery(fmter chschema.Formatter, b []byte) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
b = append(b, "INSERT INTO "...)
b, err = q.appendInsertTable(fmter, b)
if err != nil {
return nil, err
}
fields, err := q.getFields()
if err != nil {
return nil, err
}
if len(fields) > 0 {
b = append(b, " ("...)
b = appendColumns(b, "", fields)
b = append(b, ")"...)
}
b, err = q.appendValues(fmter, b)
if err != nil {
return nil, err
}
b, err = q.appendSettings(fmter, b)
if err != nil {
return nil, err
}
return b, nil
}
func (q *InsertQuery) appendValues(
fmter chschema.Formatter, b []byte,
) (_ []byte, err error) {
if !q.hasMultiTables() {
return append(b, " VALUES"...), nil
}
b = append(b, " SELECT "...)
fields, err := q.getFields()
if err != nil {
return nil, err
}
if len(fields) > 0 {
b = appendColumns(b, "", fields)
} else {
b = append(b, "*"...)
}
b = append(b, " FROM "...)
b, err = q.appendOtherTables(fmter, b)
if err != nil {
return nil, err
}
if len(q.where) > 0 {
b = append(b, " WHERE "...)
b, err = appendWhere(fmter, b, q.where)
if err != nil {
return nil, err
}
}
return b, nil
}
func (q *InsertQuery) appendInsertTable(fmter chschema.Formatter, b []byte) ([]byte, error) {
if !q.modelTableName.IsZero() {
return q.modelTableName.AppendQuery(fmter, b)
}
if q.table != nil {
return fmter.AppendQuery(b, string(q.table.CHInsertName)), nil
}
if len(q.tables) > 0 {
return q.tables[0].AppendQuery(fmter, b)
}
return nil, errors.New("ch: query does not have a table")
}
func (q *InsertQuery) Exec(ctx context.Context) (sql.Result, error) {
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
}
query := internal.String(queryBytes)
fields, err := q.getFields()
if err != nil {
return nil, err
}
ctx, evt := q.db.beforeQuery(ctx, q, query, nil, q.tableModel)
res, err := q.db.insert(ctx, q.tableModel, query, fields)
q.db.afterQuery(ctx, evt, res, err)
return res, err
}

616
ch/query_select.go Normal file
View File

@ -0,0 +1,616 @@
package ch
import (
"context"
"database/sql"
"errors"
"strconv"
"strings"
"sync"
"github.com/uptrace/go-clickhouse/ch/chschema"
"github.com/uptrace/go-clickhouse/ch/internal"
)
type SelectQuery struct {
whereBaseQuery
sample chschema.QueryWithArgs
distinctOn []chschema.QueryWithArgs
joins []joinQuery
group []chschema.QueryWithArgs
having []chschema.QueryWithArgs
order []chschema.QueryWithArgs
limit int
offset int
final bool
}
var _ Query = (*SelectQuery)(nil)
func NewSelectQuery(db *DB) *SelectQuery {
return &SelectQuery{
whereBaseQuery: whereBaseQuery{
baseQuery: baseQuery{
db: db,
},
},
}
}
func (q *SelectQuery) Operation() string {
return "SELECT"
}
func (q *SelectQuery) Model(model any) *SelectQuery {
q.setTableModel(model)
return q
}
func (q *SelectQuery) Err(err error) *SelectQuery {
q.setErr(err)
return q
}
func (q *SelectQuery) Apply(fn func(*SelectQuery) *SelectQuery) *SelectQuery {
return fn(q)
}
func (q *SelectQuery) WithAlias(name, query string, args ...any) *SelectQuery {
for i := range q.with {
with := &q.with[i]
if with.name == name {
with.query = chschema.SafeQuery(query, args)
return q
}
}
q.with = append(q.with, withQuery{
name: name,
query: chschema.SafeQuery(query, args),
})
return q
}
func (q *SelectQuery) With(name string, subq chschema.QueryAppender) *SelectQuery {
q.with = append(q.with, withQuery{
name: name,
query: subq,
cte: true,
})
return q
}
func (q *SelectQuery) Distinct() *SelectQuery {
q.distinctOn = make([]chschema.QueryWithArgs, 0)
return q
}
func (q *SelectQuery) DistinctOn(query string, args ...any) *SelectQuery {
q.distinctOn = append(q.distinctOn, chschema.SafeQuery(query, args))
return q
}
//------------------------------------------------------------------------------
func (q *SelectQuery) Table(tables ...string) *SelectQuery {
for _, table := range tables {
q.addTable(chschema.UnsafeIdent(table))
}
return q
}
func (q *SelectQuery) TableExpr(query string, args ...any) *SelectQuery {
q.addTable(chschema.SafeQuery(query, args))
return q
}
func (q *SelectQuery) ModelTableExpr(query string, args ...any) *SelectQuery {
q.modelTableName = chschema.SafeQuery(query, args)
return q
}
func (q *SelectQuery) Sample(query string, args ...any) *SelectQuery {
q.sample = chschema.SafeQuery(query, args)
return q
}
//------------------------------------------------------------------------------
func (q *SelectQuery) Column(columns ...string) *SelectQuery {
for _, column := range columns {
q.addColumn(chschema.UnsafeIdent(column))
}
return q
}
func (q *SelectQuery) ColumnExpr(query string, args ...any) *SelectQuery {
q.addColumn(chschema.SafeQuery(query, args))
return q
}
func (q *SelectQuery) ExcludeColumn(columns ...string) *SelectQuery {
q.excludeColumn(columns)
return q
}
//------------------------------------------------------------------------------
func (q *SelectQuery) Join(join string, args ...any) *SelectQuery {
q.joins = append(q.joins, joinQuery{
join: chschema.SafeQuery(join, args),
})
return q
}
func (q *SelectQuery) JoinOn(cond string, args ...any) *SelectQuery {
return q.joinOn(cond, args, " AND ")
}
func (q *SelectQuery) JoinOnOr(cond string, args ...any) *SelectQuery {
return q.joinOn(cond, args, " OR ")
}
func (q *SelectQuery) joinOn(cond string, args []any, sep string) *SelectQuery {
if len(q.joins) == 0 {
q.err = errors.New("ch: query has no joins")
return q
}
j := &q.joins[len(q.joins)-1]
j.on = append(j.on, chschema.SafeQueryWithSep(cond, args, sep))
return q
}
//------------------------------------------------------------------------------
func (q *SelectQuery) Where(query string, args ...any) *SelectQuery {
q.addWhere(chschema.SafeQueryWithSep(query, args, " AND "))
return q
}
func (q *SelectQuery) WhereOr(query string, args ...any) *SelectQuery {
q.addWhere(chschema.SafeQueryWithSep(query, args, " OR "))
return q
}
func (q *SelectQuery) WhereGroup(sep string, fn func(*WhereQuery)) *SelectQuery {
q.addWhereGroup(sep, fn)
return q
}
//------------------------------------------------------------------------------
func (q *SelectQuery) Group(columns ...string) *SelectQuery {
for _, column := range columns {
q.group = append(q.group, chschema.UnsafeIdent(column))
}
return q
}
func (q *SelectQuery) GroupExpr(group string, args ...any) *SelectQuery {
q.group = append(q.group, chschema.SafeQuery(group, args))
return q
}
func (q *SelectQuery) Having(having string, args ...any) *SelectQuery {
q.having = append(q.having, chschema.SafeQuery(having, args))
return q
}
func (q *SelectQuery) Order(orders ...string) *SelectQuery {
for _, order := range orders {
if order == "" {
continue
}
index := strings.IndexByte(order, ' ')
if index == -1 {
q.order = append(q.order, chschema.UnsafeIdent(order))
continue
}
field := order[:index]
sort := order[index+1:]
switch strings.ToUpper(sort) {
case "ASC", "DESC", "ASC NULLS FIRST", "DESC NULLS FIRST",
"ASC NULLS LAST", "DESC NULLS LAST":
q.order = append(q.order, chschema.SafeQuery("? ?", []any{
Ident(field),
Safe(sort),
}))
default:
q.order = append(q.order, chschema.UnsafeIdent(order))
}
}
return q
}
// Order adds sort order to the Query.
func (q *SelectQuery) OrderExpr(order string, args ...any) *SelectQuery {
q.order = append(q.order, chschema.SafeQuery(order, args))
return q
}
func (q *SelectQuery) Limit(limit int) *SelectQuery {
q.limit = limit
return q
}
func (q *SelectQuery) Offset(offset int) *SelectQuery {
q.offset = offset
return q
}
func (q *SelectQuery) Final() *SelectQuery {
q.final = true
return q
}
func (q *SelectQuery) Setting(query string, args ...any) *SelectQuery {
q.settings = append(q.settings, chschema.SafeQuery(query, args))
return q
}
//------------------------------------------------------------------------------
func (q *SelectQuery) String() string {
b, err := q.AppendQuery(q.db.fmter, nil)
if err != nil {
return err.Error()
}
return internal.String(b)
}
func (q *SelectQuery) AppendQuery(fmter chschema.Formatter, b []byte) (_ []byte, err error) {
return q.appendQuery(formatterWithModel(fmter, q), b, false)
}
func (q *SelectQuery) appendQuery(
fmter chschema.Formatter, b []byte, count bool,
) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
cteCount := count && (len(q.group) > 0 || len(q.distinctOn) > 0)
if cteCount {
b = append(b, `WITH "_count_wrapper" AS (`...)
}
if len(q.with) > 0 {
b, err = q.appendWith(fmter, b)
if err != nil {
return nil, err
}
}
b = append(b, "SELECT "...)
if len(q.distinctOn) > 0 {
b = append(b, "DISTINCT ON ("...)
for i, app := range q.distinctOn {
if i > 0 {
b = append(b, ", "...)
}
b, err = app.AppendQuery(fmter, b)
}
b = append(b, ") "...)
} else if q.distinctOn != nil {
b = append(b, "DISTINCT "...)
}
if count && !cteCount {
b = append(b, "count()"...)
} else {
b, err = q.appendColumns(fmter, b)
if err != nil {
return nil, err
}
}
if q.hasTables() {
b = append(b, " FROM "...)
b, err = q.appendTablesWithAlias(fmter, b)
if err != nil {
return nil, err
}
}
if !q.sample.IsZero() {
b = append(b, " SAMPLE "...)
b, err = q.sample.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
for _, j := range q.joins {
b = append(b, ' ')
b, err = j.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
b, err = q.appendWhere(fmter, b)
if err != nil {
return nil, err
}
if len(q.group) > 0 {
b = append(b, " GROUP BY "...)
for i, f := range q.group {
if i > 0 {
b = append(b, ", "...)
}
b, err = f.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
}
if len(q.having) > 0 {
b = append(b, " HAVING "...)
for i, f := range q.having {
if i > 0 {
b = append(b, " AND "...)
}
b = append(b, '(')
b, err = f.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
b = append(b, ')')
}
}
if !count {
if len(q.order) > 0 {
b = append(b, " ORDER BY "...)
for i, f := range q.order {
if i > 0 {
b = append(b, ", "...)
}
b, err = f.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
}
if q.limit > 0 {
b = append(b, " LIMIT "...)
b = strconv.AppendInt(b, int64(q.limit), 10)
}
if q.offset > 0 {
b = append(b, " OFFSET "...)
b = strconv.AppendInt(b, int64(q.offset), 10)
}
if q.final {
b = append(b, " FINAL"...)
}
} else if cteCount {
b = append(b, `) SELECT `...)
b = append(b, "count()"...)
b = append(b, ` FROM "_count_wrapper"`...)
}
b, err = q.appendSettings(fmter, b)
if err != nil {
return nil, err
}
return b, nil
}
func (q *SelectQuery) appendWith(fmter chschema.Formatter, b []byte) (_ []byte, err error) {
b = append(b, "WITH "...)
for i, with := range q.with {
if i > 0 {
b = append(b, ", "...)
}
if with.cte {
b = chschema.AppendIdent(b, with.name)
b = append(b, " AS "...)
b = append(b, "("...)
}
b, err = with.query.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
if with.cte {
b = append(b, ")"...)
} else {
b = append(b, " AS "...)
b = chschema.AppendIdent(b, with.name)
}
}
b = append(b, ' ')
return b, nil
}
func (q *SelectQuery) appendColumns(fmter chschema.Formatter, b []byte) (_ []byte, err error) {
switch {
case q.columns != nil:
for i, f := range q.columns {
if i > 0 {
b = append(b, ", "...)
}
b, err = f.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
case q.table != nil:
b = appendTableColumns(b, q.table.CHAlias, q.table.Fields)
default:
b = append(b, '*')
}
return b, nil
}
func appendTableColumns(b []byte, table chschema.Safe, fields []*chschema.Field) []byte {
for i, f := range fields {
if i > 0 {
b = append(b, ", "...)
}
if len(table) > 0 {
b = append(b, table...)
b = append(b, '.')
}
b = append(b, f.Column...)
}
return b
}
func (q *SelectQuery) Scan(ctx context.Context, values ...any) error {
return q.scan(ctx, false, values...)
}
func (q *SelectQuery) ScanColumns(ctx context.Context, values ...any) error {
return q.scan(ctx, true, values...)
}
func (q *SelectQuery) scan(ctx context.Context, columnar bool, values ...any) error {
if q.err != nil {
return q.err
}
model, err := q.newModel(values...)
if err != nil {
return err
}
if columnar {
model.(interface{ SetColumnar(bool) }).SetColumnar(true)
}
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return err
}
query := internal.String(queryBytes)
ctx, evt := q.db.beforeQuery(ctx, q, query, nil, model)
res, err := q.db.query(ctx, model, query)
q.db.afterQuery(ctx, evt, res, err)
if err != nil {
return err
}
if !columnar && useQueryRowModel(model) {
if res.affected == 0 {
return sql.ErrNoRows
}
}
return nil
}
func useQueryRowModel(model Model) bool {
if v, ok := model.(interface{ UseQueryRow() bool }); ok {
return v.UseQueryRow()
}
return false
}
// Count returns number of rows matching the query using count aggregate function.
func (q *SelectQuery) Count(ctx context.Context) (int, error) {
if q.err != nil {
return 0, q.err
}
queryBytes, err := q.appendQuery(q.db.fmter, nil, true)
if err != nil {
return 0, err
}
query := internal.String(queryBytes)
var count uint
err = q.db.QueryRowContext(ctx, query).Scan(&count)
return int(count), err
}
// SelectAndCount runs Select and Count in two goroutines,
// waits for them to finish and returns the result. If query limit is -1
// it does not select any data and only counts the results.
func (q *SelectQuery) ScanAndCount(
ctx context.Context, values ...any,
) (count int, firstErr error) {
if q.err != nil {
return 0, q.err
}
var wg sync.WaitGroup
var mu sync.Mutex
if q.limit >= 0 {
wg.Add(1)
go func() {
defer wg.Done()
err := q.Scan(ctx, values...)
if err != nil {
mu.Lock()
if firstErr == nil {
firstErr = err
}
mu.Unlock()
}
}()
}
wg.Add(1)
go func() {
defer wg.Done()
var err error
count, err = q.Count(ctx)
if err != nil {
mu.Lock()
if firstErr == nil {
firstErr = err
}
mu.Unlock()
}
}()
wg.Wait()
return count, firstErr
}
//------------------------------------------------------------------------------
type joinQuery struct {
join chschema.QueryWithArgs
on []chschema.QueryWithSep
}
func (j *joinQuery) AppendQuery(fmter chschema.Formatter, b []byte) (_ []byte, err error) {
b = append(b, ' ')
b, err = j.join.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
if len(j.on) > 0 {
b = append(b, " ON "...)
for i, on := range j.on {
if i > 0 {
b = append(b, on.Sep...)
}
b = append(b, '(')
b, err = on.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
b = append(b, ')')
}
}
return b, nil
}

223
ch/query_table_create.go Normal file
View File

@ -0,0 +1,223 @@
package ch
import (
"context"
"database/sql"
"github.com/uptrace/go-clickhouse/ch/chschema"
"github.com/uptrace/go-clickhouse/ch/internal"
)
type CreateTableQuery struct {
baseQuery
ifNotExists bool
engine chschema.QueryWithArgs
ttl chschema.QueryWithArgs
partition chschema.QueryWithArgs
order chschema.QueryWithArgs
}
var _ Query = (*CreateTableQuery)(nil)
func NewCreateTableQuery(db *DB) *CreateTableQuery {
return &CreateTableQuery{
baseQuery: baseQuery{
db: db,
},
}
}
func (q *CreateTableQuery) Model(model any) *CreateTableQuery {
q.setTableModel(model)
return q
}
// ------------------------------------------------------------------------------
func (q *CreateTableQuery) Table(tables ...string) *CreateTableQuery {
for _, table := range tables {
q.addTable(chschema.UnsafeIdent(table))
}
return q
}
func (q *CreateTableQuery) TableExpr(query string, args ...any) *CreateTableQuery {
q.addTable(chschema.SafeQuery(query, args))
return q
}
func (q *CreateTableQuery) ModelTableExpr(query string, args ...any) *CreateTableQuery {
q.modelTableName = chschema.SafeQuery(query, args)
return q
}
func (q *CreateTableQuery) ColumnExpr(query string, args ...any) *CreateTableQuery {
q.addColumn(chschema.SafeQuery(query, args))
return q
}
//------------------------------------------------------------------------------
func (q *CreateTableQuery) IfNotExists() *CreateTableQuery {
q.ifNotExists = true
return q
}
func (q *CreateTableQuery) Engine(query string, args ...any) *CreateTableQuery {
q.engine = chschema.SafeQuery(query, args)
return q
}
func (q *CreateTableQuery) TTL(query string, args ...any) *CreateTableQuery {
q.ttl = chschema.SafeQuery(query, args)
return q
}
func (q *CreateTableQuery) Partition(query string, args ...any) *CreateTableQuery {
q.partition = chschema.SafeQuery(query, args)
return q
}
func (q *CreateTableQuery) Order(query string, args ...any) *CreateTableQuery {
q.order = chschema.SafeQuery(query, args)
return q
}
func (q *CreateTableQuery) Setting(query string, args ...any) *CreateTableQuery {
q.settings = append(q.settings, chschema.SafeQuery(query, args))
return q
}
//------------------------------------------------------------------------------
func (q *CreateTableQuery) Operation() string {
return "CREATE TABLE"
}
var _ chschema.QueryAppender = (*CreateTableQuery)(nil)
func (q *CreateTableQuery) AppendQuery(fmter chschema.Formatter, b []byte) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
if q.table == nil {
return nil, errNilModel
}
b = append(b, "CREATE TABLE "...)
if q.ifNotExists {
b = append(b, "IF NOT EXISTS "...)
}
b, err = q.appendFirstTable(fmter, b)
if err != nil {
return nil, err
}
b = append(b, " ("...)
for i, field := range q.table.Fields {
if i > 0 {
b = append(b, ", "...)
}
b = append(b, field.CHName...)
b = append(b, " "...)
b = append(b, field.CHType...)
if field.NotNull {
b = append(b, " NOT NULL"...)
}
if field.CHDefault != "" {
b = append(b, " DEFAULT "...)
b = append(b, field.CHDefault...)
}
}
for i, col := range q.columns {
if i > 0 || len(q.table.Fields) > 0 {
b = append(b, ", "...)
}
b, err = col.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
b = append(b, ")"...)
b = append(b, " Engine = "...)
if !q.engine.IsZero() {
b, err = q.engine.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
} else if q.table.CHEngine != "" {
b = append(b, q.table.CHEngine...)
} else {
b = append(b, "MergeTree()"...)
}
b, err = q.appendPartition(fmter, b)
if err != nil {
return nil, err
}
if !q.order.IsZero() {
b = append(b, " ORDER BY ("...)
b, err = q.order.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
b = append(b, ')')
} else if len(q.table.PKs) > 0 {
b = append(b, " ORDER BY ("...)
for i, pk := range q.table.PKs {
if i > 0 {
b = append(b, ", "...)
}
b = append(b, pk.CHName...)
}
b = append(b, ')')
} else if q.table.CHEngine == "" {
b = append(b, " ORDER BY tuple()"...)
}
if !q.ttl.IsZero() {
b = append(b, " TTL "...)
b, err = q.ttl.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
b, err = q.appendSettings(fmter, b)
if err != nil {
return nil, err
}
return b, nil
}
func (q *CreateTableQuery) appendPartition(fmter chschema.Formatter, b []byte) ([]byte, error) {
if q.partition.IsZero() && q.table.CHPartition == "" {
return b, nil
}
b = append(b, " PARTITION BY "...)
if !q.partition.IsZero() {
return q.partition.AppendQuery(fmter, b)
}
return append(b, q.table.CHPartition...), nil
}
func (q *CreateTableQuery) Exec(ctx context.Context) (sql.Result, error) {
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
}
query := internal.String(queryBytes)
return q.exec(ctx, q, query)
}

93
ch/query_table_drop.go Normal file
View File

@ -0,0 +1,93 @@
package ch
import (
"context"
"database/sql"
"github.com/uptrace/go-clickhouse/ch/chschema"
"github.com/uptrace/go-clickhouse/ch/internal"
)
type DropTableQuery struct {
baseQuery
ifExists bool
}
var _ Query = (*DropTableQuery)(nil)
func NewDropTableQuery(db *DB) *DropTableQuery {
q := &DropTableQuery{
baseQuery: baseQuery{
db: db,
},
}
return q
}
func (q *DropTableQuery) Model(model any) *DropTableQuery {
q.setTableModel(model)
return q
}
//------------------------------------------------------------------------------
func (q *DropTableQuery) Table(tables ...string) *DropTableQuery {
for _, table := range tables {
q.addTable(chschema.UnsafeIdent(table))
}
return q
}
func (q *DropTableQuery) TableExpr(query string, args ...any) *DropTableQuery {
q.addTable(chschema.SafeQuery(query, args))
return q
}
func (q *DropTableQuery) ModelTableExpr(query string, args ...any) *DropTableQuery {
q.modelTableName = chschema.SafeQuery(query, args)
return q
}
//------------------------------------------------------------------------------
func (q *DropTableQuery) IfExists() *DropTableQuery {
q.ifExists = true
return q
}
//------------------------------------------------------------------------------
func (q *DropTableQuery) Operation() string {
return "DROP TABLE"
}
func (q *DropTableQuery) AppendQuery(fmter chschema.Formatter, b []byte) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
b = append(b, "DROP TABLE "...)
if q.ifExists {
b = append(b, "IF EXISTS "...)
}
b, err = q.appendTables(fmter, b)
if err != nil {
return nil, err
}
return b, nil
}
//------------------------------------------------------------------------------
func (q *DropTableQuery) Exec(ctx context.Context, dest ...any) (sql.Result, error) {
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
}
query := internal.String(queryBytes)
return q.exec(ctx, q, query)
}

View File

@ -0,0 +1,95 @@
package ch
import (
"context"
"database/sql"
"github.com/uptrace/go-clickhouse/ch/chschema"
"github.com/uptrace/go-clickhouse/ch/internal"
)
type TruncateTableQuery struct {
baseQuery
ifExists bool
}
var _ Query = (*TruncateTableQuery)(nil)
func NewTruncateTableQuery(db *DB) *TruncateTableQuery {
q := &TruncateTableQuery{
baseQuery: baseQuery{
db: db,
},
}
return q
}
func (q *TruncateTableQuery) Model(model any) *TruncateTableQuery {
q.setTableModel(model)
return q
}
//------------------------------------------------------------------------------
func (q *TruncateTableQuery) Table(tables ...string) *TruncateTableQuery {
for _, table := range tables {
q.addTable(chschema.UnsafeIdent(table))
}
return q
}
func (q *TruncateTableQuery) TableExpr(query string, args ...any) *TruncateTableQuery {
q.addTable(chschema.SafeQuery(query, args))
return q
}
//------------------------------------------------------------------------------
func (q *TruncateTableQuery) IfExists() *TruncateTableQuery {
q.ifExists = true
return q
}
//------------------------------------------------------------------------------
func (q *TruncateTableQuery) Operation() string {
return "TRUNCATE TABLE"
}
func (q *TruncateTableQuery) AppendQuery(
fmter chschema.Formatter, b []byte,
) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
b = append(b, "TRUNCATE TABLE "...)
if q.ifExists {
b = append(b, "IF EXISTS "...)
}
b, err = q.appendTables(fmter, b)
if err != nil {
return nil, err
}
return b, nil
}
//------------------------------------------------------------------------------
func (q *TruncateTableQuery) Exec(ctx context.Context, dest ...any) (sql.Result, error) {
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
}
query := internal.String(queryBytes)
res, err := q.exec(ctx, q, query)
if err != nil {
return nil, err
}
return res, nil
}

99
ch/query_test.go Normal file
View File

@ -0,0 +1,99 @@
package ch_test
import (
"fmt"
"path/filepath"
"testing"
"time"
"github.com/bradleyjkemp/cupaloy"
"github.com/uptrace/go-clickhouse/ch"
"github.com/uptrace/go-clickhouse/ch/chschema"
)
func TestQuery(t *testing.T) {
type Model struct {
ID uint64
String string
Bytes []byte
}
queries := []func(db *ch.DB) chschema.QueryAppender{
func(db *ch.DB) chschema.QueryAppender {
return db.NewCreateTable().Model((*Model)(nil))
},
func(db *ch.DB) chschema.QueryAppender {
return db.NewDropTable().Model((*Model)(nil))
},
func(db *ch.DB) chschema.QueryAppender {
return db.NewSelect().Model((*Model)(nil))
},
func(db *ch.DB) chschema.QueryAppender {
return db.NewSelect().Model((*Model)(nil)).ExcludeColumn("bytes")
},
func(db *ch.DB) chschema.QueryAppender {
return db.NewInsert().Model(new(Model))
},
func(db *ch.DB) chschema.QueryAppender {
return db.NewTruncateTable().Model(new(Model))
},
func(db *ch.DB) chschema.QueryAppender {
return db.NewSelect().
Model((*Model)(nil)).
Setting("max_rows_to_read = 100")
},
func(db *ch.DB) chschema.QueryAppender {
return db.NewSelect().
Model((*Model)(nil)).
Setting("max_rows_to_read = 100").
Setting("read_overflow_mode = 'break'")
},
func(db *ch.DB) chschema.QueryAppender {
return db.NewInsert().
TableExpr("dest").
TableExpr("src").
Where("_part = ?", "part_name").
Setting("max_threads = 1").
Setting("max_insert_threads = 1").
Setting("max_execution_time = 0")
},
func(db *ch.DB) chschema.QueryAppender {
return db.NewSelect().
Model((*Model)(nil)).
Sample("?", 1000)
},
func(db *ch.DB) chschema.QueryAppender {
type Model struct {
ch.CHModel `ch:"table:spans,partition:toYYYYMM(time)"`
ID uint64
Text string `ch:",lc"` // low cardinality column
Time time.Time `ch:",pk"` // ClickHouse primary key for order by
}
return db.NewCreateTable().Model((*Model)(nil)).
TTL("time + INTERVAL 30 DAY DELETE").
Partition("toDate(time)").
Order("id").
Setting("ttl_only_drop_parts = 1")
},
}
db := chDB()
defer db.Close()
snapshotsDir := filepath.Join("testdata", "snapshots")
snapshot := cupaloy.New(cupaloy.SnapshotSubdirectory(snapshotsDir))
for i, fn := range queries {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
q := fn(db)
query, err := q.AppendQuery(db.Formatter(), nil)
if err != nil {
snapshot.SnapshotT(t, err.Error())
} else {
snapshot.SnapshotT(t, string(query))
}
})
}
}

31
ch/reflect.go Normal file
View File

@ -0,0 +1,31 @@
package ch
import (
"reflect"
)
func indirect(v reflect.Value) reflect.Value {
switch v.Kind() {
case reflect.Interface:
return indirect(v.Elem())
case reflect.Ptr:
return v.Elem()
default:
return v
}
}
func indirectType(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}
func sliceElemType(v reflect.Value) reflect.Type {
elemType := v.Type().Elem()
if elemType.Kind() == reflect.Interface && v.Len() > 0 {
return indirect(v.Index(0).Elem()).Type()
}
return indirectType(elemType)
}

1
ch/testdata/snapshots/TestQuery-0 vendored Normal file
View File

@ -0,0 +1 @@
CREATE TABLE "models" (id UInt64, string String, bytes String) Engine = MergeTree() ORDER BY tuple()

1
ch/testdata/snapshots/TestQuery-1 vendored Normal file
View File

@ -0,0 +1 @@
DROP TABLE "models"

1
ch/testdata/snapshots/TestQuery-10 vendored Normal file
View File

@ -0,0 +1 @@
CREATE TABLE "spans" (id UInt64, text LowCardinality(String), time DateTime) Engine = MergeTree() PARTITION BY toDate(time) ORDER BY (id) TTL time + INTERVAL 30 DAY DELETE SETTINGS ttl_only_drop_parts = 1

1
ch/testdata/snapshots/TestQuery-2 vendored Normal file
View File

@ -0,0 +1 @@
SELECT "model"."id", "model"."string", "model"."bytes" FROM "models" AS "model"

1
ch/testdata/snapshots/TestQuery-3 vendored Normal file
View File

@ -0,0 +1 @@
SELECT "id", "string" FROM "models" AS "model"

1
ch/testdata/snapshots/TestQuery-4 vendored Normal file
View File

@ -0,0 +1 @@
INSERT INTO "models" ("id", "string", "bytes") VALUES

1
ch/testdata/snapshots/TestQuery-5 vendored Normal file
View File

@ -0,0 +1 @@
TRUNCATE TABLE "models"

1
ch/testdata/snapshots/TestQuery-6 vendored Normal file
View File

@ -0,0 +1 @@
SELECT "model"."id", "model"."string", "model"."bytes" FROM "models" AS "model" SETTINGS max_rows_to_read = 100

1
ch/testdata/snapshots/TestQuery-7 vendored Normal file
View File

@ -0,0 +1 @@
SELECT "model"."id", "model"."string", "model"."bytes" FROM "models" AS "model" SETTINGS max_rows_to_read = 100, read_overflow_mode = 'break'

1
ch/testdata/snapshots/TestQuery-8 vendored Normal file
View File

@ -0,0 +1 @@
INSERT INTO dest SELECT * FROM src WHERE (_part = 'part_name') SETTINGS max_threads = 1, max_insert_threads = 1, max_execution_time = 0

1
ch/testdata/snapshots/TestQuery-9 vendored Normal file
View File

@ -0,0 +1 @@
SELECT "model"."id", "model"."string", "model"."bytes" FROM "models" AS "model" SAMPLE 1000

3
chdebug/README.md Normal file
View File

@ -0,0 +1,3 @@
# Logging executed queries with go-clickhouse
See [documentation](https://clickhouse.uptrace.dev/guide/debugging.html) for details.

133
chdebug/debug.go Normal file
View File

@ -0,0 +1,133 @@
package chdebug
import (
"context"
"database/sql"
"fmt"
"io"
"os"
"reflect"
"time"
"github.com/fatih/color"
"github.com/uptrace/go-clickhouse/ch"
)
type Option func(*QueryHook)
// WithEnabled enables/disables the hook.
func WithEnabled(on bool) Option {
return func(h *QueryHook) {
h.enabled = on
}
}
// WithVerbose configures the hook to log all queries
// (by default, only failed queries are logged).
func WithVerbose(on bool) Option {
return func(h *QueryHook) {
h.verbose = on
}
}
// WithWriter sets the log output to an io.Writer
// the default is os.Stderr
func WithWriter(w io.Writer) Option {
return func(h *QueryHook) {
h.writer = w
}
}
// FromEnv configures the hook using the environment variable value.
// For example, WithEnv("CHDEBUG"):
// - CHDEBUG=0 - disables the hook.
// - CHDEBUG=1 - enables the hook.
// - CHDEBUG=2 - enables the hook and verbose mode.
func FromEnv(key string) Option {
if key == "" {
key = "CHDEBUG"
}
return func(h *QueryHook) {
if env, ok := os.LookupEnv(key); ok {
h.enabled = env != "" && env != "0"
h.verbose = env == "2"
}
}
}
type QueryHook struct {
enabled bool
verbose bool
writer io.Writer
}
var _ ch.QueryHook = (*QueryHook)(nil)
func NewQueryHook(opts ...Option) *QueryHook {
h := &QueryHook{
enabled: true,
writer: os.Stderr,
}
for _, opt := range opts {
opt(h)
}
return h
}
func (h *QueryHook) BeforeQuery(ctx context.Context, evt *ch.QueryEvent) context.Context {
return ctx
}
func (h *QueryHook) AfterQuery(ctx context.Context, event *ch.QueryEvent) {
if !h.enabled {
return
}
if !h.verbose {
switch event.Err {
case nil, sql.ErrNoRows:
return
}
}
now := time.Now()
dur := now.Sub(event.StartTime)
args := []any{
"[ch]",
now.Format(" 15:04:05.000 "),
formatOperation(event),
fmt.Sprintf(" %10s ", dur.Round(time.Microsecond)),
event.Query,
}
if event.Err != nil {
typ := reflect.TypeOf(event.Err).String()
args = append(args,
"\t",
color.New(color.BgRed).Sprintf(" %s ", typ+": "+event.Err.Error()),
)
}
fmt.Fprintln(h.writer, args...)
}
func formatOperation(event *ch.QueryEvent) string {
operation := event.Operation()
return operationColor(operation).Sprintf(" %-16s ", operation)
}
func operationColor(operation string) *color.Color {
switch operation {
case "SELECT":
return color.New(color.BgGreen, color.FgHiWhite)
case "INSERT":
return color.New(color.BgBlue, color.FgHiWhite)
case "UPDATE":
return color.New(color.BgYellow, color.FgHiBlack)
case "DELETE":
return color.New(color.BgMagenta, color.FgHiWhite)
default:
return color.New(color.BgWhite, color.FgHiBlack)
}
}

20
chdebug/go.mod Normal file
View File

@ -0,0 +1,20 @@
module github.com/uptrace/go-clickhouse/chdebug
go 1.18
replace github.com/uptrace/go-clickhouse => ./..
require (
github.com/fatih/color v1.13.0
github.com/uptrace/go-clickhouse v0.1.1
)
require (
github.com/codemodus/kace v0.5.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/pierrec/lz4/v4 v4.1.14 // indirect
golang.org/x/exp v0.0.0-20220317015231-48e79f11773a // indirect
golang.org/x/sys v0.0.0-20220317061510-51cd9980dadf // indirect
)

27
chdebug/go.sum Normal file
View File

@ -0,0 +1,27 @@
github.com/bradleyjkemp/cupaloy v2.3.0+incompatible h1:UafIjBvWQmS9i/xRg+CamMrnLTKNzo+bdmT/oH34c2Y=
github.com/codemodus/kace v0.5.1 h1:4OCsBlE2c/rSJo375ggfnucv9eRzge/U5LrrOZd47HA=
github.com/codemodus/kace v0.5.1/go.mod h1:coddaHoX1ku1YFSe4Ip0mL9kQjJvKkzb9CfIdG1YR04=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/pierrec/lz4/v4 v4.1.14 h1:+fL8AQEZtz/ijeNnpduH0bROTu0O3NZAlPjQxGn8LwE=
github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
golang.org/x/exp v0.0.0-20220317015231-48e79f11773a h1:DAzrdbxsb5tXNOhMCSwF7ZdfMbW46hE9fSVO6BsmUZM=
golang.org/x/exp v0.0.0-20220317015231-48e79f11773a/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220317061510-51cd9980dadf h1:Fm4IcnUL803i92qDlmB0obyHmosDrxZWxJL3gIeNqOw=
golang.org/x/sys v0.0.0-20220317061510-51cd9980dadf/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=

248
chmigrate/migration.go Normal file
View File

@ -0,0 +1,248 @@
package chmigrate
import (
"bufio"
"bytes"
"context"
"fmt"
"io/fs"
"sort"
"strings"
"time"
"github.com/uptrace/go-clickhouse/ch"
)
type Migration struct {
ch.CHModel `ch:"engine:CollapsingMergeTree(sign)"`
Name string `ch:",pk"`
GroupID int64
MigratedAt time.Time
Sign int8
Up MigrationFunc `ch:"-"`
Down MigrationFunc `ch:"-"`
}
func (m *Migration) String() string {
return m.Name
}
func (m *Migration) IsApplied() bool {
return !m.MigratedAt.IsZero()
}
type MigrationFunc func(ctx context.Context, db *ch.DB) error
func NewSQLMigrationFunc(fsys fs.FS, name string) MigrationFunc {
return func(ctx context.Context, db *ch.DB) error {
f, err := fsys.Open(name)
if err != nil {
return err
}
scanner := bufio.NewScanner(f)
var queries []string
var query []byte
for scanner.Scan() {
b := scanner.Bytes()
const prefix = "--migration:"
if bytes.HasPrefix(b, []byte(prefix)) {
b = b[len(prefix):]
if bytes.Equal(b, []byte("split")) {
queries = append(queries, string(query))
query = query[:0]
continue
}
return fmt.Errorf("ch: unknown directive: %q", b)
}
query = append(query, b...)
query = append(query, '\n')
}
if len(query) > 0 {
queries = append(queries, string(query))
}
if err := scanner.Err(); err != nil {
return err
}
for _, q := range queries {
_, err = db.ExecContext(ctx, q)
if err != nil {
return err
}
}
return nil
}
}
const goTemplate = `package %s
import (
"context"
"fmt"
"github.com/uptrace/go-clickhouse/ch"
)
func init() {
Migrations.MustRegister(func(ctx context.Context, db *ch.DB) error {
fmt.Print(" [up migration] ")
return nil
}, func(ctx context.Context, db *ch.DB) error {
fmt.Print(" [down migration] ")
return nil
})
}
`
const sqlTemplate = `SELECT 1
--migration:split
SELECT 2
`
//------------------------------------------------------------------------------
type MigrationSlice []Migration
func (ms MigrationSlice) String() string {
if len(ms) == 0 {
return "empty"
}
if len(ms) > 5 {
return fmt.Sprintf("%d migrations (%s ... %s)", len(ms), ms[0].Name, ms[len(ms)-1].Name)
}
var sb strings.Builder
for i := range ms {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(ms[i].Name)
}
return sb.String()
}
// Applied returns applied migrations in descending order
// (the order is important and is used in Rollback).
func (ms MigrationSlice) Applied() MigrationSlice {
var applied MigrationSlice
for i := range ms {
if ms[i].IsApplied() {
applied = append(applied, ms[i])
}
}
sortDesc(applied)
return applied
}
// Unapplied returns unapplied migrations in ascending order
// (the order is important and is used in Migrate).
func (ms MigrationSlice) Unapplied() MigrationSlice {
var unapplied MigrationSlice
for i := range ms {
if !ms[i].IsApplied() {
unapplied = append(unapplied, ms[i])
}
}
sortAsc(unapplied)
return unapplied
}
// LastGroupID returns the last applied migration group id.
// The id is 0 when there are no migration groups.
func (ms MigrationSlice) LastGroupID() int64 {
var lastGroupID int64
for i := range ms {
groupID := ms[i].GroupID
if groupID > lastGroupID {
lastGroupID = groupID
}
}
return lastGroupID
}
// LastGroup returns the last applied migration group.
func (ms MigrationSlice) LastGroup() *MigrationGroup {
group := &MigrationGroup{
ID: ms.LastGroupID(),
}
if group.ID == 0 {
return group
}
for i := range ms {
if ms[i].GroupID == group.ID {
group.Migrations = append(group.Migrations, ms[i])
}
}
return group
}
type MigrationGroup struct {
ID int64
Migrations MigrationSlice
}
func (g *MigrationGroup) IsZero() bool {
return g.ID == 0 && len(g.Migrations) == 0
}
func (g *MigrationGroup) String() string {
if g.IsZero() {
return "nil"
}
return fmt.Sprintf("group #%d (%s)", g.ID, g.Migrations)
}
type MigrationFile struct {
Name string
Path string
Content string
}
//------------------------------------------------------------------------------
type migrationConfig struct {
nop bool
}
func newMigrationConfig(opts []MigrationOption) *migrationConfig {
cfg := new(migrationConfig)
for _, opt := range opts {
opt(cfg)
}
return cfg
}
type MigrationOption func(cfg *migrationConfig)
func WithNopMigration() MigrationOption {
return func(cfg *migrationConfig) {
cfg.nop = true
}
}
//------------------------------------------------------------------------------
func sortAsc(ms MigrationSlice) {
sort.Slice(ms, func(i, j int) bool {
return ms[i].Name < ms[j].Name
})
}
func sortDesc(ms MigrationSlice) {
sort.Slice(ms, func(i, j int) bool {
return ms[i].Name > ms[j].Name
})
}

168
chmigrate/migrations.go Normal file
View File

@ -0,0 +1,168 @@
package chmigrate
import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"regexp"
"runtime"
"strings"
)
type MigrationsOption func(m *Migrations)
func WithMigrationsDirectory(directory string) MigrationsOption {
return func(m *Migrations) {
m.explicitDirectory = directory
}
}
type Migrations struct {
ms MigrationSlice
explicitDirectory string
implicitDirectory string
}
func NewMigrations(opts ...MigrationsOption) *Migrations {
m := new(Migrations)
for _, opt := range opts {
opt(m)
}
m.implicitDirectory = filepath.Dir(migrationFile())
return m
}
func (m *Migrations) Sorted() MigrationSlice {
migrations := make(MigrationSlice, len(m.ms))
copy(migrations, m.ms)
sortAsc(migrations)
return migrations
}
func (m *Migrations) MustRegister(up, down MigrationFunc) {
if err := m.Register(up, down); err != nil {
panic(err)
}
}
func (m *Migrations) Register(up, down MigrationFunc) error {
fpath := migrationFile()
name, err := extractMigrationName(fpath)
if err != nil {
return err
}
m.Add(Migration{
Name: name,
Up: up,
Down: down,
})
return nil
}
func (m *Migrations) Add(migration Migration) {
if migration.Name == "" {
panic("migration name is required")
}
m.ms = append(m.ms, migration)
}
func (m *Migrations) DiscoverCaller() error {
dir := filepath.Dir(migrationFile())
return m.Discover(os.DirFS(dir))
}
func (m *Migrations) Discover(fsys fs.FS) error {
return fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
if !strings.HasSuffix(path, ".up.sql") && !strings.HasSuffix(path, ".down.sql") {
return nil
}
name, err := extractMigrationName(path)
if err != nil {
return err
}
migration := m.getOrCreateMigration(name)
if err != nil {
return err
}
migrationFunc := NewSQLMigrationFunc(fsys, path)
if strings.HasSuffix(path, ".up.sql") {
migration.Up = migrationFunc
return nil
}
if strings.HasSuffix(path, ".down.sql") {
migration.Down = migrationFunc
return nil
}
return errors.New("chmigrate: not reached")
})
}
func (m *Migrations) getOrCreateMigration(name string) *Migration {
for i := range m.ms {
m := &m.ms[i]
if m.Name == name {
return m
}
}
m.ms = append(m.ms, Migration{Name: name})
return &m.ms[len(m.ms)-1]
}
func (m *Migrations) getDirectory() string {
if m.explicitDirectory != "" {
return m.explicitDirectory
}
if m.implicitDirectory != "" {
return m.implicitDirectory
}
return filepath.Dir(migrationFile())
}
func migrationFile() string {
const depth = 32
var pcs [depth]uintptr
n := runtime.Callers(1, pcs[:])
frames := runtime.CallersFrames(pcs[:n])
for {
f, ok := frames.Next()
if !ok {
break
}
if !strings.Contains(f.Function, "/chmigrate.") {
return f.File
}
}
return ""
}
var fnameRE = regexp.MustCompile(`^(\d{14})_[0-9a-z_\-]+\.`)
func extractMigrationName(fpath string) (string, error) {
fname := filepath.Base(fpath)
matches := fnameRE.FindStringSubmatch(fname)
if matches == nil {
return "", fmt.Errorf("chmigrate: unsupported migration name format: %q", fname)
}
return matches[1], nil
}

379
chmigrate/migrator.go Normal file
View File

@ -0,0 +1,379 @@
package chmigrate
import (
"context"
"errors"
"fmt"
"io/ioutil"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/uptrace/go-clickhouse/ch"
)
type MigratorOption func(m *Migrator)
func WithTableName(table string) MigratorOption {
return func(m *Migrator) {
m.table = table
}
}
func WithLocksTableName(table string) MigratorOption {
return func(m *Migrator) {
m.locksTable = table
}
}
type Migrator struct {
db *ch.DB
migrations *Migrations
ms MigrationSlice
table string
locksTable string
}
func NewMigrator(db *ch.DB, migrations *Migrations, opts ...MigratorOption) *Migrator {
m := &Migrator{
db: db,
migrations: migrations,
ms: migrations.ms,
table: "ch_migrations",
locksTable: "ch_migration_locks",
}
for _, opt := range opts {
opt(m)
}
return m
}
func (m *Migrator) DB() *ch.DB {
return m.db
}
// MigrationsWithStatus returns migrations with status in ascending order.
func (m *Migrator) MigrationsWithStatus(ctx context.Context) (MigrationSlice, error) {
sorted, _, err := m.migrationsWithStatus(ctx)
return sorted, err
}
func (m *Migrator) migrationsWithStatus(ctx context.Context) (MigrationSlice, int64, error) {
sorted := m.migrations.Sorted()
applied, err := m.selectAppliedMigrations(ctx)
if err != nil {
return nil, 0, err
}
appliedMap := migrationMap(applied)
for i := range sorted {
m1 := &sorted[i]
if m2, ok := appliedMap[m1.Name]; ok {
m1.GroupID = m2.GroupID
m1.MigratedAt = m2.MigratedAt
}
}
return sorted, applied.LastGroupID(), nil
}
func (m *Migrator) Init(ctx context.Context) error {
if _, err := m.db.NewCreateTable().
Model((*Migration)(nil)).
ModelTableExpr(m.table).
IfNotExists().
Exec(ctx); err != nil {
return err
}
if _, err := m.db.NewCreateTable().
Model((*migrationLock)(nil)).
ModelTableExpr(m.locksTable).
IfNotExists().
Exec(ctx); err != nil {
return err
}
return nil
}
func (m *Migrator) Reset(ctx context.Context) error {
if _, err := m.db.NewDropTable().
Model((*Migration)(nil)).
ModelTableExpr(m.table).
IfExists().
Exec(ctx); err != nil {
return err
}
if _, err := m.db.NewDropTable().
Model((*migrationLock)(nil)).
ModelTableExpr(m.locksTable).
IfExists().
Exec(ctx); err != nil {
return err
}
return m.Init(ctx)
}
// Migrate runs unapplied migrations. If a migration fails, migrate immediately exits.
func (m *Migrator) Migrate(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) {
cfg := newMigrationConfig(opts)
if err := m.validate(); err != nil {
return nil, err
}
if err := m.Lock(ctx); err != nil {
return nil, err
}
defer m.Unlock(ctx) //nolint:errcheck
migrations, lastGroupID, err := m.migrationsWithStatus(ctx)
if err != nil {
return nil, err
}
group := &MigrationGroup{
Migrations: migrations.Unapplied(),
}
if len(group.Migrations) == 0 {
return group, nil
}
group.ID = lastGroupID + 1
for i := range group.Migrations {
migration := &group.Migrations[i]
migration.GroupID = group.ID
// Always mark migration as applied so the rollback has a chance to fix the database.
if err := m.MarkApplied(ctx, migration); err != nil {
return nil, err
}
if !cfg.nop && migration.Up != nil {
if err := migration.Up(ctx, m.db); err != nil {
return group, err
}
}
}
return group, nil
}
func (m *Migrator) Rollback(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) {
cfg := newMigrationConfig(opts)
if err := m.validate(); err != nil {
return nil, err
}
if err := m.Lock(ctx); err != nil {
return nil, err
}
defer m.Unlock(ctx) //nolint:errcheck
migrations, err := m.MigrationsWithStatus(ctx)
if err != nil {
return nil, err
}
lastGroup := migrations.LastGroup()
for i := len(lastGroup.Migrations) - 1; i >= 0; i-- {
migration := &lastGroup.Migrations[i]
if !cfg.nop && migration.Down != nil {
if err := migration.Down(ctx, m.db); err != nil {
return nil, err
}
}
if err := m.MarkUnapplied(ctx, migration); err != nil {
return nil, err
}
}
return lastGroup, nil
}
type goMigrationConfig struct {
packageName string
}
type GoMigrationOption func(cfg *goMigrationConfig)
func WithPackageName(name string) GoMigrationOption {
return func(cfg *goMigrationConfig) {
cfg.packageName = name
}
}
// CreateGoMigration creates a Go migration file.
func (m *Migrator) CreateGoMigration(
ctx context.Context, name string, opts ...GoMigrationOption,
) (*MigrationFile, error) {
cfg := &goMigrationConfig{
packageName: "migrations",
}
for _, opt := range opts {
opt(cfg)
}
name, err := m.genMigrationName(name)
if err != nil {
return nil, err
}
fname := name + ".go"
fpath := filepath.Join(m.migrations.getDirectory(), fname)
content := fmt.Sprintf(goTemplate, cfg.packageName)
if err := ioutil.WriteFile(fpath, []byte(content), 0o644); err != nil {
return nil, err
}
mf := &MigrationFile{
Name: fname,
Path: fpath,
Content: content,
}
return mf, nil
}
// CreateSQLMigrations creates an up and down SQL migration files.
func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) {
name, err := m.genMigrationName(name)
if err != nil {
return nil, err
}
up, err := m.createSQL(ctx, name+".up.sql")
if err != nil {
return nil, err
}
down, err := m.createSQL(ctx, name+".down.sql")
if err != nil {
return nil, err
}
return []*MigrationFile{up, down}, nil
}
func (m *Migrator) createSQL(ctx context.Context, fname string) (*MigrationFile, error) {
fpath := filepath.Join(m.migrations.getDirectory(), fname)
if err := ioutil.WriteFile(fpath, []byte(sqlTemplate), 0o644); err != nil {
return nil, err
}
mf := &MigrationFile{
Name: fname,
Path: fpath,
Content: goTemplate,
}
return mf, nil
}
var nameRE = regexp.MustCompile(`^[0-9a-z_\-]+$`)
func (m *Migrator) genMigrationName(name string) (string, error) {
const timeFormat = "20060102150405"
if name == "" {
return "", errors.New("chmigrate: migration name can't be empty")
}
if !nameRE.MatchString(name) {
return "", fmt.Errorf("chmigrate: invalid migration name: %q", name)
}
version := time.Now().UTC().Format(timeFormat)
return fmt.Sprintf("%s_%s", version, name), nil
}
// MarkApplied marks the migration as applied (completed).
func (m *Migrator) MarkApplied(ctx context.Context, migration *Migration) error {
migration.Sign = 1
migration.MigratedAt = time.Now()
_, err := m.db.NewInsert().
Model(migration).
ModelTableExpr(m.table).
Exec(ctx)
return err
}
// MarkUnapplied marks the migration as unapplied (new).
func (m *Migrator) MarkUnapplied(ctx context.Context, migration *Migration) error {
migration.Sign = -1
_, err := m.db.NewInsert().
Model(migration).
ModelTableExpr(m.table).
Exec(ctx)
return err
}
// selectAppliedMigrations selects applied (applied) migrations in descending order.
func (m *Migrator) selectAppliedMigrations(ctx context.Context) (MigrationSlice, error) {
var ms MigrationSlice
if err := m.db.NewSelect().
ColumnExpr("*").
Model(&ms).
ModelTableExpr(m.table).
Final().
Scan(ctx); err != nil {
return nil, err
}
return ms, nil
}
func (m *Migrator) formattedTableName(db *ch.DB) string {
return db.Formatter().FormatQuery(m.table)
}
func (m *Migrator) validate() error {
if len(m.ms) == 0 {
return errors.New("chmigrate: there are no any migrations")
}
return nil
}
//------------------------------------------------------------------------------
type migrationLock struct {
A int8
}
func (m *Migrator) Lock(ctx context.Context) error {
if _, err := m.db.ExecContext(
ctx,
"ALTER TABLE ? ADD COLUMN ? Int8",
ch.Safe(m.locksTable), ch.Safe("col1"),
); err != nil {
return fmt.Errorf("chmigrate: migrations table is already locked (%w)", err)
}
return nil
}
func (m *Migrator) Unlock(ctx context.Context) error {
if _, err := m.db.ExecContext(
ctx,
"ALTER TABLE ? DROP COLUMN ?",
ch.Safe(m.locksTable), ch.Safe("col1"),
); err != nil && !strings.Contains(err.Error(), "Cannot find column") {
return fmt.Errorf("chmigrate: migrations table is already unlocked (%w)", err)
}
return nil
}
func migrationMap(ms MigrationSlice) map[string]*Migration {
mp := make(map[string]*Migration)
for i := range ms {
m := &ms[i]
mp[m.Name] = m
}
return mp
}

3
chotel/README.md Normal file
View File

@ -0,0 +1,3 @@
# OpenTelemetry instrumentation for go-clickhouse
See [documentation](https://clickhouse.uptrace.dev/guide/monitoring.html) for details.

106
chotel/chotel.go Normal file
View File

@ -0,0 +1,106 @@
package chotel
import (
"context"
"database/sql"
"runtime"
"strings"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
semconv "go.opentelemetry.io/otel/semconv/v1.7.0"
"go.opentelemetry.io/otel/trace"
"github.com/uptrace/go-clickhouse/ch"
)
var tracer = otel.Tracer("go-clickhouse")
type QueryHook struct{}
var _ ch.QueryHook = (*QueryHook)(nil)
func NewQueryHook() *QueryHook {
return &QueryHook{}
}
func (h *QueryHook) BeforeQuery(
ctx context.Context, evt *ch.QueryEvent,
) context.Context {
if !trace.SpanFromContext(ctx).IsRecording() {
return ctx
}
ctx, _ = tracer.Start(ctx, "", trace.WithSpanKind(trace.SpanKindClient))
return ctx
}
func (h *QueryHook) AfterQuery(ctx context.Context, event *ch.QueryEvent) {
span := trace.SpanFromContext(ctx)
if !span.IsRecording() {
return
}
defer span.End()
operation := event.Operation()
fn, file, line := funcFileLine("go-clickhouse")
span.SetName(operation)
attrs := []attribute.KeyValue{
semconv.CodeFunctionKey.String(fn),
semconv.CodeFilepathKey.String(file),
semconv.CodeLineNumberKey.Int(line),
semconv.DBSystemKey.String("clickhouse"),
semconv.DBOperationKey.String(operation),
semconv.DBStatementKey.String(event.Query),
}
if event.IQuery != nil {
if tableName := event.IQuery.GetTableName(); tableName != "" {
attrs = append(attrs, semconv.DBSQLTableKey.String(tableName))
}
}
span.SetAttributes(attrs...)
switch event.Err {
case nil, sql.ErrNoRows:
default:
span.SetStatus(codes.Error, "")
span.RecordError(event.Err)
}
if event.Result != nil {
numRow, err := event.Result.RowsAffected()
if err == nil {
span.SetAttributes(attribute.Int64("db.rows_affected", numRow))
}
}
}
func funcFileLine(pkg string) (string, string, int) {
const depth = 16
var pcs [depth]uintptr
n := runtime.Callers(3, pcs[:])
ff := runtime.CallersFrames(pcs[:n])
var fn, file string
var line int
for {
f, ok := ff.Next()
if !ok {
break
}
fn, file, line = f.Function, f.File, f.Line
if !strings.Contains(fn, pkg) {
break
}
}
if ind := strings.LastIndexByte(fn, '/'); ind != -1 {
fn = fn[ind+1:]
}
return fn, file, line
}

22
chotel/go.mod Normal file
View File

@ -0,0 +1,22 @@
module github.com/uptrace/go-clickhouse/chotel
go 1.18
replace github.com/uptrace/go-clickhouse => ./..
replace github.com/uptrace/go-clickhouse/chdebug => ../chdebug
require (
github.com/uptrace/go-clickhouse v0.1.1
go.opentelemetry.io/otel v1.5.0
go.opentelemetry.io/otel/trace v1.5.0
)
require (
github.com/codemodus/kace v0.5.1 // indirect
github.com/go-logr/logr v1.2.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/pierrec/lz4/v4 v4.1.14 // indirect
golang.org/x/exp v0.0.0-20220317015231-48e79f11773a // indirect
)

34
chotel/go.sum Normal file
View File

@ -0,0 +1,34 @@
github.com/bradleyjkemp/cupaloy v2.3.0+incompatible h1:UafIjBvWQmS9i/xRg+CamMrnLTKNzo+bdmT/oH34c2Y=
github.com/codemodus/kace v0.5.1 h1:4OCsBlE2c/rSJo375ggfnucv9eRzge/U5LrrOZd47HA=
github.com/codemodus/kace v0.5.1/go.mod h1:coddaHoX1ku1YFSe4Ip0mL9kQjJvKkzb9CfIdG1YR04=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/go-logr/logr v1.2.2 h1:ahHml/yUpnlb96Rp8HCvtYVPY8ZYpxq3g7UYchIYwbs=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/pierrec/lz4/v4 v4.1.14 h1:+fL8AQEZtz/ijeNnpduH0bROTu0O3NZAlPjQxGn8LwE=
github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
go.opentelemetry.io/otel v1.5.0 h1:DhCU8oR2sJH9rfnwPdoV/+BJ7UIN5kXHL8DuSGrPU8E=
go.opentelemetry.io/otel v1.5.0/go.mod h1:Jm/m+rNp/z0eqJc74H7LPwQ3G87qkU/AnnAydAjSAHk=
go.opentelemetry.io/otel/trace v1.5.0 h1:AKQZ9zJsBRFAp7zLdyGNkqG2rToCDIt3i5tcLzQlbmU=
go.opentelemetry.io/otel/trace v1.5.0/go.mod h1:sq55kfhjXYr1zVSyexg0w1mpa03AYXR5eyTkB9NPPdE=
golang.org/x/exp v0.0.0-20220317015231-48e79f11773a h1:DAzrdbxsb5tXNOhMCSwF7ZdfMbW46hE9fSVO6BsmUZM=
golang.org/x/exp v0.0.0-20220317015231-48e79f11773a/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/sys v0.0.0-20220307203707-22a9840ba4d7 h1:8IVLkfbr2cLhv0a/vKq4UFUcJym8RmDoDboxCFWEjYE=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=

13
example/basic/README.md Normal file
View File

@ -0,0 +1,13 @@
# Basic example
To run this example, you need a ClickHouse database:
```shell
clickhouse-client -q "CREATE DATABASE test"
```
Then run:
```shell
go run .
```

23
example/basic/go.mod Normal file
View File

@ -0,0 +1,23 @@
module github.com/uptrace/go-clickhouse/example/basic
go 1.18
replace github.com/uptrace/go-clickhouse => ../..
replace github.com/uptrace/go-clickhouse/chdebug => ../../chdebug
require (
github.com/uptrace/go-clickhouse v0.1.1
github.com/uptrace/go-clickhouse/chdebug v0.1.1
)
require (
github.com/codemodus/kace v0.5.1 // indirect
github.com/fatih/color v1.13.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/pierrec/lz4/v4 v4.1.14 // indirect
golang.org/x/exp v0.0.0-20220317015231-48e79f11773a // indirect
golang.org/x/sys v0.0.0-20220317061510-51cd9980dadf // indirect
)

27
example/basic/go.sum Normal file
View File

@ -0,0 +1,27 @@
github.com/bradleyjkemp/cupaloy v2.3.0+incompatible h1:UafIjBvWQmS9i/xRg+CamMrnLTKNzo+bdmT/oH34c2Y=
github.com/codemodus/kace v0.5.1 h1:4OCsBlE2c/rSJo375ggfnucv9eRzge/U5LrrOZd47HA=
github.com/codemodus/kace v0.5.1/go.mod h1:coddaHoX1ku1YFSe4Ip0mL9kQjJvKkzb9CfIdG1YR04=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/pierrec/lz4/v4 v4.1.14 h1:+fL8AQEZtz/ijeNnpduH0bROTu0O3NZAlPjQxGn8LwE=
github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
golang.org/x/exp v0.0.0-20220317015231-48e79f11773a h1:DAzrdbxsb5tXNOhMCSwF7ZdfMbW46hE9fSVO6BsmUZM=
golang.org/x/exp v0.0.0-20220317015231-48e79f11773a/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220317061510-51cd9980dadf h1:Fm4IcnUL803i92qDlmB0obyHmosDrxZWxJL3gIeNqOw=
golang.org/x/sys v0.0.0-20220317061510-51cd9980dadf/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=

50
example/basic/main.go Normal file
View File

@ -0,0 +1,50 @@
package main
import (
"context"
"fmt"
"time"
"github.com/uptrace/go-clickhouse/ch"
"github.com/uptrace/go-clickhouse/chdebug"
)
type Model struct {
ch.CHModel `ch:"partition:toYYYYMM(time)"`
ID uint64
Text string `ch:",lc"`
Time time.Time `ch:",pk"`
}
func main() {
ctx := context.Background()
db := ch.Connect(ch.WithDatabase("test"))
db.AddQueryHook(chdebug.NewQueryHook(chdebug.WithVerbose(true)))
if err := db.Ping(ctx); err != nil {
panic(err)
}
var num int
if err := db.QueryRowContext(ctx, "SELECT 123").Scan(&num); err != nil {
panic(err)
}
fmt.Println(num)
if err := db.ResetModel(ctx, (*Model)(nil)); err != nil {
panic(err)
}
src := &Model{ID: 1, Text: "hello", Time: time.Now()}
if _, err := db.NewInsert().Model(src).Exec(ctx); err != nil {
panic(err)
}
dest := new(Model)
if err := db.NewSelect().Model(dest).Where("id = ?", src.ID).Limit(1).Scan(ctx); err != nil {
panic(err)
}
fmt.Println(dest)
}

View File

@ -0,0 +1,4 @@
# go-clickhouse benchmark examples
These examples allow to compare performance with
[clickhouse-go](https://github.com/ClickHouse/clickhouse-go/tree/v2/benchmark/v2).

23
example/benchmark/go.mod Normal file
View File

@ -0,0 +1,23 @@
module github.com/uptrace/go-clickhouse/ch/internal/bench
go 1.18
replace github.com/uptrace/go-clickhouse => ../..
replace github.com/uptrace/go-clickhouse/chdebug => ../../chdebug
require (
github.com/uptrace/go-clickhouse v0.1.1
github.com/uptrace/go-clickhouse/chdebug v0.1.1
)
require (
github.com/codemodus/kace v0.5.1 // indirect
github.com/fatih/color v1.13.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/pierrec/lz4/v4 v4.1.14 // indirect
golang.org/x/exp v0.0.0-20220317015231-48e79f11773a // indirect
golang.org/x/sys v0.0.0-20220317061510-51cd9980dadf // indirect
)

27
example/benchmark/go.sum Normal file
View File

@ -0,0 +1,27 @@
github.com/bradleyjkemp/cupaloy v2.3.0+incompatible h1:UafIjBvWQmS9i/xRg+CamMrnLTKNzo+bdmT/oH34c2Y=
github.com/codemodus/kace v0.5.1 h1:4OCsBlE2c/rSJo375ggfnucv9eRzge/U5LrrOZd47HA=
github.com/codemodus/kace v0.5.1/go.mod h1:coddaHoX1ku1YFSe4Ip0mL9kQjJvKkzb9CfIdG1YR04=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/pierrec/lz4/v4 v4.1.14 h1:+fL8AQEZtz/ijeNnpduH0bROTu0O3NZAlPjQxGn8LwE=
github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
golang.org/x/exp v0.0.0-20220317015231-48e79f11773a h1:DAzrdbxsb5tXNOhMCSwF7ZdfMbW46hE9fSVO6BsmUZM=
golang.org/x/exp v0.0.0-20220317015231-48e79f11773a/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220317061510-51cd9980dadf h1:Fm4IcnUL803i92qDlmB0obyHmosDrxZWxJL3gIeNqOw=
golang.org/x/sys v0.0.0-20220317061510-51cd9980dadf/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=

Some files were not shown because too many files have changed in this diff Show More