package core import ( "errors" "fmt" "io" "regexp" "strings" "github.com/pocketbase/dbx" "github.com/pocketbase/pocketbase/tools/dbutils" "github.com/pocketbase/pocketbase/tools/inflector" "github.com/pocketbase/pocketbase/tools/security" "github.com/pocketbase/pocketbase/tools/tokenizer" ) // DeleteView drops the specified view name. // // This method is a no-op if a view with the provided name doesn't exist. // // NB! Be aware that this method is vulnerable to SQL injection and the // "name" argument must come only from trusted input! func (app *BaseApp) DeleteView(name string) error { _, err := app.DB().NewQuery(fmt.Sprintf( "DROP VIEW IF EXISTS {{%s}}", name, )).Execute() return err } // SaveView creates (or updates already existing) persistent SQL view. // // NB! Be aware that this method is vulnerable to SQL injection and the // "selectQuery" argument must come only from trusted input! func (app *BaseApp) SaveView(name string, selectQuery string) error { return app.RunInTransaction(func(txApp App) error { // delete old view (if exists) if err := txApp.DeleteView(name); err != nil { return err } selectQuery = strings.Trim(strings.TrimSpace(selectQuery), ";") // try to loosely detect multiple inline statements tk := tokenizer.NewFromString(selectQuery) tk.Separators(';') if queryParts, _ := tk.ScanAll(); len(queryParts) > 1 { return errors.New("multiple statements are not supported") } // (re)create the view // // note: the query is wrapped in a secondary SELECT as a rudimentary // measure to discourage multiple inline sql statements execution viewQuery := fmt.Sprintf("CREATE VIEW {{%s}} AS SELECT * FROM (%s)", name, selectQuery) if _, err := txApp.DB().NewQuery(viewQuery).Execute(); err != nil { return err } // fetch the view table info to ensure that the view was created // because missing tables or columns won't return an error if _, err := txApp.TableInfo(name); err != nil { // manually cleanup previously created view in case the func // is called in a nested transaction and the error is discarded txApp.DeleteView(name) return err } return nil }) } // CreateViewFields creates a new FieldsList from the provided select query. // // There are some caveats: // - The select query must have an "id" column. // - Wildcard ("*") columns are not supported to avoid accidentally leaking sensitive data. func (app *BaseApp) CreateViewFields(selectQuery string) (FieldsList, error) { result := NewFieldsList() suggestedFields, err := parseQueryToFields(app, selectQuery) if err != nil { return result, err } // note wrap in a transaction in case the selectQuery contains // multiple statements allowing us to rollback on any error txErr := app.RunInTransaction(func(txApp App) error { info, err := getQueryTableInfo(txApp, selectQuery) if err != nil { return err } var hasId bool for _, row := range info { if row.Name == FieldNameId { hasId = true } var field Field if f, ok := suggestedFields[row.Name]; ok { field = f.field } else { field = defaultViewField(row.Name) } result.Add(field) } if !hasId { return errors.New("missing required id column (you can use `(ROW_NUMBER() OVER()) as id` if you don't have one)") } return nil }) return result, txErr } // FindRecordByViewFile returns the original Record of the provided view collection file. func (app *BaseApp) FindRecordByViewFile(viewCollectionModelOrIdentifier any, fileFieldName string, filename string) (*Record, error) { view, err := getCollectionByModelOrIdentifier(app, viewCollectionModelOrIdentifier) if err != nil { return nil, err } if !view.IsView() { return nil, errors.New("not a view collection") } var findFirstNonViewQueryFileField func(int) (*queryField, error) findFirstNonViewQueryFileField = func(level int) (*queryField, error) { // check the level depth to prevent infinite circular recursion // (the limit is arbitrary and may change in the future) if level > 5 { return nil, errors.New("reached the max recursion level of view collection file field queries") } queryFields, err := parseQueryToFields(app, view.ViewQuery) if err != nil { return nil, err } for _, item := range queryFields { if item.collection == nil || item.original == nil || item.field.GetName() != fileFieldName { continue } if item.collection.IsView() { view = item.collection fileFieldName = item.original.GetName() return findFirstNonViewQueryFileField(level + 1) } return item, nil } return nil, errors.New("no query file field found") } qf, err := findFirstNonViewQueryFileField(1) if err != nil { return nil, err } cleanFieldName := inflector.Columnify(qf.original.GetName()) record := &Record{} query := app.RecordQuery(qf.collection).Limit(1) if opt, ok := qf.original.(MultiValuer); !ok || !opt.IsMultiple() { query.AndWhere(dbx.HashExp{cleanFieldName: filename}) } else { query.InnerJoin( fmt.Sprintf(`%s as {{_je_file}}`, dbutils.JSONEach(cleanFieldName)), dbx.HashExp{"_je_file.value": filename}, ) } if err := query.One(record); err != nil { return nil, err } return record, nil } // ------------------------------------------------------------------- // Raw query to schema helpers // ------------------------------------------------------------------- type queryField struct { // field is the final resolved field. field Field // collection refers to the original field's collection model. // It could be nil if the found query field is not from a collection collection *Collection // original is the original found collection field. // It could be nil if the found query field is not from a collection original Field } func defaultViewField(name string) Field { return &JSONField{ Name: name, MaxSize: 1, // unused for views } } var castRegex = regexp.MustCompile(`(?i)^cast\s*\(.*\s+as\s+(\w+)\s*\)$`) func parseQueryToFields(app App, selectQuery string) (map[string]*queryField, error) { p := new(identifiersParser) if err := p.parse(selectQuery); err != nil { return nil, err } collections, err := findCollectionsByIdentifiers(app, p.tables) if err != nil { return nil, err } result := make(map[string]*queryField, len(p.columns)) var mainTable identifier if len(p.tables) > 0 { mainTable = p.tables[0] } for _, col := range p.columns { colLower := strings.ToLower(col.original) // pk (always assume text field for now) if col.alias == FieldNameId { result[col.alias] = &queryField{ field: &TextField{ Name: col.alias, System: true, Required: true, PrimaryKey: true, }, } continue } // numeric aggregations if strings.HasPrefix(colLower, "count(") || strings.HasPrefix(colLower, "total(") { result[col.alias] = &queryField{ field: &NumberField{ Name: col.alias, }, } continue } castMatch := castRegex.FindStringSubmatch(colLower) // numeric casts if len(castMatch) == 2 { switch castMatch[1] { case "real", "integer", "int", "decimal", "numeric": result[col.alias] = &queryField{ field: &NumberField{ Name: col.alias, }, } continue case "text": result[col.alias] = &queryField{ field: &TextField{ Name: col.alias, }, } continue case "boolean", "bool": result[col.alias] = &queryField{ field: &BoolField{ Name: col.alias, }, } continue } } parts := strings.Split(col.original, ".") var fieldName string var collection *Collection if len(parts) == 2 { fieldName = parts[1] collection = collections[parts[0]] } else { fieldName = parts[0] collection = collections[mainTable.alias] } // fallback to the default field if collection == nil { result[col.alias] = &queryField{ field: defaultViewField(col.alias), } continue } if fieldName == "*" { return nil, errors.New("dynamic column names are not supported") } // find the first field by name (case insensitive) var field Field for _, f := range collection.Fields { if strings.EqualFold(f.GetName(), fieldName) { field = f break } } // fallback to the default field if field == nil { result[col.alias] = &queryField{ field: defaultViewField(col.alias), collection: collection, } continue } // convert to relation since it is an id reference if strings.EqualFold(fieldName, FieldNameId) { result[col.alias] = &queryField{ field: &RelationField{ Name: col.alias, MaxSelect: 1, CollectionId: collection.Id, }, collection: collection, } continue } // we fetch a brand new collection object to avoid using reflection // or having a dedicated Clone method for each field type tempCollection, err := app.FindCollectionByNameOrId(collection.Id) if err != nil { return nil, err } clone := tempCollection.Fields.GetById(field.GetId()) if clone == nil { return nil, fmt.Errorf("missing expected field %q (%q) in collection %q", field.GetName(), field.GetId(), tempCollection.Name) } // set new random id to prevent duplications if the same field is aliased multiple times clone.SetId("_clone_" + security.PseudorandomString(4)) clone.SetName(col.alias) result[col.alias] = &queryField{ original: field, field: clone, collection: collection, } } return result, nil } func findCollectionsByIdentifiers(app App, tables []identifier) (map[string]*Collection, error) { names := make([]any, 0, len(tables)) for _, table := range tables { if strings.Contains(table.alias, "(") { continue // skip expressions } names = append(names, table.original) } if len(names) == 0 { return nil, nil } result := make(map[string]*Collection, len(names)) collections := make([]*Collection, 0, len(names)) err := app.CollectionQuery(). AndWhere(dbx.In("name", names...)). All(&collections) if err != nil { return nil, err } for _, table := range tables { for _, collection := range collections { if collection.Name == table.original { result[table.alias] = collection } } } return result, nil } func getQueryTableInfo(app App, selectQuery string) ([]*TableInfoRow, error) { tempView := "_temp_" + security.PseudorandomString(6) var info []*TableInfoRow txErr := app.RunInTransaction(func(txApp App) error { // create a temp view with the provided query err := txApp.SaveView(tempView, selectQuery) if err != nil { return err } // extract the generated view table info info, err = txApp.TableInfo(tempView) return errors.Join(err, txApp.DeleteView(tempView)) }) if txErr != nil { return nil, txErr } return info, nil } // ------------------------------------------------------------------- // Raw query identifiers parser // ------------------------------------------------------------------- var joinReplaceRegex = regexp.MustCompile(`(?im)\s+(full\s+outer\s+join|left\s+outer\s+join|right\s+outer\s+join|full\s+join|cross\s+join|inner\s+join|outer\s+join|left\s+join|right\s+join|join)\s+?`) var discardReplaceRegex = regexp.MustCompile(`(?im)\s+(where|group\s+by|having|order|limit|with)\s+?`) var commentsReplaceRegex = regexp.MustCompile(`(?m)(\/\*[\s\S]+\*\/)|(--.+$)`) type identifier struct { original string alias string } type identifiersParser struct { columns []identifier tables []identifier } func (p *identifiersParser) parse(selectQuery string) error { str := strings.Trim(strings.TrimSpace(selectQuery), ";") str = joinReplaceRegex.ReplaceAllString(str, " _join_ ") str = discardReplaceRegex.ReplaceAllString(str, " _discard_ ") str = commentsReplaceRegex.ReplaceAllString(str, "") tk := tokenizer.NewFromString(str) tk.Separators(',', ' ', '\n', '\t') tk.KeepSeparator(true) var skip bool var partType string var activeBuilder *strings.Builder var selectParts strings.Builder var fromParts strings.Builder var joinParts strings.Builder for { token, err := tk.Scan() if err != nil { if err != io.EOF { return err } break } trimmed := strings.ToLower(strings.TrimSpace(token)) switch trimmed { case "select": skip = false partType = "select" activeBuilder = &selectParts case "distinct": continue // ignore as it is not important for the identifiers parsing case "from": skip = false partType = "from" activeBuilder = &fromParts case "_join_": skip = false // the previous part was also a join if partType == "join" { joinParts.WriteString(",") } partType = "join" activeBuilder = &joinParts case "_discard_": // skip following tokens skip = true default: isJoin := partType == "join" if isJoin && trimmed == "on" { skip = true } if !skip && activeBuilder != nil { activeBuilder.WriteString(" ") activeBuilder.WriteString(token) } } } selects, err := extractIdentifiers(selectParts.String()) if err != nil { return err } froms, err := extractIdentifiers(fromParts.String()) if err != nil { return err } joins, err := extractIdentifiers(joinParts.String()) if err != nil { return err } p.columns = selects p.tables = froms p.tables = append(p.tables, joins...) return nil } func extractIdentifiers(rawExpression string) ([]identifier, error) { rawTk := tokenizer.NewFromString(rawExpression) rawTk.Separators(',') rawIdentifiers, err := rawTk.ScanAll() if err != nil { return nil, err } result := make([]identifier, 0, len(rawIdentifiers)) for _, rawIdentifier := range rawIdentifiers { tk := tokenizer.NewFromString(rawIdentifier) tk.Separators(' ', '\n', '\t') parts, err := tk.ScanAll() if err != nil { return nil, err } resolved, err := identifierFromParts(parts) if err != nil { return nil, err } result = append(result, resolved) } return result, nil } func identifierFromParts(parts []string) (identifier, error) { var result identifier switch len(parts) { case 3: if !strings.EqualFold(parts[1], "as") { return result, fmt.Errorf(`invalid identifier part - expected "as", got %v`, parts[1]) } result.original = parts[0] result.alias = parts[2] case 2: result.original = parts[0] result.alias = parts[1] case 1: subParts := strings.Split(parts[0], ".") result.original = parts[0] result.alias = subParts[len(subParts)-1] default: return result, fmt.Errorf(`invalid identifier parts %v`, parts) } result.original = trimRawIdentifier(result.original) // we trim the single quote even though it is not a valid column quote character // because SQLite allows it if the context expects an identifier and not string literal // (https://www.sqlite.org/lang_keywords.html) result.alias = trimRawIdentifier(result.alias, "'") return result, nil } func trimRawIdentifier(rawIdentifier string, extraTrimChars ...string) string { trimChars := "`\"[];" if len(extraTrimChars) > 0 { trimChars += strings.Join(extraTrimChars, "") } parts := strings.Split(rawIdentifier, ".") for i := range parts { parts[i] = strings.Trim(parts[i], trimChars) } return strings.Join(parts, ".") }