1
0
mirror of https://github.com/MontFerret/ferret.git synced 2025-01-18 03:22:02 +02:00

Merge pull request #396 from MontFerret/feature/#373-params-check

Added params check before execution
This commit is contained in:
Tim Voronov 2019-10-17 16:54:59 -04:00 committed by GitHub
commit 625f78fd98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 134 additions and 40 deletions

View File

@ -111,6 +111,38 @@ func TestParam(t *testing.T) {
)
So(string(out), ShouldEqual, `"baz"`)
})
Convey("Should return an error if param values are not passed", t, func() {
prog := compiler.New().
MustCompile(`
LET doc = { foo: { bar: "baz" } }
RETURN doc.@attr.@subattr
`)
_, err := prog.Run(
context.Background(),
)
So(err, ShouldNotBeNil)
So(err.Error(), ShouldContainSubstring, runtime.ErrMissedParam.Error())
})
Convey("Should be possible to use in member expression as segments", t, func() {
prog := compiler.New().
MustCompile(`
LET doc = { foo: { bar: "baz" } }
RETURN doc.@attr.@subattr
`)
_, err := prog.Run(
context.Background(),
runtime.WithParam("attr", "foo"),
)
So(err, ShouldNotBeNil)
So(err.Error(), ShouldContainSubstring, "subattr")
})
}

View File

@ -2,47 +2,56 @@ package compiler
import (
"github.com/MontFerret/ferret/pkg/runtime/core"
"github.com/MontFerret/ferret/pkg/runtime/values/types"
)
type (
globalScope struct {
params map[string]struct{}
}
scope struct {
global *globalScope
parent *scope
vars map[string]core.Type
vars map[string]struct{}
}
)
func newRootScope() *scope {
func newGlobalScope() *globalScope {
return &globalScope{
params: map[string]struct{}{},
}
}
func newRootScope(global *globalScope) *scope {
return &scope{
vars: make(map[string]core.Type),
global: global,
vars: make(map[string]struct{}),
}
}
func newScope(parent *scope) *scope {
s := newRootScope()
s := newRootScope(parent.global)
s.parent = parent
return s
}
func (s *scope) GetVariable(name string) (core.Type, error) {
local, exists := s.vars[name]
func (s *scope) AddParam(name string) {
s.global.params[name] = struct{}{}
}
func (s *scope) HasVariable(name string) bool {
_, exists := s.vars[name]
if exists {
return local, nil
return true
}
if s.parent != nil {
parents, err := s.parent.GetVariable(name)
if err != nil {
return types.None, err
}
return parents, nil
return s.parent.HasVariable(name)
}
return types.None, core.Error(ErrVariableNotFound, name)
return false
}
func (s *scope) SetVariable(name string) error {
@ -53,7 +62,7 @@ func (s *scope) SetVariable(name string) error {
}
// TODO: add type detection
s.vars[name] = types.None
s.vars[name] = struct{}{}
return nil
}
@ -71,7 +80,7 @@ func (s *scope) RemoveVariable(name string) error {
}
func (s *scope) ClearVariables() {
s.vars = make(map[string]core.Type)
s.vars = make(map[string]struct{})
}
func (s *scope) Fork() *scope {

View File

@ -38,14 +38,15 @@ func newVisitor(src string, funcs map[string]core.Function) *visitor {
func (v *visitor) VisitProgram(ctx *fql.ProgramContext) interface{} {
return newResultFrom(func() (interface{}, error) {
rootScope := newRootScope()
block, err := v.doVisitBody(ctx.Body().(*fql.BodyContext), rootScope)
gs := newGlobalScope()
rs := newRootScope(gs)
block, err := v.doVisitBody(ctx.Body().(*fql.BodyContext), rs)
if err != nil {
return nil, err
}
return runtime.NewProgram(v.src, block)
return runtime.NewProgram(v.src, block, gs.params)
})
}
@ -801,10 +802,8 @@ func (v *visitor) doVisitMember(ctx *fql.MemberContext, scope *scope) (core.Expr
if identifier != nil {
varName := ctx.Identifier().GetText()
_, err := scope.GetVariable(varName)
if err != nil {
return nil, err
if !scope.HasVariable(varName) {
return nil, core.Error(ErrVariableNotFound, varName)
}
exp, err := expressions.NewVariableExpression(v.getSourceMap(ctx), varName)
@ -927,10 +926,8 @@ func (v *visitor) doVisitComputedPropertyNameContext(ctx *fql.ComputedPropertyNa
func (v *visitor) doVisitShorthandPropertyNameContext(ctx *fql.ShorthandPropertyNameContext, scope *scope) (core.Expression, error) {
name := ctx.Variable().GetText()
_, err := scope.GetVariable(name)
if err != nil {
return nil, err
if !scope.HasVariable(name) {
return nil, core.Error(ErrVariableNotFound, name)
}
return literals.NewStringLiteral(ctx.Variable().GetText()), nil
@ -1005,10 +1002,8 @@ func (v *visitor) doVisitVariable(ctx *fql.VariableContext, scope *scope) (core.
name := ctx.Identifier().GetText()
// check whether the variable is defined
_, err := scope.GetVariable(name)
if err != nil {
return nil, err
if !scope.HasVariable(name) {
return nil, core.Error(ErrVariableNotFound, name)
}
return expressions.NewVariableExpression(v.getSourceMap(ctx), name)
@ -1117,9 +1112,11 @@ func (v *visitor) doVisitFunctionCallExpression(context *fql.FunctionCallExpress
)
}
func (v *visitor) doVisitParamContext(context *fql.ParamContext, _ *scope) (core.Expression, error) {
func (v *visitor) doVisitParamContext(context *fql.ParamContext, s *scope) (core.Expression, error) {
name := context.Identifier().GetText()
s.AddParam(name)
return expressions.NewParameterExpression(
v.getSourceMap(context),
name,

9
pkg/runtime/errors.go Normal file
View File

@ -0,0 +1,9 @@
package runtime
import (
"github.com/pkg/errors"
)
var (
ErrMissedParam = errors.New("missed value for parameter(s)")
)

View File

@ -3,6 +3,7 @@ package runtime
import (
"context"
"runtime"
"strings"
"github.com/MontFerret/ferret/pkg/runtime/core"
"github.com/MontFerret/ferret/pkg/runtime/logging"
@ -11,11 +12,12 @@ import (
)
type Program struct {
src string
body core.Expression
src string
body core.Expression
params map[string]struct{}
}
func NewProgram(src string, body core.Expression) (*Program, error) {
func NewProgram(src string, body core.Expression, params map[string]struct{}) (*Program, error) {
if src == "" {
return nil, core.Error(core.ErrMissedArgument, "source")
}
@ -24,16 +26,33 @@ func NewProgram(src string, body core.Expression) (*Program, error) {
return nil, core.Error(core.ErrMissedArgument, "body")
}
return &Program{src, body}, nil
return &Program{src, body, params}, nil
}
func (p *Program) Source() string {
return p.src
}
func (p *Program) Run(ctx context.Context, setters ...Option) (result []byte, err error) {
ctx = NewOptions(setters).WithContext(ctx)
func (p *Program) Params() []string {
res := make([]string, 0, len(p.params))
for name := range p.params {
res = append(res, name)
}
return res
}
func (p *Program) Run(ctx context.Context, setters ...Option) (result []byte, err error) {
opts := NewOptions(setters)
err = p.validateParams(opts)
if err != nil {
return nil, err
}
ctx = opts.WithContext(ctx)
logger := logging.FromContext(ctx)
defer func() {
@ -92,3 +111,31 @@ func (p *Program) MustRun(ctx context.Context, setters ...Option) []byte {
return out
}
func (p *Program) validateParams(opts *Options) error {
if len(p.params) == 0 {
return nil
}
// There might be no errors.
// Thus, we allocate this slice lazily, on a first error.
var missedParams []string
for n := range p.params {
_, exists := opts.params[n]
if !exists {
if missedParams == nil {
missedParams = make([]string, 0, len(p.params))
}
missedParams = append(missedParams, "@"+n)
}
}
if len(missedParams) > 0 {
return core.Error(ErrMissedParam, strings.Join(missedParams, ", "))
}
return nil
}