diff --git a/pkg/compiler/compiler_param_test.go b/pkg/compiler/compiler_param_test.go index a44d17a5..5e202d08 100644 --- a/pkg/compiler/compiler_param_test.go +++ b/pkg/compiler/compiler_param_test.go @@ -111,6 +111,38 @@ func TestParam(t *testing.T) { ) So(string(out), ShouldEqual, `"baz"`) + }) + Convey("Should return an error if param values are not passed", t, func() { + prog := compiler.New(). + MustCompile(` + LET doc = { foo: { bar: "baz" } } + + RETURN doc.@attr.@subattr + `) + + _, err := prog.Run( + context.Background(), + ) + + So(err, ShouldNotBeNil) + So(err.Error(), ShouldContainSubstring, runtime.ErrMissedParam.Error()) + }) + + Convey("Should be possible to use in member expression as segments", t, func() { + prog := compiler.New(). + MustCompile(` + LET doc = { foo: { bar: "baz" } } + + RETURN doc.@attr.@subattr + `) + + _, err := prog.Run( + context.Background(), + runtime.WithParam("attr", "foo"), + ) + + So(err, ShouldNotBeNil) + So(err.Error(), ShouldContainSubstring, "subattr") }) } diff --git a/pkg/compiler/scope.go b/pkg/compiler/scope.go index 9acac22e..d7aea34e 100644 --- a/pkg/compiler/scope.go +++ b/pkg/compiler/scope.go @@ -2,47 +2,56 @@ package compiler import ( "github.com/MontFerret/ferret/pkg/runtime/core" - "github.com/MontFerret/ferret/pkg/runtime/values/types" ) type ( + globalScope struct { + params map[string]struct{} + } + scope struct { + global *globalScope parent *scope - vars map[string]core.Type + vars map[string]struct{} } ) -func newRootScope() *scope { +func newGlobalScope() *globalScope { + return &globalScope{ + params: map[string]struct{}{}, + } +} + +func newRootScope(global *globalScope) *scope { return &scope{ - vars: make(map[string]core.Type), + global: global, + vars: make(map[string]struct{}), } } func newScope(parent *scope) *scope { - s := newRootScope() + s := newRootScope(parent.global) s.parent = parent return s } -func (s *scope) GetVariable(name string) (core.Type, error) { - local, exists := s.vars[name] +func (s *scope) AddParam(name string) { + s.global.params[name] = struct{}{} +} + +func (s *scope) HasVariable(name string) bool { + _, exists := s.vars[name] if exists { - return local, nil + return true } if s.parent != nil { - parents, err := s.parent.GetVariable(name) - - if err != nil { - return types.None, err - } - - return parents, nil + return s.parent.HasVariable(name) } - return types.None, core.Error(ErrVariableNotFound, name) + return false } func (s *scope) SetVariable(name string) error { @@ -53,7 +62,7 @@ func (s *scope) SetVariable(name string) error { } // TODO: add type detection - s.vars[name] = types.None + s.vars[name] = struct{}{} return nil } @@ -71,7 +80,7 @@ func (s *scope) RemoveVariable(name string) error { } func (s *scope) ClearVariables() { - s.vars = make(map[string]core.Type) + s.vars = make(map[string]struct{}) } func (s *scope) Fork() *scope { diff --git a/pkg/compiler/visitor.go b/pkg/compiler/visitor.go index dedd6cad..49979537 100644 --- a/pkg/compiler/visitor.go +++ b/pkg/compiler/visitor.go @@ -38,14 +38,15 @@ func newVisitor(src string, funcs map[string]core.Function) *visitor { func (v *visitor) VisitProgram(ctx *fql.ProgramContext) interface{} { return newResultFrom(func() (interface{}, error) { - rootScope := newRootScope() - block, err := v.doVisitBody(ctx.Body().(*fql.BodyContext), rootScope) + gs := newGlobalScope() + rs := newRootScope(gs) + block, err := v.doVisitBody(ctx.Body().(*fql.BodyContext), rs) if err != nil { return nil, err } - return runtime.NewProgram(v.src, block) + return runtime.NewProgram(v.src, block, gs.params) }) } @@ -801,10 +802,8 @@ func (v *visitor) doVisitMember(ctx *fql.MemberContext, scope *scope) (core.Expr if identifier != nil { varName := ctx.Identifier().GetText() - _, err := scope.GetVariable(varName) - - if err != nil { - return nil, err + if !scope.HasVariable(varName) { + return nil, core.Error(ErrVariableNotFound, varName) } exp, err := expressions.NewVariableExpression(v.getSourceMap(ctx), varName) @@ -927,10 +926,8 @@ func (v *visitor) doVisitComputedPropertyNameContext(ctx *fql.ComputedPropertyNa func (v *visitor) doVisitShorthandPropertyNameContext(ctx *fql.ShorthandPropertyNameContext, scope *scope) (core.Expression, error) { name := ctx.Variable().GetText() - _, err := scope.GetVariable(name) - - if err != nil { - return nil, err + if !scope.HasVariable(name) { + return nil, core.Error(ErrVariableNotFound, name) } return literals.NewStringLiteral(ctx.Variable().GetText()), nil @@ -1005,10 +1002,8 @@ func (v *visitor) doVisitVariable(ctx *fql.VariableContext, scope *scope) (core. name := ctx.Identifier().GetText() // check whether the variable is defined - _, err := scope.GetVariable(name) - - if err != nil { - return nil, err + if !scope.HasVariable(name) { + return nil, core.Error(ErrVariableNotFound, name) } return expressions.NewVariableExpression(v.getSourceMap(ctx), name) @@ -1117,9 +1112,11 @@ func (v *visitor) doVisitFunctionCallExpression(context *fql.FunctionCallExpress ) } -func (v *visitor) doVisitParamContext(context *fql.ParamContext, _ *scope) (core.Expression, error) { +func (v *visitor) doVisitParamContext(context *fql.ParamContext, s *scope) (core.Expression, error) { name := context.Identifier().GetText() + s.AddParam(name) + return expressions.NewParameterExpression( v.getSourceMap(context), name, diff --git a/pkg/runtime/errors.go b/pkg/runtime/errors.go new file mode 100644 index 00000000..3d336cb5 --- /dev/null +++ b/pkg/runtime/errors.go @@ -0,0 +1,9 @@ +package runtime + +import ( + "github.com/pkg/errors" +) + +var ( + ErrMissedParam = errors.New("missed value for parameter(s)") +) diff --git a/pkg/runtime/program.go b/pkg/runtime/program.go index d51a46c0..dd4391bd 100644 --- a/pkg/runtime/program.go +++ b/pkg/runtime/program.go @@ -3,6 +3,7 @@ package runtime import ( "context" "runtime" + "strings" "github.com/MontFerret/ferret/pkg/runtime/core" "github.com/MontFerret/ferret/pkg/runtime/logging" @@ -11,11 +12,12 @@ import ( ) type Program struct { - src string - body core.Expression + src string + body core.Expression + params map[string]struct{} } -func NewProgram(src string, body core.Expression) (*Program, error) { +func NewProgram(src string, body core.Expression, params map[string]struct{}) (*Program, error) { if src == "" { return nil, core.Error(core.ErrMissedArgument, "source") } @@ -24,16 +26,33 @@ func NewProgram(src string, body core.Expression) (*Program, error) { return nil, core.Error(core.ErrMissedArgument, "body") } - return &Program{src, body}, nil + return &Program{src, body, params}, nil } func (p *Program) Source() string { return p.src } -func (p *Program) Run(ctx context.Context, setters ...Option) (result []byte, err error) { - ctx = NewOptions(setters).WithContext(ctx) +func (p *Program) Params() []string { + res := make([]string, 0, len(p.params)) + for name := range p.params { + res = append(res, name) + } + + return res +} + +func (p *Program) Run(ctx context.Context, setters ...Option) (result []byte, err error) { + opts := NewOptions(setters) + + err = p.validateParams(opts) + + if err != nil { + return nil, err + } + + ctx = opts.WithContext(ctx) logger := logging.FromContext(ctx) defer func() { @@ -92,3 +111,31 @@ func (p *Program) MustRun(ctx context.Context, setters ...Option) []byte { return out } + +func (p *Program) validateParams(opts *Options) error { + if len(p.params) == 0 { + return nil + } + + // There might be no errors. + // Thus, we allocate this slice lazily, on a first error. + var missedParams []string + + for n := range p.params { + _, exists := opts.params[n] + + if !exists { + if missedParams == nil { + missedParams = make([]string, 0, len(p.params)) + } + + missedParams = append(missedParams, "@"+n) + } + } + + if len(missedParams) > 0 { + return core.Error(ErrMissedParam, strings.Join(missedParams, ", ")) + } + + return nil +}