package daos import ( "errors" "fmt" "io" "regexp" "strings" "github.com/pocketbase/dbx" "github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/models/schema" "github.com/pocketbase/pocketbase/tools/inflector" "github.com/pocketbase/pocketbase/tools/list" "github.com/pocketbase/pocketbase/tools/security" "github.com/pocketbase/pocketbase/tools/tokenizer" "github.com/pocketbase/pocketbase/tools/types" ) // DeleteView drops the specified view name. // // This method is a no-op if a view with the provided name doesn't exist. // // Be aware that this method is vulnerable to SQL injection and the // "name" argument must come only from trusted input! func (dao *Dao) DeleteView(name string) error { _, err := dao.DB().NewQuery(fmt.Sprintf( "DROP VIEW IF EXISTS {{%s}}", name, )).Execute() return err } // SaveView creates (or updates already existing) persistent SQL view. // // Be aware that this method is vulnerable to SQL injection and the // "selectQuery" argument must come only from trusted input! func (dao *Dao) SaveView(name string, selectQuery string) error { return dao.RunInTransaction(func(txDao *Dao) error { // delete old view (if exists) if err := txDao.DeleteView(name); err != nil { return err } trimmed := strings.Trim(selectQuery, ";") // try to eagerly detect multiple inline statements tk := tokenizer.NewFromString(trimmed) tk.Separators(';') if queryParts, _ := tk.ScanAll(); len(queryParts) > 1 { return errors.New("multiple statements are not supported") } // (re)create the view // // note: the query is wrapped in a secondary SELECT as a rudimentary // measure to discourage multiple inline sql statements execution. viewQuery := fmt.Sprintf("CREATE VIEW {{%s}} AS SELECT * FROM (%s)", name, trimmed) if _, err := txDao.DB().NewQuery(viewQuery).Execute(); err != nil { return err } // fetch the view table info to ensure that the view was created // because missing tables or columns won't return an error if _, err := txDao.GetTableInfo(name); err != nil { return err } return nil }) } // CreateViewSchema creates a new view schema from the provided select query. // // There are some caveats: // - The select query must have an "id" column. // - Wildcard ("*") columns are not supported to avoid accidentally leaking sensitive data. func (dao *Dao) CreateViewSchema(selectQuery string) (schema.Schema, error) { result := schema.NewSchema() suggestedFields, err := dao.parseQueryToFields(selectQuery) if err != nil { return result, err } // note wrap in a transaction in case the selectQuery contains // multiple statements allowing us to rollback on any error txErr := dao.RunInTransaction(func(txDao *Dao) error { tempView := "_temp_" + security.PseudorandomString(5) // create a temp view with the provided query if err := txDao.SaveView(tempView, selectQuery); err != nil { return err } defer txDao.DeleteView(tempView) // extract the generated view table info info, err := txDao.GetTableInfo(tempView) if err != nil { return err } var hasId bool for _, row := range info { if row.Name == schema.FieldNameId { hasId = true } if list.ExistInSlice(row.Name, schema.BaseModelFieldNames()) { continue // skip base model fields since they are not part of the schema } var field *schema.SchemaField if f, ok := suggestedFields[row.Name]; ok { field = f.field } else { field = defaultViewField(row.Name) } result.AddField(field) } if !hasId { return errors.New("missing required id column (you ca use `(ROW_NUMBER() OVER()) as id` if you don't have one)") } return nil }) return result, txErr } // FindRecordByViewFile returns the original models.Record of the // provided view collection file. func (dao *Dao) FindRecordByViewFile( viewCollectionNameOrId string, fileFieldName string, filename string, ) (*models.Record, error) { view, err := dao.FindCollectionByNameOrId(viewCollectionNameOrId) if err != nil { return nil, err } if !view.IsView() { return nil, errors.New("not a view collection") } var findFirstNonViewQueryFileField func(int) (*queryField, error) findFirstNonViewQueryFileField = func(level int) (*queryField, error) { // check the level depth to prevent infinite circular recursion // (the limit is arbitrary and may change in the future) if level > 5 { return nil, errors.New("reached the max recursion level of view collection file field queries") } queryFields, err := dao.parseQueryToFields(view.ViewOptions().Query) if err != nil { return nil, err } for _, item := range queryFields { if item.collection == nil || item.original == nil || item.field.Name != fileFieldName { continue } if item.collection.IsView() { view = item.collection fileFieldName = item.original.Name return findFirstNonViewQueryFileField(level + 1) } return item, nil } return nil, errors.New("no query file field found") } qf, err := findFirstNonViewQueryFileField(1) if err != nil { return nil, err } cleanFieldName := inflector.Columnify(qf.original.Name) record := &models.Record{} err = dao.RecordQuery(qf.collection). InnerJoin(fmt.Sprintf( // note: the case is used to normalize the value access `json_each(CASE WHEN json_valid([[%s]]) THEN [[%s]] ELSE json_array([[%s]]) END) as {{_je_file}}`, cleanFieldName, cleanFieldName, cleanFieldName, ), dbx.HashExp{"_je_file.value": filename}). Limit(1). One(record) if err != nil { return nil, err } return record, nil } // ------------------------------------------------------------------- // Raw query to schema helpers // ------------------------------------------------------------------- type queryField struct { // field is the final resolved field. field *schema.SchemaField // collection refers to the original field's collection model. // It could be nil if the found query field is not from a collection schema. collection *models.Collection // original is the original found collection field. // It could be nil if the found query field is not from a collection schema. original *schema.SchemaField } func defaultViewField(name string) *schema.SchemaField { return &schema.SchemaField{ Name: name, Type: schema.FieldTypeJson, } } func (dao *Dao) parseQueryToFields(selectQuery string) (map[string]*queryField, error) { p := new(identifiersParser) if err := p.parse(selectQuery); err != nil { return nil, err } collections, err := dao.findCollectionsByIdentifiers(p.tables) if err != nil { return nil, err } result := make(map[string]*queryField, len(p.columns)) var mainTable identifier if len(p.tables) > 0 { mainTable = p.tables[0] } for _, col := range p.columns { colLower := strings.ToLower(col.original) // numeric expression cast if strings.Contains(colLower, "(") && (strings.HasPrefix(colLower, "count(") || strings.HasPrefix(colLower, "total(") || strings.Contains(colLower, " as numeric") || strings.Contains(colLower, " as real") || strings.Contains(colLower, " as int") || strings.Contains(colLower, " as integer") || strings.Contains(colLower, " as decimal")) { result[col.alias] = &queryField{ field: &schema.SchemaField{ Name: col.alias, Type: schema.FieldTypeNumber, }, } continue } parts := strings.Split(col.original, ".") var fieldName string var collection *models.Collection var isMainTableField bool if len(parts) == 2 { fieldName = parts[1] collection = collections[parts[0]] isMainTableField = parts[0] == mainTable.alias } else { fieldName = parts[0] collection = collections[mainTable.alias] isMainTableField = true } // fallback to the default field if the found column is not from a collection schema if collection == nil { result[col.alias] = &queryField{ field: defaultViewField(col.alias), } continue } if fieldName == "*" { return nil, errors.New("dynamic column names are not supported") } // find the first field by name (case insensitive) var field *schema.SchemaField for _, f := range collection.Schema.Fields() { if strings.EqualFold(f.Name, fieldName) { field = f break } } if field != nil { clone := *field clone.Name = col.alias result[col.alias] = &queryField{ field: &clone, collection: collection, original: field, } continue } if fieldName == schema.FieldNameId && !isMainTableField { // convert to relation since it is a direct id reference to non-maintable collection result[col.alias] = &queryField{ field: &schema.SchemaField{ Name: col.alias, Type: schema.FieldTypeRelation, Options: &schema.RelationOptions{ MaxSelect: types.Pointer(1), CollectionId: collection.Id, }, }, collection: collection, } } else if fieldName == schema.FieldNameCreated || fieldName == schema.FieldNameUpdated { result[col.alias] = &queryField{ field: &schema.SchemaField{ Name: col.alias, Type: schema.FieldTypeDate, }, collection: collection, } } else if fieldName == schema.FieldNameUsername && collection.IsAuth() { result[col.alias] = &queryField{ field: &schema.SchemaField{ Name: col.alias, Type: schema.FieldTypeText, }, collection: collection, } } else if fieldName == schema.FieldNameEmail && collection.IsAuth() { result[col.alias] = &queryField{ field: &schema.SchemaField{ Name: col.alias, Type: schema.FieldTypeEmail, }, collection: collection, } } else if (fieldName == schema.FieldNameVerified || fieldName == schema.FieldNameEmailVisibility) && collection.IsAuth() { result[col.alias] = &queryField{ field: &schema.SchemaField{ Name: col.alias, Type: schema.FieldTypeBool, }, collection: collection, } } else { result[col.alias] = &queryField{ field: defaultViewField(col.alias), collection: collection, } } } return result, nil } func (dao *Dao) findCollectionsByIdentifiers(tables []identifier) (map[string]*models.Collection, error) { names := make([]any, 0, len(tables)) for _, table := range tables { if strings.Contains(table.alias, "(") { continue // skip expressions } names = append(names, table.original) } if len(names) == 0 { return nil, nil } result := make(map[string]*models.Collection, len(names)) collections := make([]*models.Collection, 0, len(names)) err := dao.CollectionQuery(). AndWhere(dbx.In("name", names...)). All(&collections) if err != nil { return nil, err } for _, table := range tables { for _, collection := range collections { if collection.Name == table.original { result[table.alias] = collection } } } return result, nil } // ------------------------------------------------------------------- // Raw query identifiers parser // ------------------------------------------------------------------- var joinReplaceRegex = regexp.MustCompile(`(?im)\s+(inner join|outer join|left join|right join|join)\s+?`) var discardReplaceRegex = regexp.MustCompile(`(?im)\s+(where|group by|having|order|limit|with)\s+?`) var commentsReplaceRegex = regexp.MustCompile(`(?m)(\/\*[\s\S]+\*\/)|(--.+$)`) type identifier struct { original string alias string } type identifiersParser struct { columns []identifier tables []identifier } func (p *identifiersParser) parse(selectQuery string) error { str := strings.Trim(selectQuery, ";") str = joinReplaceRegex.ReplaceAllString(str, " _join_ ") str = discardReplaceRegex.ReplaceAllString(str, " _discard_ ") str = commentsReplaceRegex.ReplaceAllString(str, "") tk := tokenizer.NewFromString(str) tk.Separators(',', ' ', '\n', '\t') tk.KeepSeparator(true) var skip bool var partType string var activeBuilder *strings.Builder var selectParts strings.Builder var fromParts strings.Builder var joinParts strings.Builder for { token, err := tk.Scan() if err != nil { if err != io.EOF { return err } break } trimmed := strings.ToLower(strings.TrimSpace(token)) switch trimmed { case "select": skip = false partType = "select" activeBuilder = &selectParts case "from": skip = false partType = "from" activeBuilder = &fromParts case "_join_": skip = false // the previous part was also a join if partType == "join" { joinParts.WriteString(",") } partType = "join" activeBuilder = &joinParts case "_discard_": // do nothing... skip = true default: isJoin := partType == "join" if isJoin && trimmed == "on" { skip = true } if !skip && activeBuilder != nil { activeBuilder.WriteString(" ") activeBuilder.WriteString(token) } } } selects, err := extractIdentifiers(selectParts.String()) if err != nil { return err } froms, err := extractIdentifiers(fromParts.String()) if err != nil { return err } joins, err := extractIdentifiers(joinParts.String()) if err != nil { return err } p.columns = selects p.tables = froms p.tables = append(p.tables, joins...) return nil } func extractIdentifiers(rawExpression string) ([]identifier, error) { rawTk := tokenizer.NewFromString(rawExpression) rawTk.Separators(',') rawIdentifiers, err := rawTk.ScanAll() if err != nil { return nil, err } result := make([]identifier, 0, len(rawIdentifiers)) for _, rawIdentifier := range rawIdentifiers { tk := tokenizer.NewFromString(rawIdentifier) tk.Separators(' ', '\n', '\t') parts, err := tk.ScanAll() if err != nil { return nil, err } resolved, err := identifierFromParts(parts) if err != nil { return nil, err } result = append(result, resolved) } return result, nil } func identifierFromParts(parts []string) (identifier, error) { var result identifier switch len(parts) { case 3: if !strings.EqualFold(parts[1], "as") { return result, fmt.Errorf(`invalid identifier part - expected "as", got %v`, parts[1]) } result.original = parts[0] result.alias = parts[2] case 2: result.original = parts[0] result.alias = parts[1] case 1: subParts := strings.Split(parts[0], ".") result.original = parts[0] result.alias = subParts[len(subParts)-1] default: return result, fmt.Errorf(`invalid identifier parts %v`, parts) } result.original = trimRawIdentifier(result.original) result.alias = trimRawIdentifier(result.alias) return result, nil } func trimRawIdentifier(rawIdentifier string) string { const trimChars = "`\"[];" parts := strings.Split(rawIdentifier, ".") for i := range parts { parts[i] = strings.Trim(parts[i], trimChars) } return strings.Join(parts, ".") }