mirror of
https://github.com/pocketbase/pocketbase.git
synced 2025-02-15 09:12:58 +02:00
619 lines
15 KiB
Go
619 lines
15 KiB
Go
package core
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/pocketbase/dbx"
|
|
"github.com/pocketbase/pocketbase/tools/dbutils"
|
|
"github.com/pocketbase/pocketbase/tools/inflector"
|
|
"github.com/pocketbase/pocketbase/tools/security"
|
|
"github.com/pocketbase/pocketbase/tools/tokenizer"
|
|
)
|
|
|
|
// DeleteView drops the specified view name.
|
|
//
|
|
// This method is a no-op if a view with the provided name doesn't exist.
|
|
//
|
|
// NB! Be aware that this method is vulnerable to SQL injection and the
|
|
// "name" argument must come only from trusted input!
|
|
func (app *BaseApp) DeleteView(name string) error {
|
|
_, err := app.DB().NewQuery(fmt.Sprintf(
|
|
"DROP VIEW IF EXISTS {{%s}}",
|
|
name,
|
|
)).Execute()
|
|
|
|
return err
|
|
}
|
|
|
|
// SaveView creates (or updates already existing) persistent SQL view.
|
|
//
|
|
// NB! Be aware that this method is vulnerable to SQL injection and the
|
|
// "selectQuery" argument must come only from trusted input!
|
|
func (app *BaseApp) SaveView(name string, selectQuery string) error {
|
|
return app.RunInTransaction(func(txApp App) error {
|
|
// delete old view (if exists)
|
|
if err := txApp.DeleteView(name); err != nil {
|
|
return err
|
|
}
|
|
|
|
selectQuery = strings.Trim(strings.TrimSpace(selectQuery), ";")
|
|
|
|
// try to loosely detect multiple inline statements
|
|
tk := tokenizer.NewFromString(selectQuery)
|
|
tk.Separators(';')
|
|
if queryParts, _ := tk.ScanAll(); len(queryParts) > 1 {
|
|
return errors.New("multiple statements are not supported")
|
|
}
|
|
|
|
// (re)create the view
|
|
//
|
|
// note: the query is wrapped in a secondary SELECT as a rudimentary
|
|
// measure to discourage multiple inline sql statements execution
|
|
viewQuery := fmt.Sprintf("CREATE VIEW {{%s}} AS SELECT * FROM (%s)", name, selectQuery)
|
|
if _, err := txApp.DB().NewQuery(viewQuery).Execute(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// fetch the view table info to ensure that the view was created
|
|
// because missing tables or columns won't return an error
|
|
if _, err := txApp.TableInfo(name); err != nil {
|
|
// manually cleanup previously created view in case the func
|
|
// is called in a nested transaction and the error is discarded
|
|
txApp.DeleteView(name)
|
|
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// CreateViewFields creates a new FieldsList from the provided select query.
|
|
//
|
|
// There are some caveats:
|
|
// - The select query must have an "id" column.
|
|
// - Wildcard ("*") columns are not supported to avoid accidentally leaking sensitive data.
|
|
func (app *BaseApp) CreateViewFields(selectQuery string) (FieldsList, error) {
|
|
result := NewFieldsList()
|
|
|
|
suggestedFields, err := parseQueryToFields(app, selectQuery)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
// note wrap in a transaction in case the selectQuery contains
|
|
// multiple statements allowing us to rollback on any error
|
|
txErr := app.RunInTransaction(func(txApp App) error {
|
|
info, err := getQueryTableInfo(txApp, selectQuery)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var hasId bool
|
|
|
|
for _, row := range info {
|
|
if row.Name == FieldNameId {
|
|
hasId = true
|
|
}
|
|
|
|
var field Field
|
|
|
|
if f, ok := suggestedFields[row.Name]; ok {
|
|
field = f.field
|
|
} else {
|
|
field = defaultViewField(row.Name)
|
|
}
|
|
|
|
result.Add(field)
|
|
}
|
|
|
|
if !hasId {
|
|
return errors.New("missing required id column (you can use `(ROW_NUMBER() OVER()) as id` if you don't have one)")
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
return result, txErr
|
|
}
|
|
|
|
// FindRecordByViewFile returns the original Record of the provided view collection file.
|
|
func (app *BaseApp) FindRecordByViewFile(viewCollectionModelOrIdentifier any, fileFieldName string, filename string) (*Record, error) {
|
|
view, err := getCollectionByModelOrIdentifier(app, viewCollectionModelOrIdentifier)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !view.IsView() {
|
|
return nil, errors.New("not a view collection")
|
|
}
|
|
|
|
var findFirstNonViewQueryFileField func(int) (*queryField, error)
|
|
findFirstNonViewQueryFileField = func(level int) (*queryField, error) {
|
|
// check the level depth to prevent infinite circular recursion
|
|
// (the limit is arbitrary and may change in the future)
|
|
if level > 5 {
|
|
return nil, errors.New("reached the max recursion level of view collection file field queries")
|
|
}
|
|
|
|
queryFields, err := parseQueryToFields(app, view.ViewQuery)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, item := range queryFields {
|
|
if item.collection == nil ||
|
|
item.original == nil ||
|
|
item.field.GetName() != fileFieldName {
|
|
continue
|
|
}
|
|
|
|
if item.collection.IsView() {
|
|
view = item.collection
|
|
fileFieldName = item.original.GetName()
|
|
return findFirstNonViewQueryFileField(level + 1)
|
|
}
|
|
|
|
return item, nil
|
|
}
|
|
|
|
return nil, errors.New("no query file field found")
|
|
}
|
|
|
|
qf, err := findFirstNonViewQueryFileField(1)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cleanFieldName := inflector.Columnify(qf.original.GetName())
|
|
|
|
record := &Record{}
|
|
|
|
query := app.RecordQuery(qf.collection).Limit(1)
|
|
|
|
if opt, ok := qf.original.(MultiValuer); !ok || !opt.IsMultiple() {
|
|
query.AndWhere(dbx.HashExp{cleanFieldName: filename})
|
|
} else {
|
|
query.InnerJoin(
|
|
fmt.Sprintf(`%s as {{_je_file}}`, dbutils.JSONEach(cleanFieldName)),
|
|
dbx.HashExp{"_je_file.value": filename},
|
|
)
|
|
}
|
|
|
|
if err := query.One(record); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return record, nil
|
|
}
|
|
|
|
// -------------------------------------------------------------------
|
|
// Raw query to schema helpers
|
|
// -------------------------------------------------------------------
|
|
|
|
type queryField struct {
|
|
// field is the final resolved field.
|
|
field Field
|
|
|
|
// collection refers to the original field's collection model.
|
|
// It could be nil if the found query field is not from a collection
|
|
collection *Collection
|
|
|
|
// original is the original found collection field.
|
|
// It could be nil if the found query field is not from a collection
|
|
original Field
|
|
}
|
|
|
|
func defaultViewField(name string) Field {
|
|
return &JSONField{
|
|
Name: name,
|
|
MaxSize: 1, // unused for views
|
|
}
|
|
}
|
|
|
|
var castRegex = regexp.MustCompile(`(?i)^cast\s*\(.*\s+as\s+(\w+)\s*\)$`)
|
|
|
|
func parseQueryToFields(app App, selectQuery string) (map[string]*queryField, error) {
|
|
p := new(identifiersParser)
|
|
if err := p.parse(selectQuery); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
collections, err := findCollectionsByIdentifiers(app, p.tables)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := make(map[string]*queryField, len(p.columns))
|
|
|
|
var mainTable identifier
|
|
|
|
if len(p.tables) > 0 {
|
|
mainTable = p.tables[0]
|
|
}
|
|
|
|
for _, col := range p.columns {
|
|
colLower := strings.ToLower(col.original)
|
|
|
|
// pk (always assume text field for now)
|
|
if col.alias == FieldNameId {
|
|
result[col.alias] = &queryField{
|
|
field: &TextField{
|
|
Name: col.alias,
|
|
System: true,
|
|
Required: true,
|
|
PrimaryKey: true,
|
|
Pattern: `^[a-z0-9]+$`,
|
|
},
|
|
}
|
|
continue
|
|
}
|
|
|
|
// numeric aggregations
|
|
if strings.HasPrefix(colLower, "count(") || strings.HasPrefix(colLower, "total(") {
|
|
result[col.alias] = &queryField{
|
|
field: &NumberField{
|
|
Name: col.alias,
|
|
},
|
|
}
|
|
continue
|
|
}
|
|
|
|
castMatch := castRegex.FindStringSubmatch(colLower)
|
|
|
|
// numeric casts
|
|
if len(castMatch) == 2 {
|
|
switch castMatch[1] {
|
|
case "real", "integer", "int", "decimal", "numeric":
|
|
result[col.alias] = &queryField{
|
|
field: &NumberField{
|
|
Name: col.alias,
|
|
},
|
|
}
|
|
continue
|
|
case "text":
|
|
result[col.alias] = &queryField{
|
|
field: &TextField{
|
|
Name: col.alias,
|
|
},
|
|
}
|
|
continue
|
|
case "boolean", "bool":
|
|
result[col.alias] = &queryField{
|
|
field: &BoolField{
|
|
Name: col.alias,
|
|
},
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
|
|
parts := strings.Split(col.original, ".")
|
|
|
|
var fieldName string
|
|
var collection *Collection
|
|
|
|
if len(parts) == 2 {
|
|
fieldName = parts[1]
|
|
collection = collections[parts[0]]
|
|
} else {
|
|
fieldName = parts[0]
|
|
collection = collections[mainTable.alias]
|
|
}
|
|
|
|
// fallback to the default field
|
|
if collection == nil {
|
|
result[col.alias] = &queryField{
|
|
field: defaultViewField(col.alias),
|
|
}
|
|
continue
|
|
}
|
|
|
|
if fieldName == "*" {
|
|
return nil, errors.New("dynamic column names are not supported")
|
|
}
|
|
|
|
// find the first field by name (case insensitive)
|
|
var field Field
|
|
for _, f := range collection.Fields {
|
|
if strings.EqualFold(f.GetName(), fieldName) {
|
|
field = f
|
|
break
|
|
}
|
|
}
|
|
|
|
// fallback to the default field
|
|
if field == nil {
|
|
result[col.alias] = &queryField{
|
|
field: defaultViewField(col.alias),
|
|
collection: collection,
|
|
}
|
|
continue
|
|
}
|
|
|
|
// convert to relation since it is an id reference
|
|
if strings.EqualFold(fieldName, FieldNameId) {
|
|
result[col.alias] = &queryField{
|
|
field: &RelationField{
|
|
Name: col.alias,
|
|
MaxSelect: 1,
|
|
CollectionId: collection.Id,
|
|
},
|
|
collection: collection,
|
|
}
|
|
continue
|
|
}
|
|
|
|
// we fetch a brand new collection object to avoid using reflection
|
|
// or having a dedicated Clone method for each field type
|
|
tempCollection, err := app.FindCollectionByNameOrId(collection.Id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
clone := tempCollection.Fields.GetById(field.GetId())
|
|
if clone == nil {
|
|
return nil, fmt.Errorf("missing expected field %q (%q) in collection %q", field.GetName(), field.GetId(), tempCollection.Name)
|
|
}
|
|
// set new random id to prevent duplications if the same field is aliased multiple times
|
|
clone.SetId("_clone_" + security.PseudorandomString(4))
|
|
clone.SetName(col.alias)
|
|
|
|
result[col.alias] = &queryField{
|
|
original: field,
|
|
field: clone,
|
|
collection: collection,
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func findCollectionsByIdentifiers(app App, tables []identifier) (map[string]*Collection, error) {
|
|
names := make([]any, 0, len(tables))
|
|
|
|
for _, table := range tables {
|
|
if strings.Contains(table.alias, "(") {
|
|
continue // skip expressions
|
|
}
|
|
names = append(names, table.original)
|
|
}
|
|
|
|
if len(names) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
result := make(map[string]*Collection, len(names))
|
|
collections := make([]*Collection, 0, len(names))
|
|
|
|
err := app.CollectionQuery().
|
|
AndWhere(dbx.In("name", names...)).
|
|
All(&collections)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, table := range tables {
|
|
for _, collection := range collections {
|
|
if collection.Name == table.original {
|
|
result[table.alias] = collection
|
|
}
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func getQueryTableInfo(app App, selectQuery string) ([]*TableInfoRow, error) {
|
|
tempView := "_temp_" + security.PseudorandomString(6)
|
|
|
|
var info []*TableInfoRow
|
|
|
|
txErr := app.RunInTransaction(func(txApp App) error {
|
|
// create a temp view with the provided query
|
|
err := txApp.SaveView(tempView, selectQuery)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// extract the generated view table info
|
|
info, err = txApp.TableInfo(tempView)
|
|
|
|
return errors.Join(err, txApp.DeleteView(tempView))
|
|
})
|
|
|
|
if txErr != nil {
|
|
return nil, txErr
|
|
}
|
|
|
|
return info, nil
|
|
}
|
|
|
|
// -------------------------------------------------------------------
|
|
// Raw query identifiers parser
|
|
// -------------------------------------------------------------------
|
|
|
|
var joinReplaceRegex = regexp.MustCompile(`(?im)\s+(full\s+outer\s+join|left\s+outer\s+join|right\s+outer\s+join|full\s+join|cross\s+join|inner\s+join|outer\s+join|left\s+join|right\s+join|join)\s+?`)
|
|
var discardReplaceRegex = regexp.MustCompile(`(?im)\s+(where|group\s+by|having|order|limit|with)\s+?`)
|
|
var commentsReplaceRegex = regexp.MustCompile(`(?m)(\/\*[\s\S]+\*\/)|(--.+$)`)
|
|
|
|
type identifier struct {
|
|
original string
|
|
alias string
|
|
}
|
|
|
|
type identifiersParser struct {
|
|
columns []identifier
|
|
tables []identifier
|
|
}
|
|
|
|
func (p *identifiersParser) parse(selectQuery string) error {
|
|
str := strings.Trim(strings.TrimSpace(selectQuery), ";")
|
|
str = joinReplaceRegex.ReplaceAllString(str, " _join_ ")
|
|
str = discardReplaceRegex.ReplaceAllString(str, " _discard_ ")
|
|
str = commentsReplaceRegex.ReplaceAllString(str, "")
|
|
|
|
tk := tokenizer.NewFromString(str)
|
|
tk.Separators(',', ' ', '\n', '\t')
|
|
tk.KeepSeparator(true)
|
|
|
|
var skip bool
|
|
var partType string
|
|
var activeBuilder *strings.Builder
|
|
var selectParts strings.Builder
|
|
var fromParts strings.Builder
|
|
var joinParts strings.Builder
|
|
|
|
for {
|
|
token, err := tk.Scan()
|
|
if err != nil {
|
|
if err != io.EOF {
|
|
return err
|
|
}
|
|
break
|
|
}
|
|
|
|
trimmed := strings.ToLower(strings.TrimSpace(token))
|
|
|
|
switch trimmed {
|
|
case "select":
|
|
skip = false
|
|
partType = "select"
|
|
activeBuilder = &selectParts
|
|
case "distinct":
|
|
continue // ignore as it is not important for the identifiers parsing
|
|
case "from":
|
|
skip = false
|
|
partType = "from"
|
|
activeBuilder = &fromParts
|
|
case "_join_":
|
|
skip = false
|
|
|
|
// the previous part was also a join
|
|
if partType == "join" {
|
|
joinParts.WriteString(",")
|
|
}
|
|
|
|
partType = "join"
|
|
activeBuilder = &joinParts
|
|
case "_discard_":
|
|
// skip following tokens
|
|
skip = true
|
|
default:
|
|
isJoin := partType == "join"
|
|
|
|
if isJoin && trimmed == "on" {
|
|
skip = true
|
|
}
|
|
|
|
if !skip && activeBuilder != nil {
|
|
activeBuilder.WriteString(" ")
|
|
activeBuilder.WriteString(token)
|
|
}
|
|
}
|
|
}
|
|
|
|
selects, err := extractIdentifiers(selectParts.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
froms, err := extractIdentifiers(fromParts.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
joins, err := extractIdentifiers(joinParts.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
p.columns = selects
|
|
p.tables = froms
|
|
p.tables = append(p.tables, joins...)
|
|
|
|
return nil
|
|
}
|
|
|
|
func extractIdentifiers(rawExpression string) ([]identifier, error) {
|
|
rawTk := tokenizer.NewFromString(rawExpression)
|
|
rawTk.Separators(',')
|
|
|
|
rawIdentifiers, err := rawTk.ScanAll()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := make([]identifier, 0, len(rawIdentifiers))
|
|
|
|
for _, rawIdentifier := range rawIdentifiers {
|
|
tk := tokenizer.NewFromString(rawIdentifier)
|
|
tk.Separators(' ', '\n', '\t')
|
|
|
|
parts, err := tk.ScanAll()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resolved, err := identifierFromParts(parts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result = append(result, resolved)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func identifierFromParts(parts []string) (identifier, error) {
|
|
var result identifier
|
|
|
|
switch len(parts) {
|
|
case 3:
|
|
if !strings.EqualFold(parts[1], "as") {
|
|
return result, fmt.Errorf(`invalid identifier part - expected "as", got %v`, parts[1])
|
|
}
|
|
|
|
result.original = parts[0]
|
|
result.alias = parts[2]
|
|
case 2:
|
|
result.original = parts[0]
|
|
result.alias = parts[1]
|
|
case 1:
|
|
subParts := strings.Split(parts[0], ".")
|
|
result.original = parts[0]
|
|
result.alias = subParts[len(subParts)-1]
|
|
default:
|
|
return result, fmt.Errorf(`invalid identifier parts %v`, parts)
|
|
}
|
|
|
|
result.original = trimRawIdentifier(result.original)
|
|
|
|
// we trim the single quote even though it is not a valid column quote character
|
|
// because SQLite allows it if the context expects an identifier and not string literal
|
|
// (https://www.sqlite.org/lang_keywords.html)
|
|
result.alias = trimRawIdentifier(result.alias, "'")
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func trimRawIdentifier(rawIdentifier string, extraTrimChars ...string) string {
|
|
trimChars := "`\"[];"
|
|
if len(extraTrimChars) > 0 {
|
|
trimChars += strings.Join(extraTrimChars, "")
|
|
}
|
|
|
|
parts := strings.Split(rawIdentifier, ".")
|
|
|
|
for i := range parts {
|
|
parts[i] = strings.Trim(parts[i], trimChars)
|
|
}
|
|
|
|
return strings.Join(parts, ".")
|
|
}
|