diff --git a/pkg/compiler/compiler_test.go b/pkg/compiler/compiler_test.go index 77cf12ea..990cb15b 100644 --- a/pkg/compiler/compiler_test.go +++ b/pkg/compiler/compiler_test.go @@ -978,6 +978,81 @@ func TestLogicalOperators(t *testing.T) { So(err, ShouldBeNil) So(string(out), ShouldEqual, "true") }) + + Convey("1 || 7 should return 1", t, func() { + c := compiler.New() + + prog, err := c.Compile(` + RETURN 1 || 7 + `) + + So(err, ShouldBeNil) + + out, err := prog.Run(context.Background()) + + So(err, ShouldBeNil) + So(string(out), ShouldEqual, "1") + }) + + Convey("NONE || 'foo' should return 'foo'", t, func() { + c := compiler.New() + + prog, err := c.Compile(` + RETURN NONE || 'foo' + `) + + So(err, ShouldBeNil) + + out, err := prog.Run(context.Background()) + + So(err, ShouldBeNil) + So(string(out), ShouldEqual, `"foo"`) + }) + + Convey("NONE && true should return null", t, func() { + c := compiler.New() + + prog, err := c.Compile(` + RETURN NONE && true + `) + + So(err, ShouldBeNil) + + out, err := prog.Run(context.Background()) + + So(err, ShouldBeNil) + So(string(out), ShouldEqual, `null`) + }) + + Convey("'' && true should return ''", t, func() { + c := compiler.New() + + prog, err := c.Compile(` + RETURN '' && true + `) + + So(err, ShouldBeNil) + + out, err := prog.Run(context.Background()) + + So(err, ShouldBeNil) + So(string(out), ShouldEqual, `""`) + }) + + Convey("true && 23 should return '23", t, func() { + c := compiler.New() + + prog, err := c.Compile(` + RETURN true && 23 + `) + + So(err, ShouldBeNil) + + out, err := prog.Run(context.Background()) + + So(err, ShouldBeNil) + So(string(out), ShouldEqual, `23`) + }) } func TestMathOperators(t *testing.T) { diff --git a/pkg/runtime/expressions/condition.go b/pkg/runtime/expressions/condition.go index 629291eb..9f8f2f8f 100644 --- a/pkg/runtime/expressions/condition.go +++ b/pkg/runtime/expressions/condition.go @@ -42,7 +42,7 @@ func (e *ConditionExpression) Exec(ctx context.Context, scope *core.Scope) (core return values.None, core.SourceError(e.src, err) } - cond := e.evalTestValue(out) + cond := values.ToBoolean(out) var next core.Expression @@ -65,20 +65,3 @@ func (e *ConditionExpression) Exec(ctx context.Context, scope *core.Scope) (core return res, nil } - -func (e *ConditionExpression) evalTestValue(value core.Value) values.Boolean { - switch value.Type() { - case core.BooleanType: - return value.(values.Boolean) - case core.NoneType: - return values.False - case core.StringType: - return values.NewBoolean(value.String() != "") - case core.IntType: - return values.NewBoolean(value.(values.Int) != 0) - case core.FloatType: - return values.NewBoolean(value.(values.Float) != 0) - default: - return values.True - } -} diff --git a/pkg/runtime/expressions/operators/logical.go b/pkg/runtime/expressions/operators/logical.go index b0e55818..9077922b 100644 --- a/pkg/runtime/expressions/operators/logical.go +++ b/pkg/runtime/expressions/operators/logical.go @@ -57,18 +57,22 @@ func (operator *LogicalOperator) Exec(ctx context.Context, scope *core.Scope) (c return nil, err } - left = operator.ensureType(left) - if operator.value == NotType { return Not(left, values.None), nil } - if operator.value == AndType && left == values.False { - return values.False, nil + leftBool := values.ToBoolean(left) + + if operator.value == AndType && leftBool == values.False { + if left.Type() == core.BooleanType { + return values.False, nil + } + + return left, nil } - if operator.value == OrType && left == values.True { - return values.True, nil + if operator.value == OrType && leftBool == values.True { + return left, nil } right, err := operator.right.Exec(ctx, scope) @@ -77,17 +81,5 @@ func (operator *LogicalOperator) Exec(ctx context.Context, scope *core.Scope) (c return nil, err } - return operator.ensureType(right), nil -} - -func (operator *LogicalOperator) ensureType(value core.Value) core.Value { - if value.Type() != core.BooleanType { - if value.Type() == core.NoneType { - return values.False - } - - return values.True - } - - return value + return right, nil } diff --git a/pkg/runtime/expressions/operators/operator.go b/pkg/runtime/expressions/operators/operator.go index 850a7a8d..48405a74 100644 --- a/pkg/runtime/expressions/operators/operator.go +++ b/pkg/runtime/expressions/operators/operator.go @@ -72,13 +72,13 @@ func GreaterOrEqual(left, right core.Value) core.Value { } func Not(left, _ core.Value) core.Value { - if left == values.True { + b := values.ToBoolean(left) + + if b == values.True { return values.False - } else if left == values.False { - return values.True } - return values.False + return values.True } // Adds numbers @@ -86,33 +86,33 @@ func Not(left, _ core.Value) core.Value { func Add(left, right core.Value) core.Value { if left.Type() == core.IntType { if right.Type() == core.IntType { - l := left.Unwrap().(int) - r := right.Unwrap().(int) + l := left.(values.Int) + r := right.(values.Int) - return values.NewInt(l + r) + return l + r } if right.Type() == core.FloatType { - l := left.Unwrap().(int) - r := right.Unwrap().(float64) + l := left.(values.Int) + r := right.(values.Float) - return values.Float(float64(l) + r) + return values.Float(l) + r } } if left.Type() == core.FloatType { if right.Type() == core.FloatType { - l := left.Unwrap().(float64) - r := right.Unwrap().(float64) + l := left.(values.Float) + r := right.(values.Float) - return values.Float(l + r) + return l + r } if right.Type() == core.IntType { - l := left.Unwrap().(float64) - r := right.Unwrap().(int) + l := left.(values.Float) + r := right.(values.Int) - return values.Float(l + float64(r)) + return l + values.Float(r) } } @@ -122,33 +122,33 @@ func Add(left, right core.Value) core.Value { func Subtract(left, right core.Value) core.Value { if left.Type() == core.IntType { if right.Type() == core.IntType { - l := left.Unwrap().(int) - r := right.Unwrap().(int) + l := left.(values.Int) + r := right.(values.Int) - return values.NewInt(l - r) + return l - r } if right.Type() == core.FloatType { - l := left.Unwrap().(int) - r := right.Unwrap().(float64) + l := left.(values.Int) + r := right.(values.Float) - return values.Float(float64(l) - r) + return values.Float(l) - r } } if left.Type() == core.FloatType { if right.Type() == core.FloatType { - l := left.Unwrap().(float64) - r := right.Unwrap().(float64) + l := left.(values.Float) + r := right.(values.Float) - return values.Float(l - r) + return l - r } if right.Type() == core.IntType { - l := left.Unwrap().(float64) - r := right.Unwrap().(int) + l := left.(values.Float) + r := right.(values.Int) - return values.Float(l - float64(r)) + return l - values.Float(r) } } @@ -158,33 +158,33 @@ func Subtract(left, right core.Value) core.Value { func Multiply(left, right core.Value) core.Value { if left.Type() == core.IntType { if right.Type() == core.IntType { - l := left.Unwrap().(int) - r := right.Unwrap().(int) + l := left.(values.Int) + r := right.(values.Int) - return values.NewInt(l * r) + return l * r } if right.Type() == core.FloatType { - l := left.Unwrap().(int) - r := right.Unwrap().(float64) + l := left.(values.Int) + r := right.(values.Float) - return values.Float(float64(l) * r) + return values.Float(l) * r } } if left.Type() == core.FloatType { if right.Type() == core.FloatType { - l := left.Unwrap().(float64) - r := right.Unwrap().(float64) + l := left.(values.Float) + r := right.(values.Float) - return values.Float(l * r) + return l * r } if right.Type() == core.IntType { - l := left.Unwrap().(float64) - r := right.Unwrap().(int) + l := left.(values.Float) + r := right.(values.Int) - return values.Float(l * float64(r)) + return l * values.Float(r) } } @@ -194,33 +194,33 @@ func Multiply(left, right core.Value) core.Value { func Divide(left, right core.Value) core.Value { if left.Type() == core.IntType { if right.Type() == core.IntType { - l := left.Unwrap().(int) - r := right.Unwrap().(int) + l := left.(values.Int) + r := right.(values.Int) - return values.NewInt(l / r) + return l / r } if right.Type() == core.FloatType { - l := left.Unwrap().(int) - r := right.Unwrap().(float64) + l := left.(values.Int) + r := right.(values.Float) - return values.Float(float64(l) / r) + return values.Float(l) / r } } if left.Type() == core.FloatType { if right.Type() == core.FloatType { - l := left.Unwrap().(float64) - r := right.Unwrap().(float64) + l := left.(values.Float) + r := right.(values.Float) - return values.Float(l / r) + return l / r } if right.Type() == core.IntType { - l := left.Unwrap().(float64) - r := right.Unwrap().(int) + l := left.(values.Float) + r := right.(values.Int) - return values.Float(l / float64(r)) + return l / values.Float(r) } } @@ -230,33 +230,33 @@ func Divide(left, right core.Value) core.Value { func Modulus(left, right core.Value) core.Value { if left.Type() == core.IntType { if right.Type() == core.IntType { - l := left.Unwrap().(int) - r := right.Unwrap().(int) + l := left.(values.Int) + r := right.(values.Int) - return values.NewInt(l % r) + return l % r } if right.Type() == core.FloatType { - l := left.Unwrap().(int) - r := right.Unwrap().(float64) + l := left.(values.Int) + r := right.(values.Float) - return values.Float(l % int(r)) + return l % values.Int(r) } } if left.Type() == core.FloatType { if right.Type() == core.FloatType { - l := left.Unwrap().(float64) - r := right.Unwrap().(float64) + l := left.(values.Float) + r := right.(values.Float) - return values.Float(int(l) % int(r)) + return values.Int(l) % values.Int(r) } if right.Type() == core.IntType { - l := left.Unwrap().(float64) - r := right.Unwrap().(int) + l := left.(values.Float) + r := right.(values.Int) - return values.Float(int(l) % r) + return values.Int(l) % r } } @@ -265,15 +265,15 @@ func Modulus(left, right core.Value) core.Value { func Increment(left, _ core.Value) core.Value { if left.Type() == core.IntType { - l := left.Unwrap().(int) + l := left.(values.Int) - return values.NewInt(l + 1) + return l + 1 } if left.Type() == core.FloatType { - l := left.Unwrap().(float64) + l := left.(values.Float) - return values.Float(l + 1) + return l + 1 } return values.None @@ -281,15 +281,15 @@ func Increment(left, _ core.Value) core.Value { func Decrement(left, _ core.Value) core.Value { if left.Type() == core.IntType { - l := left.Unwrap().(int) + l := left.(values.Int) - return values.NewInt(l - 1) + return l - 1 } if left.Type() == core.FloatType { - l := left.Unwrap().(float64) + l := left.(values.Float) - return values.Float(l - 1) + return l - 1 } return values.None diff --git a/pkg/runtime/values/helpers.go b/pkg/runtime/values/helpers.go index a22cd035..8d34c05c 100644 --- a/pkg/runtime/values/helpers.go +++ b/pkg/runtime/values/helpers.go @@ -209,3 +209,20 @@ func Parse(input interface{}) core.Value { return None } + +func ToBoolean(input core.Value) core.Value { + switch input.Type() { + case core.BooleanType: + return input + case core.NoneType: + return False + case core.StringType: + return NewBoolean(input.String() != "") + case core.IntType: + return NewBoolean(input.(Int) != 0) + case core.FloatType: + return NewBoolean(input.(Float) != 0) + default: + return True + } +}