mirror of
https://github.com/uptrace/go-clickhouse.git
synced 2025-06-12 23:37:29 +02:00
Merge pull request #38 from uptrace/feat/bfloat16
feat: add bfloat16 support
This commit is contained in:
commit
7f77517af4
25
ch/bfloat16/bfloat16.go
Normal file
25
ch/bfloat16/bfloat16.go
Normal file
@ -0,0 +1,25 @@
|
||||
package bfloat16
|
||||
|
||||
import (
|
||||
"math"
|
||||
)
|
||||
|
||||
type Map map[T]uint64
|
||||
|
||||
type T uint16
|
||||
|
||||
func From(f float64) T {
|
||||
return FromFloat32(float32(f))
|
||||
}
|
||||
|
||||
func FromFloat32(f float32) T {
|
||||
return T(math.Float32bits(f) >> 16)
|
||||
}
|
||||
|
||||
func (f T) Float32() float32 {
|
||||
return math.Float32frombits(uint32(f) << 16)
|
||||
}
|
||||
|
||||
func (f T) Float64() float64 {
|
||||
return float64(f.Float32())
|
||||
}
|
@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/go-clickhouse/ch/bfloat16"
|
||||
"github.com/uptrace/go-clickhouse/ch/chproto"
|
||||
"github.com/uptrace/go-clickhouse/ch/internal"
|
||||
|
||||
@ -1164,3 +1165,68 @@ func newLCKeyType(typ int64) lcKey {
|
||||
panic("not reached")
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
type BFloat16HistColumn struct {
|
||||
ColumnOf[bfloat16.Map]
|
||||
}
|
||||
|
||||
var _ Columnar = (*BFloat16HistColumn)(nil)
|
||||
|
||||
func NewBFloat16HistColumn(typ reflect.Type, chType string, numRow int) Columnar {
|
||||
return &BFloat16HistColumn{
|
||||
ColumnOf: NewColumnOf[bfloat16.Map](numRow),
|
||||
}
|
||||
}
|
||||
|
||||
func (c BFloat16HistColumn) Type() reflect.Type {
|
||||
return bfloat16MapType
|
||||
}
|
||||
|
||||
func (c *BFloat16HistColumn) ReadFrom(rd *chproto.Reader, numRow int) error {
|
||||
if numRow == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.Alloc(numRow)
|
||||
|
||||
for i := range c.Column {
|
||||
n, err := rd.Uvarint()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data := make(bfloat16.Map, n)
|
||||
|
||||
for j := 0; j < int(n); j++ {
|
||||
value, err := rd.UInt16()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := rd.UInt64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data[bfloat16.T(value)] = count
|
||||
}
|
||||
|
||||
c.Column[i] = data
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c BFloat16HistColumn) WriteTo(wr *chproto.Writer) error {
|
||||
for _, m := range c.Column {
|
||||
wr.Uvarint(uint64(len(m)))
|
||||
|
||||
for k, v := range m {
|
||||
wr.UInt16(uint16(k))
|
||||
wr.UInt64(v)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/go-clickhouse/ch/bfloat16"
|
||||
"github.com/uptrace/go-clickhouse/ch/chtype"
|
||||
"github.com/uptrace/go-clickhouse/ch/internal"
|
||||
)
|
||||
@ -41,50 +42,6 @@ var chType = [...]string{
|
||||
reflect.UnsafePointer: "",
|
||||
}
|
||||
|
||||
// keep in sync with ColumnFactory
|
||||
func clickhouseType(typ reflect.Type) string {
|
||||
switch typ {
|
||||
case timeType:
|
||||
return chtype.DateTime
|
||||
case ipType:
|
||||
return chtype.IPv6
|
||||
}
|
||||
|
||||
kind := typ.Kind()
|
||||
switch kind {
|
||||
case reflect.Ptr:
|
||||
if typ.Elem().Kind() == reflect.Struct {
|
||||
return chtype.String
|
||||
}
|
||||
return fmt.Sprintf("Nullable(%s)", clickhouseType(typ.Elem()))
|
||||
case reflect.Slice:
|
||||
switch elem := typ.Elem(); elem.Kind() {
|
||||
case reflect.Ptr:
|
||||
if elem.Elem().Kind() == reflect.Struct {
|
||||
return chtype.String // json
|
||||
}
|
||||
case reflect.Struct:
|
||||
if elem != timeType {
|
||||
return chtype.String // json
|
||||
}
|
||||
case reflect.Uint8:
|
||||
return chtype.String // []byte
|
||||
}
|
||||
|
||||
return "Array(" + clickhouseType(typ.Elem()) + ")"
|
||||
case reflect.Array:
|
||||
if isUUID(typ) {
|
||||
return chtype.UUID
|
||||
}
|
||||
}
|
||||
|
||||
if s := chType[kind]; s != "" {
|
||||
return s
|
||||
}
|
||||
|
||||
panic(fmt.Errorf("ch: unsupported Go type: %s", typ))
|
||||
}
|
||||
|
||||
type NewColumnFunc func(typ reflect.Type, chType string, numRow int) Columnar
|
||||
|
||||
var kindToColumn = [...]NewColumnFunc{
|
||||
@ -141,6 +98,13 @@ func ColumnFactory(typ reflect.Type, chType string) NewColumnFunc {
|
||||
chType = chSubType(chType, "SimpleAggregateFunction(")
|
||||
} else if s := dateTimeType(chType); s != "" {
|
||||
chType = s
|
||||
} else if funcName, _ := aggFuncNameAndType(chType); funcName != "" {
|
||||
switch funcName {
|
||||
case "quantileBFloat16", "quantilesBFloat16":
|
||||
return NewBFloat16HistColumn
|
||||
default:
|
||||
panic(fmt.Errorf("unsupported ClickHouse type: %s", chType))
|
||||
}
|
||||
}
|
||||
|
||||
switch typ {
|
||||
@ -270,12 +234,13 @@ var (
|
||||
float32Type = reflect.TypeOf(float32(0))
|
||||
float64Type = reflect.TypeOf(float64(0))
|
||||
|
||||
stringType = reflect.TypeOf("")
|
||||
bytesType = reflect.TypeOf((*[]byte)(nil)).Elem()
|
||||
uuidType = reflect.TypeOf((*UUID)(nil)).Elem()
|
||||
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
|
||||
ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
|
||||
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
|
||||
stringType = reflect.TypeOf("")
|
||||
bytesType = reflect.TypeOf((*[]byte)(nil)).Elem()
|
||||
uuidType = reflect.TypeOf((*UUID)(nil)).Elem()
|
||||
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
|
||||
ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
|
||||
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
|
||||
bfloat16MapType = reflect.TypeOf((*bfloat16.Map)(nil)).Elem()
|
||||
|
||||
int64SliceType = reflect.TypeOf((*[]int64)(nil)).Elem()
|
||||
uint64SliceType = reflect.TypeOf((*[]uint64)(nil)).Elem()
|
||||
@ -342,6 +307,51 @@ func goType(chType string) reflect.Type {
|
||||
panic(fmt.Errorf("unsupported ClickHouse type=%q", chType))
|
||||
}
|
||||
|
||||
// clickhouseType returns ClickHouse type for the given Go type.
|
||||
// Keep in sync with ColumnFactory.
|
||||
func clickhouseType(typ reflect.Type) string {
|
||||
switch typ {
|
||||
case timeType:
|
||||
return chtype.DateTime
|
||||
case ipType:
|
||||
return chtype.IPv6
|
||||
}
|
||||
|
||||
kind := typ.Kind()
|
||||
switch kind {
|
||||
case reflect.Ptr:
|
||||
if typ.Elem().Kind() == reflect.Struct {
|
||||
return chtype.String
|
||||
}
|
||||
return fmt.Sprintf("Nullable(%s)", clickhouseType(typ.Elem()))
|
||||
case reflect.Slice:
|
||||
switch elem := typ.Elem(); elem.Kind() {
|
||||
case reflect.Ptr:
|
||||
if elem.Elem().Kind() == reflect.Struct {
|
||||
return chtype.String // json
|
||||
}
|
||||
case reflect.Struct:
|
||||
if elem != timeType {
|
||||
return chtype.String // json
|
||||
}
|
||||
case reflect.Uint8:
|
||||
return chtype.String // []byte
|
||||
}
|
||||
|
||||
return "Array(" + clickhouseType(typ.Elem()) + ")"
|
||||
case reflect.Array:
|
||||
if isUUID(typ) {
|
||||
return chtype.UUID
|
||||
}
|
||||
}
|
||||
|
||||
if s := chType[kind]; s != "" {
|
||||
return s
|
||||
}
|
||||
|
||||
panic(fmt.Errorf("ch: unsupported Go type: %s", typ))
|
||||
}
|
||||
|
||||
func chArrayElemType(s string) string {
|
||||
s = chSubType(s, "Array(")
|
||||
if s == "" {
|
||||
@ -401,7 +411,15 @@ func nullableType(s string) string {
|
||||
}
|
||||
|
||||
func aggFuncNameAndType(chType string) (funcName, funcType string) {
|
||||
s := chSubType(chType, "SimpleAggregateFunction(")
|
||||
var s string
|
||||
|
||||
for _, prefix := range []string{"SimpleAggregateFunction(", "AggregateFunction("} {
|
||||
s = chSubType(chType, prefix)
|
||||
if s != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if s == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
32
ch/config.go
32
ch/config.go
@ -261,25 +261,27 @@ func parseDSN(dsn string) ([]Option, error) {
|
||||
|
||||
switch u.Scheme {
|
||||
case "ch", "clickhouse":
|
||||
if u.Host != "" {
|
||||
addr := u.Host
|
||||
if !strings.Contains(addr, ":") {
|
||||
addr += ":5432"
|
||||
}
|
||||
opts = append(opts, WithAddr(addr))
|
||||
}
|
||||
|
||||
if len(u.Path) > 1 {
|
||||
opts = append(opts, WithDatabase(u.Path[1:]))
|
||||
}
|
||||
|
||||
if host := q.string("host"); host != "" {
|
||||
opts = append(opts, WithAddr(host))
|
||||
}
|
||||
// ok
|
||||
default:
|
||||
return nil, errors.New("ch: unknown scheme: " + u.Scheme)
|
||||
}
|
||||
|
||||
if u.Host != "" {
|
||||
addr := u.Host
|
||||
if !strings.Contains(addr, ":") {
|
||||
addr += ":5432"
|
||||
}
|
||||
opts = append(opts, WithAddr(addr))
|
||||
}
|
||||
|
||||
if len(u.Path) > 1 {
|
||||
opts = append(opts, WithDatabase(u.Path[1:]))
|
||||
}
|
||||
|
||||
if host := q.string("host"); host != "" {
|
||||
opts = append(opts, WithAddr(host))
|
||||
}
|
||||
|
||||
if u.User != nil {
|
||||
opts = append(opts, WithUser(u.User.Username()))
|
||||
if password, ok := u.User.Password(); ok {
|
||||
|
@ -192,19 +192,21 @@ func (q *InsertQuery) Exec(ctx context.Context) (sql.Result, error) {
|
||||
query := internal.String(queryBytes)
|
||||
|
||||
ctx, evt := q.db.beforeQuery(ctx, q, query, nil, q.tableModel)
|
||||
|
||||
var res *result
|
||||
var retErr error
|
||||
|
||||
if q.tableModel != nil {
|
||||
fields, err := q.getFields()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err = q.db.insert(ctx, q.tableModel, query, fields)
|
||||
res, retErr = q.db.insert(ctx, q.tableModel, query, fields)
|
||||
} else {
|
||||
res, err = q.db.exec(ctx, query)
|
||||
res, retErr = q.db.exec(ctx, query)
|
||||
}
|
||||
|
||||
q.db.afterQuery(ctx, evt, res, err)
|
||||
q.db.afterQuery(ctx, evt, res, retErr)
|
||||
|
||||
return res, err
|
||||
return res, retErr
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user