mirror of
https://github.com/pocketbase/pocketbase.git
synced 2024-11-28 18:11:17 +02:00
591 lines
15 KiB
Go
591 lines
15 KiB
Go
package daos
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/pocketbase/dbx"
|
|
"github.com/pocketbase/pocketbase/models"
|
|
"github.com/pocketbase/pocketbase/models/schema"
|
|
"github.com/pocketbase/pocketbase/tools/inflector"
|
|
"github.com/pocketbase/pocketbase/tools/list"
|
|
"github.com/pocketbase/pocketbase/tools/security"
|
|
"github.com/pocketbase/pocketbase/tools/tokenizer"
|
|
"github.com/pocketbase/pocketbase/tools/types"
|
|
)
|
|
|
|
// DeleteView drops the specified view name.
|
|
//
|
|
// This method is a no-op if a view with the provided name doesn't exist.
|
|
//
|
|
// Be aware that this method is vulnerable to SQL injection and the
|
|
// "name" argument must come only from trusted input!
|
|
func (dao *Dao) DeleteView(name string) error {
|
|
_, err := dao.DB().NewQuery(fmt.Sprintf(
|
|
"DROP VIEW IF EXISTS {{%s}}",
|
|
name,
|
|
)).Execute()
|
|
|
|
return err
|
|
}
|
|
|
|
// SaveView creates (or updates already existing) persistent SQL view.
|
|
//
|
|
// Be aware that this method is vulnerable to SQL injection and the
|
|
// "selectQuery" argument must come only from trusted input!
|
|
func (dao *Dao) SaveView(name string, selectQuery string) error {
|
|
return dao.RunInTransaction(func(txDao *Dao) error {
|
|
// delete old view (if exists)
|
|
if err := txDao.DeleteView(name); err != nil {
|
|
return err
|
|
}
|
|
|
|
trimmed := strings.Trim(selectQuery, ";")
|
|
|
|
// try to eagerly detect multiple inline statements
|
|
tk := tokenizer.NewFromString(trimmed)
|
|
tk.Separators(';')
|
|
if queryParts, _ := tk.ScanAll(); len(queryParts) > 1 {
|
|
return errors.New("multiple statements are not supported")
|
|
}
|
|
|
|
// (re)create the view
|
|
//
|
|
// note: the query is wrapped in a secondary SELECT as a rudimentary
|
|
// measure to discourage multiple inline sql statements execution.
|
|
viewQuery := fmt.Sprintf("CREATE VIEW {{%s}} AS SELECT * FROM (%s)", name, trimmed)
|
|
if _, err := txDao.DB().NewQuery(viewQuery).Execute(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// fetch the view table info to ensure that the view was created
|
|
// because missing tables or columns won't return an error
|
|
if _, err := txDao.TableInfo(name); err != nil {
|
|
// manually cleanup previously created view in case the func
|
|
// is called in a nested transaction and the error is discarded
|
|
txDao.DeleteView(name)
|
|
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// CreateViewSchema creates a new view schema from the provided select query.
|
|
//
|
|
// There are some caveats:
|
|
// - The select query must have an "id" column.
|
|
// - Wildcard ("*") columns are not supported to avoid accidentally leaking sensitive data.
|
|
func (dao *Dao) CreateViewSchema(selectQuery string) (schema.Schema, error) {
|
|
result := schema.NewSchema()
|
|
|
|
suggestedFields, err := dao.parseQueryToFields(selectQuery)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
// note wrap in a transaction in case the selectQuery contains
|
|
// multiple statements allowing us to rollback on any error
|
|
txErr := dao.RunInTransaction(func(txDao *Dao) error {
|
|
tempView := "_temp_" + security.PseudorandomString(5)
|
|
// create a temp view with the provided query
|
|
if err := txDao.SaveView(tempView, selectQuery); err != nil {
|
|
return err
|
|
}
|
|
defer txDao.DeleteView(tempView)
|
|
|
|
// extract the generated view table info
|
|
info, err := txDao.TableInfo(tempView)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var hasId bool
|
|
|
|
for _, row := range info {
|
|
if row.Name == schema.FieldNameId {
|
|
hasId = true
|
|
}
|
|
|
|
if list.ExistInSlice(row.Name, schema.BaseModelFieldNames()) {
|
|
continue // skip base model fields since they are not part of the schema
|
|
}
|
|
|
|
var field *schema.SchemaField
|
|
|
|
if f, ok := suggestedFields[row.Name]; ok {
|
|
field = f.field
|
|
} else {
|
|
field = defaultViewField(row.Name)
|
|
}
|
|
|
|
result.AddField(field)
|
|
}
|
|
|
|
if !hasId {
|
|
return errors.New("missing required id column (you ca use `(ROW_NUMBER() OVER()) as id` if you don't have one)")
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
return result, txErr
|
|
}
|
|
|
|
// FindRecordByViewFile returns the original models.Record of the
|
|
// provided view collection file.
|
|
func (dao *Dao) FindRecordByViewFile(
|
|
viewCollectionNameOrId string,
|
|
fileFieldName string,
|
|
filename string,
|
|
) (*models.Record, error) {
|
|
view, err := dao.FindCollectionByNameOrId(viewCollectionNameOrId)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !view.IsView() {
|
|
return nil, errors.New("not a view collection")
|
|
}
|
|
|
|
var findFirstNonViewQueryFileField func(int) (*queryField, error)
|
|
findFirstNonViewQueryFileField = func(level int) (*queryField, error) {
|
|
// check the level depth to prevent infinite circular recursion
|
|
// (the limit is arbitrary and may change in the future)
|
|
if level > 5 {
|
|
return nil, errors.New("reached the max recursion level of view collection file field queries")
|
|
}
|
|
|
|
queryFields, err := dao.parseQueryToFields(view.ViewOptions().Query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, item := range queryFields {
|
|
if item.collection == nil ||
|
|
item.original == nil ||
|
|
item.field.Name != fileFieldName {
|
|
continue
|
|
}
|
|
|
|
if item.collection.IsView() {
|
|
view = item.collection
|
|
fileFieldName = item.original.Name
|
|
return findFirstNonViewQueryFileField(level + 1)
|
|
}
|
|
|
|
return item, nil
|
|
}
|
|
|
|
return nil, errors.New("no query file field found")
|
|
}
|
|
|
|
qf, err := findFirstNonViewQueryFileField(1)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cleanFieldName := inflector.Columnify(qf.original.Name)
|
|
|
|
record := &models.Record{}
|
|
|
|
query := dao.RecordQuery(qf.collection).Limit(1)
|
|
|
|
if opt, ok := qf.original.Options.(schema.MultiValuer); !ok || !opt.IsMultiple() {
|
|
query.AndWhere(dbx.HashExp{cleanFieldName: filename})
|
|
} else {
|
|
query.InnerJoin(fmt.Sprintf(
|
|
`json_each(CASE WHEN json_valid([[%s]]) THEN [[%s]] ELSE json_array([[%s]]) END) as {{_je_file}}`,
|
|
cleanFieldName, cleanFieldName, cleanFieldName,
|
|
), dbx.HashExp{"_je_file.value": filename})
|
|
}
|
|
|
|
if err := query.One(record); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return record, nil
|
|
}
|
|
|
|
// -------------------------------------------------------------------
|
|
// Raw query to schema helpers
|
|
// -------------------------------------------------------------------
|
|
|
|
type queryField struct {
|
|
// field is the final resolved field.
|
|
field *schema.SchemaField
|
|
|
|
// collection refers to the original field's collection model.
|
|
// It could be nil if the found query field is not from a collection schema.
|
|
collection *models.Collection
|
|
|
|
// original is the original found collection field.
|
|
// It could be nil if the found query field is not from a collection schema.
|
|
original *schema.SchemaField
|
|
}
|
|
|
|
func defaultViewField(name string) *schema.SchemaField {
|
|
return &schema.SchemaField{
|
|
Name: name,
|
|
Type: schema.FieldTypeJson,
|
|
}
|
|
}
|
|
|
|
func (dao *Dao) parseQueryToFields(selectQuery string) (map[string]*queryField, error) {
|
|
p := new(identifiersParser)
|
|
if err := p.parse(selectQuery); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
collections, err := dao.findCollectionsByIdentifiers(p.tables)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := make(map[string]*queryField, len(p.columns))
|
|
|
|
var mainTable identifier
|
|
|
|
if len(p.tables) > 0 {
|
|
mainTable = p.tables[0]
|
|
}
|
|
|
|
for _, col := range p.columns {
|
|
colLower := strings.ToLower(col.original)
|
|
|
|
// numeric expression cast
|
|
if strings.Contains(colLower, "(") &&
|
|
(strings.HasPrefix(colLower, "count(") ||
|
|
strings.HasPrefix(colLower, "total(") ||
|
|
strings.Contains(colLower, " as numeric") ||
|
|
strings.Contains(colLower, " as real") ||
|
|
strings.Contains(colLower, " as int") ||
|
|
strings.Contains(colLower, " as integer") ||
|
|
strings.Contains(colLower, " as decimal")) {
|
|
result[col.alias] = &queryField{
|
|
field: &schema.SchemaField{
|
|
Name: col.alias,
|
|
Type: schema.FieldTypeNumber,
|
|
},
|
|
}
|
|
continue
|
|
}
|
|
|
|
parts := strings.Split(col.original, ".")
|
|
|
|
var fieldName string
|
|
var collection *models.Collection
|
|
|
|
if len(parts) == 2 {
|
|
fieldName = parts[1]
|
|
collection = collections[parts[0]]
|
|
} else {
|
|
fieldName = parts[0]
|
|
collection = collections[mainTable.alias]
|
|
}
|
|
|
|
// fallback to the default field if the found column is not from a collection schema
|
|
if collection == nil {
|
|
result[col.alias] = &queryField{
|
|
field: defaultViewField(col.alias),
|
|
}
|
|
continue
|
|
}
|
|
|
|
if fieldName == "*" {
|
|
return nil, errors.New("dynamic column names are not supported")
|
|
}
|
|
|
|
// find the first field by name (case insensitive)
|
|
var field *schema.SchemaField
|
|
for _, f := range collection.Schema.Fields() {
|
|
if strings.EqualFold(f.Name, fieldName) {
|
|
field = f
|
|
break
|
|
}
|
|
}
|
|
|
|
if field != nil {
|
|
clone := *field
|
|
clone.Id = "" // unset to prevent duplications if the same field is aliased multiple times
|
|
clone.Name = col.alias
|
|
result[col.alias] = &queryField{
|
|
field: &clone,
|
|
collection: collection,
|
|
original: field,
|
|
}
|
|
continue
|
|
}
|
|
|
|
if fieldName == schema.FieldNameId {
|
|
// convert to relation since it is a direct id reference
|
|
result[col.alias] = &queryField{
|
|
field: &schema.SchemaField{
|
|
Name: col.alias,
|
|
Type: schema.FieldTypeRelation,
|
|
Options: &schema.RelationOptions{
|
|
MaxSelect: types.Pointer(1),
|
|
CollectionId: collection.Id,
|
|
},
|
|
},
|
|
collection: collection,
|
|
}
|
|
} else if fieldName == schema.FieldNameCreated || fieldName == schema.FieldNameUpdated {
|
|
result[col.alias] = &queryField{
|
|
field: &schema.SchemaField{
|
|
Name: col.alias,
|
|
Type: schema.FieldTypeDate,
|
|
},
|
|
collection: collection,
|
|
}
|
|
} else if fieldName == schema.FieldNameUsername && collection.IsAuth() {
|
|
result[col.alias] = &queryField{
|
|
field: &schema.SchemaField{
|
|
Name: col.alias,
|
|
Type: schema.FieldTypeText,
|
|
},
|
|
collection: collection,
|
|
}
|
|
} else if fieldName == schema.FieldNameEmail && collection.IsAuth() {
|
|
result[col.alias] = &queryField{
|
|
field: &schema.SchemaField{
|
|
Name: col.alias,
|
|
Type: schema.FieldTypeEmail,
|
|
},
|
|
collection: collection,
|
|
}
|
|
} else if (fieldName == schema.FieldNameVerified || fieldName == schema.FieldNameEmailVisibility) && collection.IsAuth() {
|
|
result[col.alias] = &queryField{
|
|
field: &schema.SchemaField{
|
|
Name: col.alias,
|
|
Type: schema.FieldTypeBool,
|
|
},
|
|
collection: collection,
|
|
}
|
|
} else {
|
|
result[col.alias] = &queryField{
|
|
field: defaultViewField(col.alias),
|
|
collection: collection,
|
|
}
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (dao *Dao) findCollectionsByIdentifiers(tables []identifier) (map[string]*models.Collection, error) {
|
|
names := make([]any, 0, len(tables))
|
|
|
|
for _, table := range tables {
|
|
if strings.Contains(table.alias, "(") {
|
|
continue // skip expressions
|
|
}
|
|
names = append(names, table.original)
|
|
}
|
|
|
|
if len(names) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
result := make(map[string]*models.Collection, len(names))
|
|
collections := make([]*models.Collection, 0, len(names))
|
|
|
|
err := dao.CollectionQuery().
|
|
AndWhere(dbx.In("name", names...)).
|
|
All(&collections)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, table := range tables {
|
|
for _, collection := range collections {
|
|
if collection.Name == table.original {
|
|
result[table.alias] = collection
|
|
}
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// -------------------------------------------------------------------
|
|
// Raw query identifiers parser
|
|
// -------------------------------------------------------------------
|
|
|
|
var joinReplaceRegex = regexp.MustCompile(`(?im)\s+(inner join|outer join|left join|right join|join)\s+?`)
|
|
var discardReplaceRegex = regexp.MustCompile(`(?im)\s+(where|group by|having|order|limit|with)\s+?`)
|
|
var commentsReplaceRegex = regexp.MustCompile(`(?m)(\/\*[\s\S]+\*\/)|(--.+$)`)
|
|
|
|
type identifier struct {
|
|
original string
|
|
alias string
|
|
}
|
|
|
|
type identifiersParser struct {
|
|
columns []identifier
|
|
tables []identifier
|
|
}
|
|
|
|
func (p *identifiersParser) parse(selectQuery string) error {
|
|
str := strings.Trim(selectQuery, ";")
|
|
str = joinReplaceRegex.ReplaceAllString(str, " _join_ ")
|
|
str = discardReplaceRegex.ReplaceAllString(str, " _discard_ ")
|
|
str = commentsReplaceRegex.ReplaceAllString(str, "")
|
|
|
|
tk := tokenizer.NewFromString(str)
|
|
tk.Separators(',', ' ', '\n', '\t')
|
|
tk.KeepSeparator(true)
|
|
|
|
var skip bool
|
|
var partType string
|
|
var activeBuilder *strings.Builder
|
|
var selectParts strings.Builder
|
|
var fromParts strings.Builder
|
|
var joinParts strings.Builder
|
|
|
|
for {
|
|
token, err := tk.Scan()
|
|
if err != nil {
|
|
if err != io.EOF {
|
|
return err
|
|
}
|
|
break
|
|
}
|
|
|
|
trimmed := strings.ToLower(strings.TrimSpace(token))
|
|
|
|
switch trimmed {
|
|
case "select":
|
|
skip = false
|
|
partType = "select"
|
|
activeBuilder = &selectParts
|
|
case "distinct":
|
|
continue // ignore as it is not important for the identifiers parsing
|
|
case "from":
|
|
skip = false
|
|
partType = "from"
|
|
activeBuilder = &fromParts
|
|
case "_join_":
|
|
skip = false
|
|
|
|
// the previous part was also a join
|
|
if partType == "join" {
|
|
joinParts.WriteString(",")
|
|
}
|
|
|
|
partType = "join"
|
|
activeBuilder = &joinParts
|
|
case "_discard_":
|
|
// skip following tokens
|
|
skip = true
|
|
default:
|
|
isJoin := partType == "join"
|
|
|
|
if isJoin && trimmed == "on" {
|
|
skip = true
|
|
}
|
|
|
|
if !skip && activeBuilder != nil {
|
|
activeBuilder.WriteString(" ")
|
|
activeBuilder.WriteString(token)
|
|
}
|
|
}
|
|
}
|
|
|
|
selects, err := extractIdentifiers(selectParts.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
froms, err := extractIdentifiers(fromParts.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
joins, err := extractIdentifiers(joinParts.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
p.columns = selects
|
|
p.tables = froms
|
|
p.tables = append(p.tables, joins...)
|
|
|
|
return nil
|
|
}
|
|
|
|
func extractIdentifiers(rawExpression string) ([]identifier, error) {
|
|
rawTk := tokenizer.NewFromString(rawExpression)
|
|
rawTk.Separators(',')
|
|
|
|
rawIdentifiers, err := rawTk.ScanAll()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := make([]identifier, 0, len(rawIdentifiers))
|
|
|
|
for _, rawIdentifier := range rawIdentifiers {
|
|
tk := tokenizer.NewFromString(rawIdentifier)
|
|
tk.Separators(' ', '\n', '\t')
|
|
|
|
parts, err := tk.ScanAll()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resolved, err := identifierFromParts(parts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result = append(result, resolved)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func identifierFromParts(parts []string) (identifier, error) {
|
|
var result identifier
|
|
|
|
switch len(parts) {
|
|
case 3:
|
|
if !strings.EqualFold(parts[1], "as") {
|
|
return result, fmt.Errorf(`invalid identifier part - expected "as", got %v`, parts[1])
|
|
}
|
|
|
|
result.original = parts[0]
|
|
result.alias = parts[2]
|
|
case 2:
|
|
result.original = parts[0]
|
|
result.alias = parts[1]
|
|
case 1:
|
|
subParts := strings.Split(parts[0], ".")
|
|
result.original = parts[0]
|
|
result.alias = subParts[len(subParts)-1]
|
|
default:
|
|
return result, fmt.Errorf(`invalid identifier parts %v`, parts)
|
|
}
|
|
|
|
result.original = trimRawIdentifier(result.original)
|
|
result.alias = trimRawIdentifier(result.alias)
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func trimRawIdentifier(rawIdentifier string) string {
|
|
const trimChars = "`\"[];"
|
|
|
|
parts := strings.Split(rawIdentifier, ".")
|
|
|
|
for i := range parts {
|
|
parts[i] = strings.Trim(parts[i], trimChars)
|
|
}
|
|
|
|
return strings.Join(parts, ".")
|
|
}
|