You've already forked pocketbase
							
							
				mirror of
				https://github.com/pocketbase/pocketbase.git
				synced 2025-10-31 16:47:43 +02:00 
			
		
		
		
	added support to filter request.user.profile relation fields
This commit is contained in:
		| @@ -52,10 +52,9 @@ func (api *recordApi) list(c echo.Context) error { | ||||
| 		return rest.NewForbiddenError("Only admins can perform this action.", nil) | ||||
| 	} | ||||
|  | ||||
| 	// forbid user/guest defined non-relational joins (aka. @collection.*) | ||||
| 	queryStr := c.QueryString() | ||||
| 	if admin == nil && queryStr != "" && (strings.Contains(queryStr, "@collection") || strings.Contains(queryStr, "%40collection")) { | ||||
| 		return rest.NewForbiddenError("Only admins can filter by @collection.", nil) | ||||
| 	// forbid users and guests to query special filter/sort fields | ||||
| 	if err := api.checkForForbiddenQueryFields(c); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	requestData := api.exportRequestData(c) | ||||
| @@ -63,14 +62,15 @@ func (api *recordApi) list(c echo.Context) error { | ||||
| 	fieldsResolver := resolvers.NewRecordFieldResolver(api.app.Dao(), collection, requestData) | ||||
|  | ||||
| 	searchProvider := search.NewProvider(fieldsResolver). | ||||
| 		Query(api.app.Dao().RecordQuery(collection)) | ||||
| 		Query(api.app.Dao().RecordQuery(collection)). | ||||
| 		CountColumn(fmt.Sprintf("%s.id", api.app.Dao().DB().QuoteSimpleColumnName(collection.Name))) | ||||
|  | ||||
| 	if admin == nil && collection.ListRule != nil { | ||||
| 		searchProvider.AddFilter(search.FilterData(*collection.ListRule)) | ||||
| 	} | ||||
|  | ||||
| 	var rawRecords = []dbx.NullStringMap{} | ||||
| 	result, err := searchProvider.ParseAndExec(queryStr, &rawRecords) | ||||
| 	result, err := searchProvider.ParseAndExec(c.QueryString(), &rawRecords) | ||||
| 	if err != nil { | ||||
| 		return rest.NewBadRequestError("Invalid filter parameters.", err) | ||||
| 	} | ||||
| @@ -407,6 +407,24 @@ func (api *recordApi) exportRequestData(c echo.Context) map[string]any { | ||||
| 	return result | ||||
| } | ||||
|  | ||||
| func (api *recordApi) checkForForbiddenQueryFields(c echo.Context) error { | ||||
| 	admin, _ := c.Get(ContextAdminKey).(*models.Admin) | ||||
| 	if admin != nil { | ||||
| 		return nil // admins are allowed to query everything | ||||
| 	} | ||||
|  | ||||
| 	decodedQuery := c.QueryParam(search.FilterQueryParam) + c.QueryParam(search.SortQueryParam) | ||||
| 	forbiddenFields := []string{"@collection.", "@request."} | ||||
|  | ||||
| 	for _, field := range forbiddenFields { | ||||
| 		if strings.Contains(decodedQuery, field) { | ||||
| 			return rest.NewForbiddenError("Only admins can filter by @collection and @request query params", nil) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (api *recordApi) expandFunc(c echo.Context, requestData map[string]any) daos.ExpandFetchFunc { | ||||
| 	admin, _ := c.Get(ContextAdminKey).(*models.Admin) | ||||
|  | ||||
|   | ||||
| @@ -20,6 +20,7 @@ import ( | ||||
| var _ search.FieldResolver = (*RecordFieldResolver)(nil) | ||||
|  | ||||
| type join struct { | ||||
| 	id    string | ||||
| 	table string | ||||
| 	on    dbx.Expression | ||||
| } | ||||
| @@ -36,7 +37,7 @@ type RecordFieldResolver struct { | ||||
| 	baseCollection    *models.Collection | ||||
| 	allowedFields     []string | ||||
| 	requestData       map[string]any | ||||
| 	joins             map[string]join | ||||
| 	joins             []join // we cannot use a map because the insertion order is not preserved | ||||
| 	loadedCollections []*models.Collection | ||||
| } | ||||
|  | ||||
| @@ -50,7 +51,7 @@ func NewRecordFieldResolver( | ||||
| 		dao:               dao, | ||||
| 		baseCollection:    baseCollection, | ||||
| 		requestData:       requestData, | ||||
| 		joins:             make(map[string]join), | ||||
| 		joins:             []join{}, | ||||
| 		loadedCollections: []*models.Collection{baseCollection}, | ||||
| 		allowedFields: []string{ | ||||
| 			`^\w+[\w\.]*$`, | ||||
| @@ -85,6 +86,7 @@ func (r *RecordFieldResolver) UpdateQuery(query *dbx.SelectQuery) error { | ||||
| //	id | ||||
| //	project.screen.status | ||||
| //	@request.status | ||||
| //	@request.user.profile.someRelation.name | ||||
| //	@collection.product.name | ||||
| func (r *RecordFieldResolver) Resolve(fieldName string) (resultName string, placeholderParams dbx.Params, err error) { | ||||
| 	if len(r.allowedFields) > 0 && !list.ExistInSliceWithRegex(fieldName, r.allowedFields) { | ||||
| @@ -93,15 +95,6 @@ func (r *RecordFieldResolver) Resolve(fieldName string) (resultName string, plac | ||||
|  | ||||
| 	props := strings.Split(fieldName, ".") | ||||
|  | ||||
| 	// check for @request field | ||||
| 	if props[0] == "@request" { | ||||
| 		if len(props) == 1 { | ||||
| 			return "", nil, fmt.Errorf("Invalid @request data field path in %q.", fieldName) | ||||
| 		} | ||||
|  | ||||
| 		return r.resolveRequestField(props[1:]...) | ||||
| 	} | ||||
|  | ||||
| 	currentCollectionName := r.baseCollection.Name | ||||
| 	currentTableAlias := currentCollectionName | ||||
|  | ||||
| @@ -113,16 +106,54 @@ func (r *RecordFieldResolver) Resolve(fieldName string) (resultName string, plac | ||||
| 		} | ||||
|  | ||||
| 		currentCollectionName = props[1] | ||||
| 		currentTableAlias = "c_" + currentCollectionName | ||||
| 		currentTableAlias = "__collection_" + currentCollectionName | ||||
|  | ||||
| 		collection, err := r.loadCollection(currentCollectionName) | ||||
| 		if err != nil { | ||||
| 			return "", nil, fmt.Errorf("Failed to load collection %q from field path %q.", currentCollectionName, fieldName) | ||||
| 		} | ||||
|  | ||||
| 		r.addJoin(collection.Name, currentTableAlias, "", "", "") | ||||
| 		r.addJoin(collection.Name, currentTableAlias, nil) | ||||
|  | ||||
| 		props = props[2:] // leave only the collection fields | ||||
| 	} else if props[0] == "@request" { | ||||
| 		// check for @request field | ||||
| 		if len(props) == 1 { | ||||
| 			return "", nil, fmt.Errorf("Invalid @request data field path in %q.", fieldName) | ||||
| 		} | ||||
|  | ||||
| 		// not a profile relational field | ||||
| 		if len(props) <= 4 || !strings.HasPrefix(fieldName, "@request.user.profile.") { | ||||
| 			return r.resolveStaticRequestField(props[1:]...) | ||||
| 		} | ||||
|  | ||||
| 		// resolve the profile collection fields | ||||
| 		currentCollectionName = models.ProfileCollectionName | ||||
| 		currentTableAlias = "__user_" + currentCollectionName | ||||
|  | ||||
| 		collection, err := r.loadCollection(currentCollectionName) | ||||
| 		if err != nil { | ||||
| 			return "", nil, fmt.Errorf("Failed to load collection %q from field path %q.", currentCollectionName, fieldName) | ||||
| 		} | ||||
|  | ||||
| 		profileIdPlaceholder, profileIdPlaceholderParam, err := r.resolveStaticRequestField("user", "profile", "id") | ||||
| 		if err != nil { | ||||
| 			return "", nil, fmt.Errorf("Failed to resolve @request.user.profile.id path in %q.", fieldName) | ||||
| 		} | ||||
| 		if strings.ToLower(profileIdPlaceholder) == "null" { | ||||
| 			// the user doesn't have an associated profile | ||||
| 			return "NULL", nil, nil | ||||
| 		} | ||||
|  | ||||
| 		// join the profile collection | ||||
| 		r.addJoin(collection.Name, currentTableAlias, dbx.NewExp(fmt.Sprintf( | ||||
| 			// aka. profiles.id = profileId | ||||
| 			"[[%s.id]] = %s", | ||||
| 			inflector.Columnify(currentTableAlias), | ||||
| 			profileIdPlaceholder, | ||||
| 		), profileIdPlaceholderParam)) | ||||
|  | ||||
| 		props = props[3:] // leave only the profile fields | ||||
| 	} | ||||
|  | ||||
| 	baseModelFields := schema.ReservedFieldNames() | ||||
| @@ -173,9 +204,14 @@ func (r *RecordFieldResolver) Resolve(fieldName string) (resultName string, plac | ||||
| 		r.addJoin( | ||||
| 			newCollectionName, | ||||
| 			newTableAlias, | ||||
| 			"id", | ||||
| 			currentTableAlias, | ||||
| 			field.Name, | ||||
| 			dbx.NewExp(fmt.Sprintf( | ||||
| 				// 'LIKE' expr is used to handle the case when the reference field supports multiple values (aka. is json array) | ||||
| 				"[[%s.%s]] LIKE ('%%' || [[%s.%s]] || '%%')", | ||||
| 				inflector.Columnify(currentTableAlias), | ||||
| 				inflector.Columnify(field.Name), | ||||
| 				inflector.Columnify(newTableAlias), | ||||
| 				inflector.Columnify("id"), | ||||
| 			)), | ||||
| 		) | ||||
|  | ||||
| 		currentCollectionName = newCollectionName | ||||
| @@ -185,7 +221,7 @@ func (r *RecordFieldResolver) Resolve(fieldName string) (resultName string, plac | ||||
| 	return "", nil, fmt.Errorf("Failed to resolve field %q.", fieldName) | ||||
| } | ||||
|  | ||||
| func (r *RecordFieldResolver) resolveRequestField(path ...string) (resultName string, placeholderParams dbx.Params, err error) { | ||||
| func (r *RecordFieldResolver) resolveStaticRequestField(path ...string) (resultName string, placeholderParams dbx.Params, err error) { | ||||
| 	// ignore error because requestData is dynamic and some of the | ||||
| 	// lookup keys may not be defined for the request | ||||
| 	resultVal, _ := extractNestedMapVal(r.requestData, path...) | ||||
| @@ -259,24 +295,27 @@ func (r *RecordFieldResolver) loadCollection(collectionNameOrId string) (*models | ||||
| 	return collection, nil | ||||
| } | ||||
|  | ||||
| func (r *RecordFieldResolver) addJoin(tableName, tableAlias, fieldName, ref, refFieldName string) { | ||||
| 	table := fmt.Sprintf( | ||||
| func (r *RecordFieldResolver) addJoin(tableName string, tableAlias string, on dbx.Expression) { | ||||
| 	tableExpr := fmt.Sprintf( | ||||
| 		"%s %s", | ||||
| 		inflector.Columnify(tableName), | ||||
| 		inflector.Columnify(tableAlias), | ||||
| 	) | ||||
|  | ||||
| 	var on dbx.Expression | ||||
| 	if ref != "" { | ||||
| 		on = dbx.NewExp(fmt.Sprintf( | ||||
| 			// 'LIKE' expr is used to handle the case when the reference field supports multiple values (aka. is json array) | ||||
| 			"[[%s.%s]] LIKE ('%%' || [[%s.%s]] || '%%')", | ||||
| 			inflector.Columnify(ref), | ||||
| 			inflector.Columnify(refFieldName), | ||||
| 			inflector.Columnify(tableAlias), | ||||
| 			inflector.Columnify(fieldName), | ||||
| 		)) | ||||
| 	join := join{ | ||||
| 		id:    tableAlias, | ||||
| 		table: tableExpr, | ||||
| 		on:    on, | ||||
| 	} | ||||
|  | ||||
| 	r.joins[tableAlias] = join{table, on} | ||||
| 	// replace existing join | ||||
| 	for i, j := range r.joins { | ||||
| 		if j.id == join.id { | ||||
| 			r.joins[i] = join | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// register new join | ||||
| 	r.joins = append(r.joins, join) | ||||
| } | ||||
|   | ||||
| @@ -2,11 +2,12 @@ package resolvers_test | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"strings" | ||||
| 	"regexp" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/pocketbase/pocketbase/resolvers" | ||||
| 	"github.com/pocketbase/pocketbase/tests" | ||||
| 	"github.com/pocketbase/pocketbase/tools/list" | ||||
| ) | ||||
|  | ||||
| func TestRecordFieldResolverUpdateQuery(t *testing.T) { | ||||
| @@ -18,54 +19,94 @@ func TestRecordFieldResolverUpdateQuery(t *testing.T) { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	requestData := map[string]any{ | ||||
| 		"user": map[string]any{ | ||||
| 			"id": "4d0197cc-2b4a-3f83-a26b-d77bc8423d3c", | ||||
| 			"profile": map[string]any{ | ||||
| 				"id":   "d13f60a4-5765-48c7-9e1d-3e782340f833", | ||||
| 				"name": "test", | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	scenarios := []struct { | ||||
| 		fieldName        string | ||||
| 		expectQueryParts []string // we are matching parts of the query | ||||
| 		// since joins are added with map iteration and the order is not guaranteed | ||||
| 		name        string | ||||
| 		fields      []string | ||||
| 		expectQuery string | ||||
| 	}{ | ||||
| 		// missing field | ||||
| 		{"", []string{ | ||||
| 		{ | ||||
| 			"missing field", | ||||
| 			[]string{""}, | ||||
| 			"SELECT `demo4`.* FROM `demo4`", | ||||
| 		}}, | ||||
| 		// non relation field | ||||
| 		{"title", []string{ | ||||
| 		}, | ||||
| 		{ | ||||
| 			"non relation field", | ||||
| 			[]string{"title"}, | ||||
| 			"SELECT `demo4`.* FROM `demo4`", | ||||
| 		}}, | ||||
| 		// incomplete rel | ||||
| 		{"onerel", []string{ | ||||
| 		}, | ||||
| 		{ | ||||
| 			"incomplete rel", | ||||
| 			[]string{"onerel"}, | ||||
| 			"SELECT `demo4`.* FROM `demo4`", | ||||
| 		}}, | ||||
| 		// single rel | ||||
| 		{"onerel.title", []string{ | ||||
| 			"SELECT DISTINCT `demo4`.* FROM `demo4`", | ||||
| 			" LEFT JOIN `demo4` `demo4_onerel` ON [[demo4.onerel]] LIKE ('%' || [[demo4_onerel.id]] || '%')", | ||||
| 		}}, | ||||
| 		// nested incomplete rels | ||||
| 		{"manyrels.onerel", []string{ | ||||
| 			"SELECT DISTINCT `demo4`.* FROM `demo4`", | ||||
| 			" LEFT JOIN `demo4` `demo4_manyrels` ON [[demo4.manyrels]] LIKE ('%' || [[demo4_manyrels.id]] || '%')", | ||||
| 		}}, | ||||
| 		// nested complete rels | ||||
| 		{"manyrels.onerel.title", []string{ | ||||
| 			"SELECT DISTINCT `demo4`.* FROM `demo4`", | ||||
| 			" LEFT JOIN `demo4` `demo4_manyrels` ON [[demo4.manyrels]] LIKE ('%' || [[demo4_manyrels.id]] || '%')", | ||||
| 			" LEFT JOIN `demo4` `demo4_manyrels_onerel` ON [[demo4_manyrels.onerel]] LIKE ('%' || [[demo4_manyrels_onerel.id]] || '%')", | ||||
| 		}}, | ||||
| 		// // repeated nested rels | ||||
| 		{"manyrels.onerel.manyrels.onerel.title", []string{ | ||||
| 			"SELECT DISTINCT `demo4`.* FROM `demo4`", | ||||
| 			" LEFT JOIN `demo4` `demo4_manyrels` ON [[demo4.manyrels]] LIKE ('%' || [[demo4_manyrels.id]] || '%')", | ||||
| 			" LEFT JOIN `demo4` `demo4_manyrels_onerel` ON [[demo4_manyrels.onerel]] LIKE ('%' || [[demo4_manyrels_onerel.id]] || '%')", | ||||
| 			" LEFT JOIN `demo4` `demo4_manyrels_onerel_manyrels` ON [[demo4_manyrels_onerel.manyrels]] LIKE ('%' || [[demo4_manyrels_onerel_manyrels.id]] || '%')", | ||||
| 			" LEFT JOIN `demo4` `demo4_manyrels_onerel_manyrels_onerel` ON [[demo4_manyrels_onerel_manyrels.onerel]] LIKE ('%' || [[demo4_manyrels_onerel_manyrels_onerel.id]] || '%')", | ||||
| 		}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"single rel", | ||||
| 			[]string{"onerel.title"}, | ||||
| 			"SELECT DISTINCT `demo4`.* FROM `demo4` LEFT JOIN `demo4` `demo4_onerel` ON [[demo4.onerel]] LIKE ('%' || [[demo4_onerel.id]] || '%')", | ||||
| 		}, | ||||
| 		{ | ||||
| 			"non-relation field + single rel", | ||||
| 			[]string{"title", "onerel.title"}, | ||||
| 			"SELECT DISTINCT `demo4`.* FROM `demo4` LEFT JOIN `demo4` `demo4_onerel` ON [[demo4.onerel]] LIKE ('%' || [[demo4_onerel.id]] || '%')", | ||||
| 		}, | ||||
| 		{ | ||||
| 			"nested incomplete rels", | ||||
| 			[]string{"manyrels.onerel"}, | ||||
| 			"SELECT DISTINCT `demo4`.* FROM `demo4` LEFT JOIN `demo4` `demo4_manyrels` ON [[demo4.manyrels]] LIKE ('%' || [[demo4_manyrels.id]] || '%')", | ||||
| 		}, | ||||
| 		{ | ||||
| 			"nested complete rels", | ||||
| 			[]string{"manyrels.onerel.title"}, | ||||
| 			"SELECT DISTINCT `demo4`.* FROM `demo4` LEFT JOIN `demo4` `demo4_manyrels` ON [[demo4.manyrels]] LIKE ('%' || [[demo4_manyrels.id]] || '%') LEFT JOIN `demo4` `demo4_manyrels_onerel` ON [[demo4_manyrels.onerel]] LIKE ('%' || [[demo4_manyrels_onerel.id]] || '%')", | ||||
| 		}, | ||||
| 		{ | ||||
| 			"repeated nested rels", | ||||
| 			[]string{"manyrels.onerel.manyrels.onerel.title"}, | ||||
| 			"SELECT DISTINCT `demo4`.* FROM `demo4` LEFT JOIN `demo4` `demo4_manyrels` ON [[demo4.manyrels]] LIKE ('%' || [[demo4_manyrels.id]] || '%') LEFT JOIN `demo4` `demo4_manyrels_onerel` ON [[demo4_manyrels.onerel]] LIKE ('%' || [[demo4_manyrels_onerel.id]] || '%') LEFT JOIN `demo4` `demo4_manyrels_onerel_manyrels` ON [[demo4_manyrels_onerel.manyrels]] LIKE ('%' || [[demo4_manyrels_onerel_manyrels.id]] || '%') LEFT JOIN `demo4` `demo4_manyrels_onerel_manyrels_onerel` ON [[demo4_manyrels_onerel_manyrels.onerel]] LIKE ('%' || [[demo4_manyrels_onerel_manyrels_onerel.id]] || '%')", | ||||
| 		}, | ||||
| 		{ | ||||
| 			"multiple rels", | ||||
| 			[]string{"manyrels.title", "onerel.onefile"}, | ||||
| 			"SELECT DISTINCT `demo4`.* FROM `demo4` LEFT JOIN `demo4` `demo4_manyrels` ON [[demo4.manyrels]] LIKE ('%' || [[demo4_manyrels.id]] || '%') LEFT JOIN `demo4` `demo4_onerel` ON [[demo4.onerel]] LIKE ('%' || [[demo4_onerel.id]] || '%')", | ||||
| 		}, | ||||
| 		{ | ||||
| 			"@collection join", | ||||
| 			[]string{"@collection.demo.title", "@collection.demo2.text", "@collection.demo.file"}, | ||||
| 			"SELECT DISTINCT `demo4`.* FROM `demo4` LEFT JOIN `demo` `__collection_demo` LEFT JOIN `demo2` `__collection_demo2`", | ||||
| 		}, | ||||
| 		{ | ||||
| 			"static @request.user.profile fields", | ||||
| 			[]string{"@request.user.id", "@request.user.profile.id", "@request.data.demo"}, | ||||
| 			"SELECT `demo4`.* FROM `demo4`", | ||||
| 		}, | ||||
| 		{ | ||||
| 			"relational @request.user.profile fields", | ||||
| 			[]string{"@request.user.profile.rel.id", "@request.user.profile.rel.name"}, | ||||
| 			"^" + | ||||
| 				regexp.QuoteMeta("SELECT DISTINCT `demo4`.* FROM `demo4` LEFT JOIN `profiles` `__user_profiles` ON [[__user_profiles.id]] =") + | ||||
| 				" {:.*} " + | ||||
| 				regexp.QuoteMeta("LEFT JOIN `profiles` `__user_profiles_rel` ON [[__user_profiles.rel]] LIKE ('%' || [[__user_profiles_rel.id]] || '%')") + | ||||
| 				"$", | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for i, s := range scenarios { | ||||
| 		query := app.Dao().RecordQuery(collection) | ||||
|  | ||||
| 		r := resolvers.NewRecordFieldResolver(app.Dao(), collection, nil) | ||||
| 		r.Resolve(s.fieldName) | ||||
| 		r := resolvers.NewRecordFieldResolver(app.Dao(), collection, requestData) | ||||
| 		for _, field := range s.fields { | ||||
| 			r.Resolve(field) | ||||
| 		} | ||||
|  | ||||
| 		if err := r.UpdateQuery(query); err != nil { | ||||
| 			t.Errorf("(%d) UpdateQuery failed with error %v", i, err) | ||||
| @@ -74,16 +115,8 @@ func TestRecordFieldResolverUpdateQuery(t *testing.T) { | ||||
|  | ||||
| 		rawQuery := query.Build().SQL() | ||||
|  | ||||
| 		partsLength := 0 | ||||
| 		for _, part := range s.expectQueryParts { | ||||
| 			partsLength += len(part) | ||||
| 			if !strings.Contains(rawQuery, part) { | ||||
| 				t.Errorf("(%d) Part %v is missing from query \n%v", i, part, rawQuery) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		if partsLength != len(rawQuery) { | ||||
| 			t.Errorf("(%d) Expected %d characters, got %d in \n%v", i, partsLength, len(rawQuery), rawQuery) | ||||
| 		if !list.ExistInSliceWithRegex(rawQuery, []string{s.expectQuery}) { | ||||
| 			t.Errorf("(%d) Expected query\n %v \ngot:\n %v", i, s.expectQuery, rawQuery) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -97,7 +130,16 @@ func TestRecordFieldResolverResolveSchemaFields(t *testing.T) { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	r := resolvers.NewRecordFieldResolver(app.Dao(), collection, nil) | ||||
| 	requestData := map[string]any{ | ||||
| 		"user": map[string]any{ | ||||
| 			"id": "4d0197cc-2b4a-3f83-a26b-d77bc8423d3c", | ||||
| 			"profile": map[string]any{ | ||||
| 				"id": "d13f60a4-5765-48c7-9e1d-3e782340f833", | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	r := resolvers.NewRecordFieldResolver(app.Dao(), collection, requestData) | ||||
|  | ||||
| 	scenarios := []struct { | ||||
| 		fieldName   string | ||||
| @@ -118,42 +160,45 @@ func TestRecordFieldResolverResolveSchemaFields(t *testing.T) { | ||||
| 		{"manyrels.unknown", true, ""}, | ||||
| 		{"manyrels.title", false, "[[demo4_manyrels.title]]"}, | ||||
| 		{"manyrels.onerel.manyrels.onefile", false, "[[demo4_manyrels_onerel_manyrels.onefile]]"}, | ||||
| 		// @request.user.profile relation join: | ||||
| 		{"@request.user.profile.rel.name", false, "[[__user_profiles_rel.name]]"}, | ||||
| 		// @collection fieds: | ||||
| 		{"@collect", true, ""}, | ||||
| 		{"collection.demo4.title", true, ""}, | ||||
| 		{"@collection", true, ""}, | ||||
| 		{"@collection.unknown", true, ""}, | ||||
| 		{"@collection.demo", true, ""}, | ||||
| 		{"@collection.demo.", true, ""}, | ||||
| 		{"@collection.demo.title", false, "[[c_demo.title]]"}, | ||||
| 		{"@collection.demo4.title", false, "[[c_demo4.title]]"}, | ||||
| 		{"@collection.demo4.id", false, "[[c_demo4.id]]"}, | ||||
| 		{"@collection.demo4.created", false, "[[c_demo4.created]]"}, | ||||
| 		{"@collection.demo4.updated", false, "[[c_demo4.updated]]"}, | ||||
| 		{"@collection.demo.title", false, "[[__collection_demo.title]]"}, | ||||
| 		{"@collection.demo4.title", false, "[[__collection_demo4.title]]"}, | ||||
| 		{"@collection.demo4.id", false, "[[__collection_demo4.id]]"}, | ||||
| 		{"@collection.demo4.created", false, "[[__collection_demo4.created]]"}, | ||||
| 		{"@collection.demo4.updated", false, "[[__collection_demo4.updated]]"}, | ||||
| 		{"@collection.demo4.manyrels.missing", true, ""}, | ||||
| 		{"@collection.demo4.manyrels.onerel.manyrels.onerel.onefile", false, "[[c_demo4_manyrels_onerel_manyrels_onerel.onefile]]"}, | ||||
| 		{"@collection.demo4.manyrels.onerel.manyrels.onerel.onefile", false, "[[__collection_demo4_manyrels_onerel_manyrels_onerel.onefile]]"}, | ||||
| 	} | ||||
|  | ||||
| 	for i, s := range scenarios { | ||||
| 	for _, s := range scenarios { | ||||
| 		name, params, err := r.Resolve(s.fieldName) | ||||
|  | ||||
| 		hasErr := err != nil | ||||
| 		if hasErr != s.expectError { | ||||
| 			t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err) | ||||
| 			t.Errorf("(%q) Expected hasErr %v, got %v (%v)", s.fieldName, s.expectError, hasErr, err) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if name != s.expectName { | ||||
| 			t.Errorf("(%d) Expected name %q, got %q", i, s.expectName, name) | ||||
| 			t.Errorf("(%q) Expected name %q, got %q", s.fieldName, s.expectName, name) | ||||
| 		} | ||||
|  | ||||
| 		// params should be empty for non @request fields | ||||
| 		if len(params) != 0 { | ||||
| 			t.Errorf("(%d) Expected 0 params, got %v", i, params) | ||||
| 			t.Errorf("(%q) Expected 0 params, got %v", s.fieldName, params) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRecordFieldResolverResolveRequestDataFields(t *testing.T) { | ||||
| func TestRecordFieldResolverResolveStaticRequestDataFields(t *testing.T) { | ||||
| 	app, _ := tests.NewTestApp() | ||||
| 	defer app.Cleanup() | ||||
|  | ||||
| @@ -171,7 +216,13 @@ func TestRecordFieldResolverResolveRequestDataFields(t *testing.T) { | ||||
| 			"b": 456, | ||||
| 			"c": map[string]int{"sub": 1}, | ||||
| 		}, | ||||
| 		"user": nil, | ||||
| 		"user": map[string]any{ | ||||
| 			"id": "4d0197cc-2b4a-3f83-a26b-d77bc8423d3c", | ||||
| 			"profile": map[string]any{ | ||||
| 				"id":   "d13f60a4-5765-48c7-9e1d-3e782340f833", | ||||
| 				"name": "test", | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	r := resolvers.NewRecordFieldResolver(app.Dao(), collection, requestData) | ||||
| @@ -194,7 +245,9 @@ func TestRecordFieldResolverResolveRequestDataFields(t *testing.T) { | ||||
| 		{"@request.data.b.missing", false, ``}, | ||||
| 		{"@request.data.c", false, `"{\"sub\":1}"`}, | ||||
| 		{"@request.user", true, ""}, | ||||
| 		{"@request.user.id", false, ""}, | ||||
| 		{"@request.user.id", false, `"4d0197cc-2b4a-3f83-a26b-d77bc8423d3c"`}, | ||||
| 		{"@request.user.profile", false, `"{\"id\":\"d13f60a4-5765-48c7-9e1d-3e782340f833\",\"name\":\"test\"}"`}, | ||||
| 		{"@request.user.profile.name", false, `"test"`}, | ||||
| 	} | ||||
|  | ||||
| 	for i, s := range scenarios { | ||||
|   | ||||
										
											Binary file not shown.
										
									
								
							| @@ -37,6 +37,7 @@ type Provider struct { | ||||
| 	query         *dbx.SelectQuery | ||||
| 	page          int | ||||
| 	perPage       int | ||||
| 	countColumn   string | ||||
| 	sort          []SortField | ||||
| 	filter        []FilterData | ||||
| } | ||||
| @@ -67,6 +68,13 @@ func (s *Provider) Query(query *dbx.SelectQuery) *Provider { | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| // CountColumn specifies an optional distinct column to use in the | ||||
| // SELECT COUNT query. | ||||
| func (s *Provider) CountColumn(countColumn string) *Provider { | ||||
| 	s.countColumn = countColumn | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| // Page sets the `page` field of the current search provider. | ||||
| // | ||||
| // Normalization on the `page` value is done during `Exec()`. | ||||
| @@ -190,7 +198,11 @@ func (s *Provider) Exec(items any) (*Result, error) { | ||||
| 	// count | ||||
| 	var totalCount int64 | ||||
| 	countQuery := modelsQuery | ||||
| 	if err := countQuery.Select("count(*)").Row(&totalCount); err != nil { | ||||
| 	countQuery.Distinct(false).Select("COUNT(*)") | ||||
| 	if s.countColumn != "" { | ||||
| 		countQuery.Select("COUNT(DISTINCT(" + s.countColumn + "))") | ||||
| 	} | ||||
| 	if err := countQuery.Row(&totalCount); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
|   | ||||
| @@ -60,6 +60,15 @@ func TestProviderPerPage(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestProviderCountColumn(t *testing.T) { | ||||
| 	r := &testFieldResolver{} | ||||
| 	p := NewProvider(r).CountColumn("test") | ||||
|  | ||||
| 	if p.countColumn != "test" { | ||||
| 		t.Fatalf("Expected distinct count column %v, got %v", "test", p.countColumn) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestProviderSort(t *testing.T) { | ||||
| 	initialSort := []SortField{{"test1", SortAsc}, {"test2", SortAsc}} | ||||
| 	r := &testFieldResolver{} | ||||
| @@ -214,6 +223,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) { | ||||
| 		perPage       int | ||||
| 		sort          []SortField | ||||
| 		filter        []FilterData | ||||
| 		countColumn   string | ||||
| 		expectError   bool | ||||
| 		expectResult  string | ||||
| 		expectQueries []string | ||||
| @@ -224,10 +234,11 @@ func TestProviderExecNonEmptyQuery(t *testing.T) { | ||||
| 			10, | ||||
| 			[]SortField{}, | ||||
| 			[]FilterData{}, | ||||
| 			"", | ||||
| 			false, | ||||
| 			`{"page":1,"perPage":10,"totalItems":2,"items":[{"test1":1,"test2":"test2.1","test3":""},{"test1":2,"test2":"test2.2","test3":""}]}`, | ||||
| 			[]string{ | ||||
| 				"SELECT count(*) FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC", | ||||
| 				"SELECT COUNT(*) FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC", | ||||
| 				"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 10", | ||||
| 			}, | ||||
| 		}, | ||||
| @@ -237,10 +248,11 @@ func TestProviderExecNonEmptyQuery(t *testing.T) { | ||||
| 			0,  // fallback to default | ||||
| 			[]SortField{}, | ||||
| 			[]FilterData{}, | ||||
| 			"", | ||||
| 			false, | ||||
| 			`{"page":1,"perPage":30,"totalItems":2,"items":[{"test1":1,"test2":"test2.1","test3":""},{"test1":2,"test2":"test2.2","test3":""}]}`, | ||||
| 			[]string{ | ||||
| 				"SELECT count(*) FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC", | ||||
| 				"SELECT COUNT(*) FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC", | ||||
| 				"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 30", | ||||
| 			}, | ||||
| 		}, | ||||
| @@ -250,6 +262,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) { | ||||
| 			10, | ||||
| 			[]SortField{{"unknown", SortAsc}}, | ||||
| 			[]FilterData{}, | ||||
| 			"", | ||||
| 			true, | ||||
| 			"", | ||||
| 			nil, | ||||
| @@ -260,6 +273,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) { | ||||
| 			10, | ||||
| 			[]SortField{}, | ||||
| 			[]FilterData{"test2 = 'test2.1'", "invalid"}, | ||||
| 			"", | ||||
| 			true, | ||||
| 			"", | ||||
| 			nil, | ||||
| @@ -270,10 +284,11 @@ func TestProviderExecNonEmptyQuery(t *testing.T) { | ||||
| 			5555, // will be limited by MaxPerPage | ||||
| 			[]SortField{{"test2", SortDesc}}, | ||||
| 			[]FilterData{"test2 != null", "test1 >= 2"}, | ||||
| 			"", | ||||
| 			false, | ||||
| 			`{"page":1,"perPage":` + fmt.Sprint(MaxPerPage) + `,"totalItems":1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`, | ||||
| 			[]string{ | ||||
| 				"SELECT count(*) FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (COALESCE(test2, '') != COALESCE(null, ''))) AND (test1 >= '2') ORDER BY `test1` ASC, `test2` DESC", | ||||
| 				"SELECT COUNT(*) FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (COALESCE(test2, '') != COALESCE(null, ''))) AND (test1 >= '2') ORDER BY `test1` ASC, `test2` DESC", | ||||
| 				"SELECT * FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (COALESCE(test2, '') != COALESCE(null, ''))) AND (test1 >= '2') ORDER BY `test1` ASC, `test2` DESC LIMIT 200", | ||||
| 			}, | ||||
| 		}, | ||||
| @@ -283,10 +298,11 @@ func TestProviderExecNonEmptyQuery(t *testing.T) { | ||||
| 			10, | ||||
| 			[]SortField{{"test3", SortAsc}}, | ||||
| 			[]FilterData{"test3 != ''"}, | ||||
| 			"", | ||||
| 			false, | ||||
| 			`{"page":1,"perPage":10,"totalItems":0,"items":[]}`, | ||||
| 			[]string{ | ||||
| 				"SELECT count(*) FROM `test` WHERE (NOT (`test1` IS NULL)) AND (COALESCE(test3, '') != COALESCE('', '')) ORDER BY `test1` ASC, `test3` ASC", | ||||
| 				"SELECT COUNT(*) FROM `test` WHERE (NOT (`test1` IS NULL)) AND (COALESCE(test3, '') != COALESCE('', '')) ORDER BY `test1` ASC, `test3` ASC", | ||||
| 				"SELECT * FROM `test` WHERE (NOT (`test1` IS NULL)) AND (COALESCE(test3, '') != COALESCE('', '')) ORDER BY `test1` ASC, `test3` ASC LIMIT 10", | ||||
| 			}, | ||||
| 		}, | ||||
| @@ -296,10 +312,25 @@ func TestProviderExecNonEmptyQuery(t *testing.T) { | ||||
| 			1, | ||||
| 			[]SortField{}, | ||||
| 			[]FilterData{}, | ||||
| 			"", | ||||
| 			false, | ||||
| 			`{"page":2,"perPage":1,"totalItems":2,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`, | ||||
| 			[]string{ | ||||
| 				"SELECT count(*) FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC", | ||||
| 				"SELECT COUNT(*) FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC", | ||||
| 				"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 1 OFFSET 1", | ||||
| 			}, | ||||
| 		}, | ||||
| 		// distinct count column | ||||
| 		{ | ||||
| 			3, | ||||
| 			1, | ||||
| 			[]SortField{}, | ||||
| 			[]FilterData{}, | ||||
| 			"test.test1", | ||||
| 			false, | ||||
| 			`{"page":2,"perPage":1,"totalItems":2,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`, | ||||
| 			[]string{ | ||||
| 				"SELECT COUNT(DISTINCT(test.test1)) FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC", | ||||
| 				"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 1 OFFSET 1", | ||||
| 			}, | ||||
| 		}, | ||||
| @@ -314,7 +345,8 @@ func TestProviderExecNonEmptyQuery(t *testing.T) { | ||||
| 			Page(s.page). | ||||
| 			PerPage(s.perPage). | ||||
| 			Sort(s.sort). | ||||
| 			Filter(s.filter) | ||||
| 			Filter(s.filter). | ||||
| 			CountColumn(s.countColumn) | ||||
|  | ||||
| 		result, err := p.Exec(&[]testTableStruct{}) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user