diff --git a/ch/bfloat16/bfloat16.go b/ch/bfloat16/bfloat16.go new file mode 100644 index 0000000..1c57a27 --- /dev/null +++ b/ch/bfloat16/bfloat16.go @@ -0,0 +1,25 @@ +package bfloat16 + +import ( + "math" +) + +type Map map[T]uint64 + +type T uint16 + +func From(f float64) T { + return FromFloat32(float32(f)) +} + +func FromFloat32(f float32) T { + return T(math.Float32bits(f) >> 16) +} + +func (f T) Float32() float32 { + return math.Float32frombits(uint32(f) << 16) +} + +func (f T) Float64() float64 { + return float64(f.Float32()) +} diff --git a/ch/chschema/column.go b/ch/chschema/column.go index 6a52777..75fd4d4 100644 --- a/ch/chschema/column.go +++ b/ch/chschema/column.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/uptrace/go-clickhouse/ch/bfloat16" "github.com/uptrace/go-clickhouse/ch/chproto" "github.com/uptrace/go-clickhouse/ch/internal" @@ -1164,3 +1165,68 @@ func newLCKeyType(typ int64) lcKey { panic("not reached") } } + +//------------------------------------------------------------------------------ + +type BFloat16HistColumn struct { + ColumnOf[bfloat16.Map] +} + +var _ Columnar = (*BFloat16HistColumn)(nil) + +func NewBFloat16HistColumn(typ reflect.Type, chType string, numRow int) Columnar { + return &BFloat16HistColumn{ + ColumnOf: NewColumnOf[bfloat16.Map](numRow), + } +} + +func (c BFloat16HistColumn) Type() reflect.Type { + return bfloat16MapType +} + +func (c *BFloat16HistColumn) ReadFrom(rd *chproto.Reader, numRow int) error { + if numRow == 0 { + return nil + } + + c.Alloc(numRow) + + for i := range c.Column { + n, err := rd.Uvarint() + if err != nil { + return err + } + + data := make(bfloat16.Map, n) + + for j := 0; j < int(n); j++ { + value, err := rd.UInt16() + if err != nil { + return err + } + + count, err := rd.UInt64() + if err != nil { + return err + } + + data[bfloat16.T(value)] = count + } + + c.Column[i] = data + } + + return nil +} + +func (c BFloat16HistColumn) WriteTo(wr *chproto.Writer) error { + for _, m := range c.Column { + wr.Uvarint(uint64(len(m))) + + for k, v := range m { + wr.UInt16(uint16(k)) + wr.UInt64(v) + } + } + return nil +} diff --git a/ch/chschema/types.go b/ch/chschema/types.go index 3a210ce..f8aa5b9 100644 --- a/ch/chschema/types.go +++ b/ch/chschema/types.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/uptrace/go-clickhouse/ch/bfloat16" "github.com/uptrace/go-clickhouse/ch/chtype" "github.com/uptrace/go-clickhouse/ch/internal" ) @@ -41,50 +42,6 @@ var chType = [...]string{ reflect.UnsafePointer: "", } -// keep in sync with ColumnFactory -func clickhouseType(typ reflect.Type) string { - switch typ { - case timeType: - return chtype.DateTime - case ipType: - return chtype.IPv6 - } - - kind := typ.Kind() - switch kind { - case reflect.Ptr: - if typ.Elem().Kind() == reflect.Struct { - return chtype.String - } - return fmt.Sprintf("Nullable(%s)", clickhouseType(typ.Elem())) - case reflect.Slice: - switch elem := typ.Elem(); elem.Kind() { - case reflect.Ptr: - if elem.Elem().Kind() == reflect.Struct { - return chtype.String // json - } - case reflect.Struct: - if elem != timeType { - return chtype.String // json - } - case reflect.Uint8: - return chtype.String // []byte - } - - return "Array(" + clickhouseType(typ.Elem()) + ")" - case reflect.Array: - if isUUID(typ) { - return chtype.UUID - } - } - - if s := chType[kind]; s != "" { - return s - } - - panic(fmt.Errorf("ch: unsupported Go type: %s", typ)) -} - type NewColumnFunc func(typ reflect.Type, chType string, numRow int) Columnar var kindToColumn = [...]NewColumnFunc{ @@ -141,6 +98,13 @@ func ColumnFactory(typ reflect.Type, chType string) NewColumnFunc { chType = chSubType(chType, "SimpleAggregateFunction(") } else if s := dateTimeType(chType); s != "" { chType = s + } else if funcName, _ := aggFuncNameAndType(chType); funcName != "" { + switch funcName { + case "quantileBFloat16", "quantilesBFloat16": + return NewBFloat16HistColumn + default: + panic(fmt.Errorf("unsupported ClickHouse type: %s", chType)) + } } switch typ { @@ -270,12 +234,13 @@ var ( float32Type = reflect.TypeOf(float32(0)) float64Type = reflect.TypeOf(float64(0)) - stringType = reflect.TypeOf("") - bytesType = reflect.TypeOf((*[]byte)(nil)).Elem() - uuidType = reflect.TypeOf((*UUID)(nil)).Elem() - timeType = reflect.TypeOf((*time.Time)(nil)).Elem() - ipType = reflect.TypeOf((*net.IP)(nil)).Elem() - ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() + stringType = reflect.TypeOf("") + bytesType = reflect.TypeOf((*[]byte)(nil)).Elem() + uuidType = reflect.TypeOf((*UUID)(nil)).Elem() + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + ipType = reflect.TypeOf((*net.IP)(nil)).Elem() + ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() + bfloat16MapType = reflect.TypeOf((*bfloat16.Map)(nil)).Elem() int64SliceType = reflect.TypeOf((*[]int64)(nil)).Elem() uint64SliceType = reflect.TypeOf((*[]uint64)(nil)).Elem() @@ -342,6 +307,51 @@ func goType(chType string) reflect.Type { panic(fmt.Errorf("unsupported ClickHouse type=%q", chType)) } +// clickhouseType returns ClickHouse type for the given Go type. +// Keep in sync with ColumnFactory. +func clickhouseType(typ reflect.Type) string { + switch typ { + case timeType: + return chtype.DateTime + case ipType: + return chtype.IPv6 + } + + kind := typ.Kind() + switch kind { + case reflect.Ptr: + if typ.Elem().Kind() == reflect.Struct { + return chtype.String + } + return fmt.Sprintf("Nullable(%s)", clickhouseType(typ.Elem())) + case reflect.Slice: + switch elem := typ.Elem(); elem.Kind() { + case reflect.Ptr: + if elem.Elem().Kind() == reflect.Struct { + return chtype.String // json + } + case reflect.Struct: + if elem != timeType { + return chtype.String // json + } + case reflect.Uint8: + return chtype.String // []byte + } + + return "Array(" + clickhouseType(typ.Elem()) + ")" + case reflect.Array: + if isUUID(typ) { + return chtype.UUID + } + } + + if s := chType[kind]; s != "" { + return s + } + + panic(fmt.Errorf("ch: unsupported Go type: %s", typ)) +} + func chArrayElemType(s string) string { s = chSubType(s, "Array(") if s == "" { @@ -401,7 +411,15 @@ func nullableType(s string) string { } func aggFuncNameAndType(chType string) (funcName, funcType string) { - s := chSubType(chType, "SimpleAggregateFunction(") + var s string + + for _, prefix := range []string{"SimpleAggregateFunction(", "AggregateFunction("} { + s = chSubType(chType, prefix) + if s != "" { + break + } + } + if s == "" { return "", "" } diff --git a/ch/config.go b/ch/config.go index f29e1d7..6aed087 100644 --- a/ch/config.go +++ b/ch/config.go @@ -261,25 +261,27 @@ func parseDSN(dsn string) ([]Option, error) { switch u.Scheme { case "ch", "clickhouse": - if u.Host != "" { - addr := u.Host - if !strings.Contains(addr, ":") { - addr += ":5432" - } - opts = append(opts, WithAddr(addr)) - } - - if len(u.Path) > 1 { - opts = append(opts, WithDatabase(u.Path[1:])) - } - - if host := q.string("host"); host != "" { - opts = append(opts, WithAddr(host)) - } + // ok default: return nil, errors.New("ch: unknown scheme: " + u.Scheme) } + if u.Host != "" { + addr := u.Host + if !strings.Contains(addr, ":") { + addr += ":5432" + } + opts = append(opts, WithAddr(addr)) + } + + if len(u.Path) > 1 { + opts = append(opts, WithDatabase(u.Path[1:])) + } + + if host := q.string("host"); host != "" { + opts = append(opts, WithAddr(host)) + } + if u.User != nil { opts = append(opts, WithUser(u.User.Username())) if password, ok := u.User.Password(); ok { diff --git a/ch/query_insert.go b/ch/query_insert.go index 0fdd003..e3aa805 100644 --- a/ch/query_insert.go +++ b/ch/query_insert.go @@ -192,19 +192,21 @@ func (q *InsertQuery) Exec(ctx context.Context) (sql.Result, error) { query := internal.String(queryBytes) ctx, evt := q.db.beforeQuery(ctx, q, query, nil, q.tableModel) + var res *result + var retErr error if q.tableModel != nil { fields, err := q.getFields() if err != nil { return nil, err } - res, err = q.db.insert(ctx, q.tableModel, query, fields) + res, retErr = q.db.insert(ctx, q.tableModel, query, fields) } else { - res, err = q.db.exec(ctx, query) + res, retErr = q.db.exec(ctx, query) } - q.db.afterQuery(ctx, evt, res, err) + q.db.afterQuery(ctx, evt, res, retErr) - return res, err + return res, retErr }