1
0
mirror of https://github.com/pocketbase/pocketbase.git synced 2025-02-15 09:12:58 +02:00
pocketbase/core/view.go

619 lines
15 KiB
Go

package core
import (
"errors"
"fmt"
"io"
"regexp"
"strings"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/dbutils"
"github.com/pocketbase/pocketbase/tools/inflector"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/pocketbase/pocketbase/tools/tokenizer"
)
// DeleteView drops the specified view name.
//
// This method is a no-op if a view with the provided name doesn't exist.
//
// NB! Be aware that this method is vulnerable to SQL injection and the
// "name" argument must come only from trusted input!
func (app *BaseApp) DeleteView(name string) error {
_, err := app.DB().NewQuery(fmt.Sprintf(
"DROP VIEW IF EXISTS {{%s}}",
name,
)).Execute()
return err
}
// SaveView creates (or updates already existing) persistent SQL view.
//
// NB! Be aware that this method is vulnerable to SQL injection and the
// "selectQuery" argument must come only from trusted input!
func (app *BaseApp) SaveView(name string, selectQuery string) error {
return app.RunInTransaction(func(txApp App) error {
// delete old view (if exists)
if err := txApp.DeleteView(name); err != nil {
return err
}
selectQuery = strings.Trim(strings.TrimSpace(selectQuery), ";")
// try to loosely detect multiple inline statements
tk := tokenizer.NewFromString(selectQuery)
tk.Separators(';')
if queryParts, _ := tk.ScanAll(); len(queryParts) > 1 {
return errors.New("multiple statements are not supported")
}
// (re)create the view
//
// note: the query is wrapped in a secondary SELECT as a rudimentary
// measure to discourage multiple inline sql statements execution
viewQuery := fmt.Sprintf("CREATE VIEW {{%s}} AS SELECT * FROM (%s)", name, selectQuery)
if _, err := txApp.DB().NewQuery(viewQuery).Execute(); err != nil {
return err
}
// fetch the view table info to ensure that the view was created
// because missing tables or columns won't return an error
if _, err := txApp.TableInfo(name); err != nil {
// manually cleanup previously created view in case the func
// is called in a nested transaction and the error is discarded
txApp.DeleteView(name)
return err
}
return nil
})
}
// CreateViewFields creates a new FieldsList from the provided select query.
//
// There are some caveats:
// - The select query must have an "id" column.
// - Wildcard ("*") columns are not supported to avoid accidentally leaking sensitive data.
func (app *BaseApp) CreateViewFields(selectQuery string) (FieldsList, error) {
result := NewFieldsList()
suggestedFields, err := parseQueryToFields(app, selectQuery)
if err != nil {
return result, err
}
// note wrap in a transaction in case the selectQuery contains
// multiple statements allowing us to rollback on any error
txErr := app.RunInTransaction(func(txApp App) error {
info, err := getQueryTableInfo(txApp, selectQuery)
if err != nil {
return err
}
var hasId bool
for _, row := range info {
if row.Name == FieldNameId {
hasId = true
}
var field Field
if f, ok := suggestedFields[row.Name]; ok {
field = f.field
} else {
field = defaultViewField(row.Name)
}
result.Add(field)
}
if !hasId {
return errors.New("missing required id column (you can use `(ROW_NUMBER() OVER()) as id` if you don't have one)")
}
return nil
})
return result, txErr
}
// FindRecordByViewFile returns the original Record of the provided view collection file.
func (app *BaseApp) FindRecordByViewFile(viewCollectionModelOrIdentifier any, fileFieldName string, filename string) (*Record, error) {
view, err := getCollectionByModelOrIdentifier(app, viewCollectionModelOrIdentifier)
if err != nil {
return nil, err
}
if !view.IsView() {
return nil, errors.New("not a view collection")
}
var findFirstNonViewQueryFileField func(int) (*queryField, error)
findFirstNonViewQueryFileField = func(level int) (*queryField, error) {
// check the level depth to prevent infinite circular recursion
// (the limit is arbitrary and may change in the future)
if level > 5 {
return nil, errors.New("reached the max recursion level of view collection file field queries")
}
queryFields, err := parseQueryToFields(app, view.ViewQuery)
if err != nil {
return nil, err
}
for _, item := range queryFields {
if item.collection == nil ||
item.original == nil ||
item.field.GetName() != fileFieldName {
continue
}
if item.collection.IsView() {
view = item.collection
fileFieldName = item.original.GetName()
return findFirstNonViewQueryFileField(level + 1)
}
return item, nil
}
return nil, errors.New("no query file field found")
}
qf, err := findFirstNonViewQueryFileField(1)
if err != nil {
return nil, err
}
cleanFieldName := inflector.Columnify(qf.original.GetName())
record := &Record{}
query := app.RecordQuery(qf.collection).Limit(1)
if opt, ok := qf.original.(MultiValuer); !ok || !opt.IsMultiple() {
query.AndWhere(dbx.HashExp{cleanFieldName: filename})
} else {
query.InnerJoin(
fmt.Sprintf(`%s as {{_je_file}}`, dbutils.JSONEach(cleanFieldName)),
dbx.HashExp{"_je_file.value": filename},
)
}
if err := query.One(record); err != nil {
return nil, err
}
return record, nil
}
// -------------------------------------------------------------------
// Raw query to schema helpers
// -------------------------------------------------------------------
type queryField struct {
// field is the final resolved field.
field Field
// collection refers to the original field's collection model.
// It could be nil if the found query field is not from a collection
collection *Collection
// original is the original found collection field.
// It could be nil if the found query field is not from a collection
original Field
}
func defaultViewField(name string) Field {
return &JSONField{
Name: name,
MaxSize: 1, // unused for views
}
}
var castRegex = regexp.MustCompile(`(?i)^cast\s*\(.*\s+as\s+(\w+)\s*\)$`)
func parseQueryToFields(app App, selectQuery string) (map[string]*queryField, error) {
p := new(identifiersParser)
if err := p.parse(selectQuery); err != nil {
return nil, err
}
collections, err := findCollectionsByIdentifiers(app, p.tables)
if err != nil {
return nil, err
}
result := make(map[string]*queryField, len(p.columns))
var mainTable identifier
if len(p.tables) > 0 {
mainTable = p.tables[0]
}
for _, col := range p.columns {
colLower := strings.ToLower(col.original)
// pk (always assume text field for now)
if col.alias == FieldNameId {
result[col.alias] = &queryField{
field: &TextField{
Name: col.alias,
System: true,
Required: true,
PrimaryKey: true,
Pattern: `^[a-z0-9]+$`,
},
}
continue
}
// numeric aggregations
if strings.HasPrefix(colLower, "count(") || strings.HasPrefix(colLower, "total(") {
result[col.alias] = &queryField{
field: &NumberField{
Name: col.alias,
},
}
continue
}
castMatch := castRegex.FindStringSubmatch(colLower)
// numeric casts
if len(castMatch) == 2 {
switch castMatch[1] {
case "real", "integer", "int", "decimal", "numeric":
result[col.alias] = &queryField{
field: &NumberField{
Name: col.alias,
},
}
continue
case "text":
result[col.alias] = &queryField{
field: &TextField{
Name: col.alias,
},
}
continue
case "boolean", "bool":
result[col.alias] = &queryField{
field: &BoolField{
Name: col.alias,
},
}
continue
}
}
parts := strings.Split(col.original, ".")
var fieldName string
var collection *Collection
if len(parts) == 2 {
fieldName = parts[1]
collection = collections[parts[0]]
} else {
fieldName = parts[0]
collection = collections[mainTable.alias]
}
// fallback to the default field
if collection == nil {
result[col.alias] = &queryField{
field: defaultViewField(col.alias),
}
continue
}
if fieldName == "*" {
return nil, errors.New("dynamic column names are not supported")
}
// find the first field by name (case insensitive)
var field Field
for _, f := range collection.Fields {
if strings.EqualFold(f.GetName(), fieldName) {
field = f
break
}
}
// fallback to the default field
if field == nil {
result[col.alias] = &queryField{
field: defaultViewField(col.alias),
collection: collection,
}
continue
}
// convert to relation since it is an id reference
if strings.EqualFold(fieldName, FieldNameId) {
result[col.alias] = &queryField{
field: &RelationField{
Name: col.alias,
MaxSelect: 1,
CollectionId: collection.Id,
},
collection: collection,
}
continue
}
// we fetch a brand new collection object to avoid using reflection
// or having a dedicated Clone method for each field type
tempCollection, err := app.FindCollectionByNameOrId(collection.Id)
if err != nil {
return nil, err
}
clone := tempCollection.Fields.GetById(field.GetId())
if clone == nil {
return nil, fmt.Errorf("missing expected field %q (%q) in collection %q", field.GetName(), field.GetId(), tempCollection.Name)
}
// set new random id to prevent duplications if the same field is aliased multiple times
clone.SetId("_clone_" + security.PseudorandomString(4))
clone.SetName(col.alias)
result[col.alias] = &queryField{
original: field,
field: clone,
collection: collection,
}
}
return result, nil
}
func findCollectionsByIdentifiers(app App, tables []identifier) (map[string]*Collection, error) {
names := make([]any, 0, len(tables))
for _, table := range tables {
if strings.Contains(table.alias, "(") {
continue // skip expressions
}
names = append(names, table.original)
}
if len(names) == 0 {
return nil, nil
}
result := make(map[string]*Collection, len(names))
collections := make([]*Collection, 0, len(names))
err := app.CollectionQuery().
AndWhere(dbx.In("name", names...)).
All(&collections)
if err != nil {
return nil, err
}
for _, table := range tables {
for _, collection := range collections {
if collection.Name == table.original {
result[table.alias] = collection
}
}
}
return result, nil
}
func getQueryTableInfo(app App, selectQuery string) ([]*TableInfoRow, error) {
tempView := "_temp_" + security.PseudorandomString(6)
var info []*TableInfoRow
txErr := app.RunInTransaction(func(txApp App) error {
// create a temp view with the provided query
err := txApp.SaveView(tempView, selectQuery)
if err != nil {
return err
}
// extract the generated view table info
info, err = txApp.TableInfo(tempView)
return errors.Join(err, txApp.DeleteView(tempView))
})
if txErr != nil {
return nil, txErr
}
return info, nil
}
// -------------------------------------------------------------------
// Raw query identifiers parser
// -------------------------------------------------------------------
var joinReplaceRegex = regexp.MustCompile(`(?im)\s+(full\s+outer\s+join|left\s+outer\s+join|right\s+outer\s+join|full\s+join|cross\s+join|inner\s+join|outer\s+join|left\s+join|right\s+join|join)\s+?`)
var discardReplaceRegex = regexp.MustCompile(`(?im)\s+(where|group\s+by|having|order|limit|with)\s+?`)
var commentsReplaceRegex = regexp.MustCompile(`(?m)(\/\*[\s\S]+\*\/)|(--.+$)`)
type identifier struct {
original string
alias string
}
type identifiersParser struct {
columns []identifier
tables []identifier
}
func (p *identifiersParser) parse(selectQuery string) error {
str := strings.Trim(strings.TrimSpace(selectQuery), ";")
str = joinReplaceRegex.ReplaceAllString(str, " _join_ ")
str = discardReplaceRegex.ReplaceAllString(str, " _discard_ ")
str = commentsReplaceRegex.ReplaceAllString(str, "")
tk := tokenizer.NewFromString(str)
tk.Separators(',', ' ', '\n', '\t')
tk.KeepSeparator(true)
var skip bool
var partType string
var activeBuilder *strings.Builder
var selectParts strings.Builder
var fromParts strings.Builder
var joinParts strings.Builder
for {
token, err := tk.Scan()
if err != nil {
if err != io.EOF {
return err
}
break
}
trimmed := strings.ToLower(strings.TrimSpace(token))
switch trimmed {
case "select":
skip = false
partType = "select"
activeBuilder = &selectParts
case "distinct":
continue // ignore as it is not important for the identifiers parsing
case "from":
skip = false
partType = "from"
activeBuilder = &fromParts
case "_join_":
skip = false
// the previous part was also a join
if partType == "join" {
joinParts.WriteString(",")
}
partType = "join"
activeBuilder = &joinParts
case "_discard_":
// skip following tokens
skip = true
default:
isJoin := partType == "join"
if isJoin && trimmed == "on" {
skip = true
}
if !skip && activeBuilder != nil {
activeBuilder.WriteString(" ")
activeBuilder.WriteString(token)
}
}
}
selects, err := extractIdentifiers(selectParts.String())
if err != nil {
return err
}
froms, err := extractIdentifiers(fromParts.String())
if err != nil {
return err
}
joins, err := extractIdentifiers(joinParts.String())
if err != nil {
return err
}
p.columns = selects
p.tables = froms
p.tables = append(p.tables, joins...)
return nil
}
func extractIdentifiers(rawExpression string) ([]identifier, error) {
rawTk := tokenizer.NewFromString(rawExpression)
rawTk.Separators(',')
rawIdentifiers, err := rawTk.ScanAll()
if err != nil {
return nil, err
}
result := make([]identifier, 0, len(rawIdentifiers))
for _, rawIdentifier := range rawIdentifiers {
tk := tokenizer.NewFromString(rawIdentifier)
tk.Separators(' ', '\n', '\t')
parts, err := tk.ScanAll()
if err != nil {
return nil, err
}
resolved, err := identifierFromParts(parts)
if err != nil {
return nil, err
}
result = append(result, resolved)
}
return result, nil
}
func identifierFromParts(parts []string) (identifier, error) {
var result identifier
switch len(parts) {
case 3:
if !strings.EqualFold(parts[1], "as") {
return result, fmt.Errorf(`invalid identifier part - expected "as", got %v`, parts[1])
}
result.original = parts[0]
result.alias = parts[2]
case 2:
result.original = parts[0]
result.alias = parts[1]
case 1:
subParts := strings.Split(parts[0], ".")
result.original = parts[0]
result.alias = subParts[len(subParts)-1]
default:
return result, fmt.Errorf(`invalid identifier parts %v`, parts)
}
result.original = trimRawIdentifier(result.original)
// we trim the single quote even though it is not a valid column quote character
// because SQLite allows it if the context expects an identifier and not string literal
// (https://www.sqlite.org/lang_keywords.html)
result.alias = trimRawIdentifier(result.alias, "'")
return result, nil
}
func trimRawIdentifier(rawIdentifier string, extraTrimChars ...string) string {
trimChars := "`\"[];"
if len(extraTrimChars) > 0 {
trimChars += strings.Join(extraTrimChars, "")
}
parts := strings.Split(rawIdentifier, ".")
for i := range parts {
parts[i] = strings.Trim(parts[i], trimChars)
}
return strings.Join(parts, ".")
}