diff --git a/ch/query_base.go b/ch/query_base.go index 16bdaf6..b8c3f28 100644 --- a/ch/query_base.go +++ b/ch/query_base.go @@ -32,6 +32,15 @@ type baseQuery struct { flags internal.Flag } +func (q *baseQuery) clone() baseQuery { + clone := *q + clone.with = lazyClone(clone.with) + clone.tables = lazyClone(clone.tables) + clone.columns = lazyClone(clone.columns) + clone.settings = lazyClone(clone.settings) + return clone +} + func (q *baseQuery) DB() *DB { return q.db } @@ -392,51 +401,32 @@ func appendTableColumns(b []byte, table chschema.Safe, fields []*chschema.Field) //------------------------------------------------------------------------------ -type whereBaseQuery struct { - baseQuery - - where []chschema.QueryWithSep +type whereQuery struct { + filters []chschema.QueryWithSep } -func (q *whereBaseQuery) addWhere(where chschema.QueryWithSep) { - q.where = append(q.where, where) +func (q *whereQuery) clone() whereQuery { + clone := *q + clone.filters = lazyClone(clone.filters) + return clone } -func (q *whereBaseQuery) addWhereGroup(sep string, where []chschema.QueryWithSep) { - if len(where) == 0 { +func (q *whereQuery) addFilter(filter chschema.QueryWithSep) { + q.filters = append(q.filters, filter) +} + +func (q *whereQuery) addGroup(sep string, filters []chschema.QueryWithSep) { + if len(filters) == 0 { return } - q.addWhere(chschema.SafeQueryWithSep("", nil, sep)) - q.addWhere(chschema.SafeQueryWithSep("", nil, "(")) + q.addFilter(chschema.SafeQueryWithSep("", nil, sep)) + q.addFilter(chschema.SafeQueryWithSep("", nil, "(")) - where[0].Sep = "" - q.where = append(q.where, where...) + filters[0].Sep = "" + q.filters = append(q.filters, filters...) - q.addWhere(chschema.SafeQueryWithSep("", nil, ")")) -} - -func (q *whereBaseQuery) mustAppendWhere(fmter chschema.Formatter, b []byte) ([]byte, error) { - if len(q.where) == 0 { - err := errors.New("ch: Update and Delete queries require at least one Where") - return nil, err - } - return q.appendWhere(fmter, b) -} - -func (q *whereBaseQuery) appendWhere(fmter chschema.Formatter, b []byte) (_ []byte, err error) { - if len(q.where) == 0 { - return b, nil - } - - b = append(b, " WHERE "...) - - b, err = appendWhere(fmter, b, q.where) - if err != nil { - return nil, err - } - - return b, nil + q.addFilter(chschema.SafeQueryWithSep("", nil, ")")) } func appendWhere( @@ -460,3 +450,11 @@ func appendWhere( } return b, nil } + +func lazyClone[S ~[]E, E any](s S) S { + // Preserve nil in case it matters. + if s == nil { + return nil + } + return s[:len(s):len(s)] +} diff --git a/ch/query_insert.go b/ch/query_insert.go index 88c19a5..ba258f1 100644 --- a/ch/query_insert.go +++ b/ch/query_insert.go @@ -10,17 +10,16 @@ import ( ) type InsertQuery struct { - whereBaseQuery + baseQuery + where whereQuery } var _ Query = (*InsertQuery)(nil) func NewInsertQuery(db *DB) *InsertQuery { return &InsertQuery{ - whereBaseQuery: whereBaseQuery{ - baseQuery: baseQuery{ - db: db, - }, + baseQuery: baseQuery{ + db: db, }, } } @@ -76,12 +75,12 @@ func (q *InsertQuery) ExcludeColumn(columns ...string) *InsertQuery { //------------------------------------------------------------------------------ func (q *InsertQuery) Where(query string, args ...any) *InsertQuery { - q.addWhere(chschema.SafeQueryWithSep(query, args, " AND ")) + q.where.addFilter(chschema.SafeQueryWithSep(query, args, " AND ")) return q } func (q *InsertQuery) WhereOr(query string, args ...any) *InsertQuery { - q.addWhere(chschema.SafeQueryWithSep(query, args, " OR ")) + q.where.addFilter(chschema.SafeQueryWithSep(query, args, " OR ")) return q } @@ -152,10 +151,9 @@ func (q *InsertQuery) appendValues( return nil, err } - if len(q.where) > 0 { + if len(q.where.filters) > 0 { b = append(b, " WHERE "...) - - b, err = appendWhere(fmter, b, q.where) + b, err = appendWhere(fmter, b, q.where.filters) if err != nil { return nil, err } diff --git a/ch/query_select.go b/ch/query_select.go index 05125b0..d6bfd20 100644 --- a/ch/query_select.go +++ b/ch/query_select.go @@ -13,11 +13,13 @@ import ( ) type SelectQuery struct { - whereBaseQuery + baseQuery sample chschema.QueryWithArgs distinctOn []chschema.QueryWithArgs joins []joinQuery + prewhere whereQuery + where whereQuery group []chschema.QueryWithArgs having []chschema.QueryWithArgs order []chschema.QueryWithArgs @@ -30,14 +32,28 @@ var _ Query = (*SelectQuery)(nil) func NewSelectQuery(db *DB) *SelectQuery { return &SelectQuery{ - whereBaseQuery: whereBaseQuery{ - baseQuery: baseQuery{ - db: db, - }, + baseQuery: baseQuery{ + db: db, }, } } +func (q *SelectQuery) Clone() *SelectQuery { + clone := *q + + clone.baseQuery = clone.baseQuery.clone() + clone.prewhere = clone.prewhere.clone() + clone.where = clone.where.clone() + + clone.distinctOn = lazyClone(clone.distinctOn) + clone.joins = lazyClone(clone.joins) + clone.group = lazyClone(clone.group) + clone.having = lazyClone(clone.having) + clone.order = lazyClone(clone.order) + + return &clone +} + func (q *SelectQuery) Operation() string { return "SELECT" } @@ -163,26 +179,52 @@ func (q *SelectQuery) joinOn(cond string, args []any, sep string) *SelectQuery { //------------------------------------------------------------------------------ +func (q *SelectQuery) Prewhere(query string, args ...any) *SelectQuery { + q.prewhere.addFilter(chschema.SafeQueryWithSep(query, args, " AND ")) + return q +} + +func (q *SelectQuery) PrewhereOr(query string, args ...any) *SelectQuery { + q.prewhere.addFilter(chschema.SafeQueryWithSep(query, args, " OR ")) + return q +} + +func (q *SelectQuery) PrewhereGroup(sep string, fn func(*SelectQuery) *SelectQuery) *SelectQuery { + saved := q.prewhere.filters + q.prewhere.filters = nil + + q = fn(q) + + filters := q.prewhere.filters + q.prewhere.filters = saved + + q.prewhere.addGroup(sep, filters) + + return q +} + +//------------------------------------------------------------------------------ + func (q *SelectQuery) Where(query string, args ...any) *SelectQuery { - q.addWhere(chschema.SafeQueryWithSep(query, args, " AND ")) + q.where.addFilter(chschema.SafeQueryWithSep(query, args, " AND ")) return q } func (q *SelectQuery) WhereOr(query string, args ...any) *SelectQuery { - q.addWhere(chschema.SafeQueryWithSep(query, args, " OR ")) + q.where.addFilter(chschema.SafeQueryWithSep(query, args, " OR ")) return q } func (q *SelectQuery) WhereGroup(sep string, fn func(*SelectQuery) *SelectQuery) *SelectQuery { - saved := q.where - q.where = nil + saved := q.where.filters + q.where.filters = nil q = fn(q) - where := q.where - q.where = saved + filters := q.where.filters + q.where.filters = saved - q.addWhereGroup(sep, where) + q.where.addGroup(sep, filters) return q } @@ -344,9 +386,19 @@ func (q *SelectQuery) appendQuery( } } - b, err = q.appendWhere(fmter, b) - if err != nil { - return nil, err + if len(q.prewhere.filters) > 0 { + b = append(b, " PREWHERE "...) + b, err = appendWhere(fmter, b, q.prewhere.filters) + if err != nil { + return nil, err + } + } + if len(q.where.filters) > 0 { + b = append(b, " WHERE "...) + b, err = appendWhere(fmter, b, q.where.filters) + if err != nil { + return nil, err + } } if len(q.group) > 0 { diff --git a/ch/query_view_create.go b/ch/query_view_create.go index 492b9df..3f23536 100644 --- a/ch/query_view_create.go +++ b/ch/query_view_create.go @@ -9,13 +9,14 @@ import ( ) type CreateViewQuery struct { - whereBaseQuery + baseQuery materialized bool ifNotExists bool view chschema.QueryWithArgs cluster chschema.QueryWithArgs to chschema.QueryWithArgs + where whereQuery group []chschema.QueryWithArgs order chschema.QueryWithArgs } @@ -24,10 +25,8 @@ var _ Query = (*CreateViewQuery)(nil) func NewCreateViewQuery(db *DB) *CreateViewQuery { return &CreateViewQuery{ - whereBaseQuery: whereBaseQuery{ - baseQuery: baseQuery{ - db: db, - }, + baseQuery: baseQuery{ + db: db, }, } } @@ -124,29 +123,31 @@ func (q *CreateViewQuery) IfNotExists() *CreateViewQuery { //------------------------------------------------------------------------------ func (q *CreateViewQuery) Where(query string, args ...any) *CreateViewQuery { - q.addWhere(chschema.SafeQueryWithSep(query, args, " AND ")) + q.where.addFilter(chschema.SafeQueryWithSep(query, args, " AND ")) return q } func (q *CreateViewQuery) WhereOr(query string, args ...any) *CreateViewQuery { - q.addWhere(chschema.SafeQueryWithSep(query, args, " OR ")) + q.where.addFilter(chschema.SafeQueryWithSep(query, args, " OR ")) return q } func (q *CreateViewQuery) WhereGroup(sep string, fn func(*CreateViewQuery) *CreateViewQuery) *CreateViewQuery { - saved := q.where - q.where = nil + saved := q.where.filters + q.where.filters = nil q = fn(q) - where := q.where - q.where = saved + filters := q.where.filters + q.where.filters = saved - q.addWhereGroup(sep, where) + q.where.addGroup(sep, filters) return q } +//------------------------------------------------------------------------------ + func (q *CreateViewQuery) Group(columns ...string) *CreateViewQuery { for _, column := range columns { q.group = append(q.group, chschema.UnsafeIdent(column)) @@ -224,9 +225,12 @@ func (q *CreateViewQuery) AppendQuery(fmter chschema.Formatter, b []byte) (_ []b return nil, err } - b, err = q.appendWhere(fmter, b) - if err != nil { - return nil, err + if len(q.where.filters) > 0 { + b = append(b, " WHERE "...) + b, err = appendWhere(fmter, b, q.where.filters) + if err != nil { + return nil, err + } } if len(q.group) > 0 {