mirror of
https://github.com/MontFerret/ferret.git
synced 2025-01-18 03:22:02 +02:00
Merge pull request #396 from MontFerret/feature/#373-params-check
Added params check before execution
This commit is contained in:
commit
625f78fd98
@ -111,6 +111,38 @@ func TestParam(t *testing.T) {
|
||||
)
|
||||
|
||||
So(string(out), ShouldEqual, `"baz"`)
|
||||
})
|
||||
|
||||
Convey("Should return an error if param values are not passed", t, func() {
|
||||
prog := compiler.New().
|
||||
MustCompile(`
|
||||
LET doc = { foo: { bar: "baz" } }
|
||||
|
||||
RETURN doc.@attr.@subattr
|
||||
`)
|
||||
|
||||
_, err := prog.Run(
|
||||
context.Background(),
|
||||
)
|
||||
|
||||
So(err, ShouldNotBeNil)
|
||||
So(err.Error(), ShouldContainSubstring, runtime.ErrMissedParam.Error())
|
||||
})
|
||||
|
||||
Convey("Should be possible to use in member expression as segments", t, func() {
|
||||
prog := compiler.New().
|
||||
MustCompile(`
|
||||
LET doc = { foo: { bar: "baz" } }
|
||||
|
||||
RETURN doc.@attr.@subattr
|
||||
`)
|
||||
|
||||
_, err := prog.Run(
|
||||
context.Background(),
|
||||
runtime.WithParam("attr", "foo"),
|
||||
)
|
||||
|
||||
So(err, ShouldNotBeNil)
|
||||
So(err.Error(), ShouldContainSubstring, "subattr")
|
||||
})
|
||||
}
|
||||
|
@ -2,47 +2,56 @@ package compiler
|
||||
|
||||
import (
|
||||
"github.com/MontFerret/ferret/pkg/runtime/core"
|
||||
"github.com/MontFerret/ferret/pkg/runtime/values/types"
|
||||
)
|
||||
|
||||
type (
|
||||
globalScope struct {
|
||||
params map[string]struct{}
|
||||
}
|
||||
|
||||
scope struct {
|
||||
global *globalScope
|
||||
parent *scope
|
||||
vars map[string]core.Type
|
||||
vars map[string]struct{}
|
||||
}
|
||||
)
|
||||
|
||||
func newRootScope() *scope {
|
||||
func newGlobalScope() *globalScope {
|
||||
return &globalScope{
|
||||
params: map[string]struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
func newRootScope(global *globalScope) *scope {
|
||||
return &scope{
|
||||
vars: make(map[string]core.Type),
|
||||
global: global,
|
||||
vars: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func newScope(parent *scope) *scope {
|
||||
s := newRootScope()
|
||||
s := newRootScope(parent.global)
|
||||
s.parent = parent
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *scope) GetVariable(name string) (core.Type, error) {
|
||||
local, exists := s.vars[name]
|
||||
func (s *scope) AddParam(name string) {
|
||||
s.global.params[name] = struct{}{}
|
||||
}
|
||||
|
||||
func (s *scope) HasVariable(name string) bool {
|
||||
_, exists := s.vars[name]
|
||||
|
||||
if exists {
|
||||
return local, nil
|
||||
return true
|
||||
}
|
||||
|
||||
if s.parent != nil {
|
||||
parents, err := s.parent.GetVariable(name)
|
||||
|
||||
if err != nil {
|
||||
return types.None, err
|
||||
}
|
||||
|
||||
return parents, nil
|
||||
return s.parent.HasVariable(name)
|
||||
}
|
||||
|
||||
return types.None, core.Error(ErrVariableNotFound, name)
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *scope) SetVariable(name string) error {
|
||||
@ -53,7 +62,7 @@ func (s *scope) SetVariable(name string) error {
|
||||
}
|
||||
|
||||
// TODO: add type detection
|
||||
s.vars[name] = types.None
|
||||
s.vars[name] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -71,7 +80,7 @@ func (s *scope) RemoveVariable(name string) error {
|
||||
}
|
||||
|
||||
func (s *scope) ClearVariables() {
|
||||
s.vars = make(map[string]core.Type)
|
||||
s.vars = make(map[string]struct{})
|
||||
}
|
||||
|
||||
func (s *scope) Fork() *scope {
|
||||
|
@ -38,14 +38,15 @@ func newVisitor(src string, funcs map[string]core.Function) *visitor {
|
||||
|
||||
func (v *visitor) VisitProgram(ctx *fql.ProgramContext) interface{} {
|
||||
return newResultFrom(func() (interface{}, error) {
|
||||
rootScope := newRootScope()
|
||||
block, err := v.doVisitBody(ctx.Body().(*fql.BodyContext), rootScope)
|
||||
gs := newGlobalScope()
|
||||
rs := newRootScope(gs)
|
||||
block, err := v.doVisitBody(ctx.Body().(*fql.BodyContext), rs)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runtime.NewProgram(v.src, block)
|
||||
return runtime.NewProgram(v.src, block, gs.params)
|
||||
})
|
||||
}
|
||||
|
||||
@ -801,10 +802,8 @@ func (v *visitor) doVisitMember(ctx *fql.MemberContext, scope *scope) (core.Expr
|
||||
if identifier != nil {
|
||||
varName := ctx.Identifier().GetText()
|
||||
|
||||
_, err := scope.GetVariable(varName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if !scope.HasVariable(varName) {
|
||||
return nil, core.Error(ErrVariableNotFound, varName)
|
||||
}
|
||||
|
||||
exp, err := expressions.NewVariableExpression(v.getSourceMap(ctx), varName)
|
||||
@ -927,10 +926,8 @@ func (v *visitor) doVisitComputedPropertyNameContext(ctx *fql.ComputedPropertyNa
|
||||
func (v *visitor) doVisitShorthandPropertyNameContext(ctx *fql.ShorthandPropertyNameContext, scope *scope) (core.Expression, error) {
|
||||
name := ctx.Variable().GetText()
|
||||
|
||||
_, err := scope.GetVariable(name)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if !scope.HasVariable(name) {
|
||||
return nil, core.Error(ErrVariableNotFound, name)
|
||||
}
|
||||
|
||||
return literals.NewStringLiteral(ctx.Variable().GetText()), nil
|
||||
@ -1005,10 +1002,8 @@ func (v *visitor) doVisitVariable(ctx *fql.VariableContext, scope *scope) (core.
|
||||
name := ctx.Identifier().GetText()
|
||||
|
||||
// check whether the variable is defined
|
||||
_, err := scope.GetVariable(name)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if !scope.HasVariable(name) {
|
||||
return nil, core.Error(ErrVariableNotFound, name)
|
||||
}
|
||||
|
||||
return expressions.NewVariableExpression(v.getSourceMap(ctx), name)
|
||||
@ -1117,9 +1112,11 @@ func (v *visitor) doVisitFunctionCallExpression(context *fql.FunctionCallExpress
|
||||
)
|
||||
}
|
||||
|
||||
func (v *visitor) doVisitParamContext(context *fql.ParamContext, _ *scope) (core.Expression, error) {
|
||||
func (v *visitor) doVisitParamContext(context *fql.ParamContext, s *scope) (core.Expression, error) {
|
||||
name := context.Identifier().GetText()
|
||||
|
||||
s.AddParam(name)
|
||||
|
||||
return expressions.NewParameterExpression(
|
||||
v.getSourceMap(context),
|
||||
name,
|
||||
|
9
pkg/runtime/errors.go
Normal file
9
pkg/runtime/errors.go
Normal file
@ -0,0 +1,9 @@
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissedParam = errors.New("missed value for parameter(s)")
|
||||
)
|
@ -3,6 +3,7 @@ package runtime
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/MontFerret/ferret/pkg/runtime/core"
|
||||
"github.com/MontFerret/ferret/pkg/runtime/logging"
|
||||
@ -11,11 +12,12 @@ import (
|
||||
)
|
||||
|
||||
type Program struct {
|
||||
src string
|
||||
body core.Expression
|
||||
src string
|
||||
body core.Expression
|
||||
params map[string]struct{}
|
||||
}
|
||||
|
||||
func NewProgram(src string, body core.Expression) (*Program, error) {
|
||||
func NewProgram(src string, body core.Expression, params map[string]struct{}) (*Program, error) {
|
||||
if src == "" {
|
||||
return nil, core.Error(core.ErrMissedArgument, "source")
|
||||
}
|
||||
@ -24,16 +26,33 @@ func NewProgram(src string, body core.Expression) (*Program, error) {
|
||||
return nil, core.Error(core.ErrMissedArgument, "body")
|
||||
}
|
||||
|
||||
return &Program{src, body}, nil
|
||||
return &Program{src, body, params}, nil
|
||||
}
|
||||
|
||||
func (p *Program) Source() string {
|
||||
return p.src
|
||||
}
|
||||
|
||||
func (p *Program) Run(ctx context.Context, setters ...Option) (result []byte, err error) {
|
||||
ctx = NewOptions(setters).WithContext(ctx)
|
||||
func (p *Program) Params() []string {
|
||||
res := make([]string, 0, len(p.params))
|
||||
|
||||
for name := range p.params {
|
||||
res = append(res, name)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (p *Program) Run(ctx context.Context, setters ...Option) (result []byte, err error) {
|
||||
opts := NewOptions(setters)
|
||||
|
||||
err = p.validateParams(opts)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx = opts.WithContext(ctx)
|
||||
logger := logging.FromContext(ctx)
|
||||
|
||||
defer func() {
|
||||
@ -92,3 +111,31 @@ func (p *Program) MustRun(ctx context.Context, setters ...Option) []byte {
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *Program) validateParams(opts *Options) error {
|
||||
if len(p.params) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// There might be no errors.
|
||||
// Thus, we allocate this slice lazily, on a first error.
|
||||
var missedParams []string
|
||||
|
||||
for n := range p.params {
|
||||
_, exists := opts.params[n]
|
||||
|
||||
if !exists {
|
||||
if missedParams == nil {
|
||||
missedParams = make([]string, 0, len(p.params))
|
||||
}
|
||||
|
||||
missedParams = append(missedParams, "@"+n)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missedParams) > 0 {
|
||||
return core.Error(ErrMissedParam, strings.Join(missedParams, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user