1
0
mirror of https://github.com/MontFerret/ferret.git synced 2025-08-15 20:02:56 +02:00

Refactor logical operator compilation to optimize short-circuit evaluation, add assembly utilities (assembler and disassembler), and migrate integration tests from bytecode to compiler package for improved organization.

This commit is contained in:
Tim Voronov
2025-07-03 16:56:24 -04:00
parent ea827fc2f4
commit a1e98c3c3c
24 changed files with 427 additions and 141 deletions

8
pkg/asm/assembler.go Normal file
View File

@@ -0,0 +1,8 @@
package asm
import "github.com/MontFerret/ferret/pkg/vm"
// TODO: Implement the assembler that converts FASM (Ferret Assembly) code into a Ferret VM program.
func Assemble(fasm string) (*vm.Program, error) {
return new(vm.Program), nil
}

95
pkg/asm/disassembler.go Normal file
View File

@@ -0,0 +1,95 @@
package asm
import (
"bytes"
"fmt"
"text/tabwriter"
"github.com/MontFerret/ferret/pkg/vm"
)
// Disassemble returns a human-readable disassembly of the given program.
func Disassemble(p *vm.Program) string {
labels := collectLabels(p.Bytecode)
var buf bytes.Buffer
w := tabwriter.NewWriter(&buf, 0, 4, 2, ' ', 0)
// Header: params
for _, line := range formatParams(p) {
fmt.Fprintln(w, line)
}
// Body: disassembly
for ip, instr := range p.Bytecode {
if label, ok := labels[ip]; ok {
fmt.Fprintf(w, "%s:\n", label)
}
fmt.Fprintln(w, disasmLine(ip, instr, p, labels))
}
w.Flush()
return buf.String()
}
// collectLabels identifies jump targets and assigns symbolic labels to them.
func collectLabels(bytecode []vm.Instruction) map[int]string {
labels := make(map[int]string)
counter := 0
for _, instr := range bytecode {
switch instr.Opcode {
case vm.OpJump, vm.OpJumpIfFalse, vm.OpJumpIfTrue:
target := int(instr.Operands[0])
if _, ok := labels[target]; !ok {
labels[target] = fmt.Sprintf("@L%d", counter)
counter++
}
default:
// Do nothing for other opcodes
}
}
return labels
}
// disasmLine renders a single instruction into text, with optional constants and location info.
func disasmLine(ip int, instr vm.Instruction, p *vm.Program, labels map[int]string) string {
ops := instr.Operands
var out string
opcode := instr.Opcode
switch opcode {
case vm.OpLoadConst:
cIdx := ops[1].Constant()
comment := constValue(p, cIdx)
out = fmt.Sprintf("%d: %s R%d C%d ; %s", ip, opcode, ops[0], cIdx, comment)
case vm.OpMove:
out = fmt.Sprintf("%d: %s R%d R%d", ip, opcode, ops[0], ops[1])
case vm.OpAdd:
out = fmt.Sprintf("%d: %s R%d R%d R%d", ip, opcode, ops[0], ops[1], ops[2])
case vm.OpJump:
out = fmt.Sprintf("%d: %s %s", ip, opcode, labelOrAddr(int(ops[0]), labels))
case vm.OpJumpIfTrue, vm.OpJumpIfFalse:
out = fmt.Sprintf("%d: %s %s %s", ip, opcode, labelOrAddr(int(ops[0]), labels), ops[1])
case vm.OpReturn:
out = fmt.Sprintf("%d: %s R%d", ip, opcode, ops[0])
default:
out = fmt.Sprintf("%d: %s %v", ip, opcode, ops)
}
if loc := formatLocation(p, ip); loc != "" {
out += " " + loc
}
return out
}

View File

@@ -0,0 +1,70 @@
package asm
import (
"github.com/MontFerret/ferret/pkg/vm"
"reflect"
"testing"
)
func TestDisassemble(t *testing.T) {
type args struct {
p *vm.Program
}
tests := []struct {
name string
args args
want string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := Disassemble(tt.args.p); got != tt.want {
t.Errorf("Disassemble() = %v, want %v", got, tt.want)
}
})
}
}
func Test_collectLabels(t *testing.T) {
type args struct {
bytecode []vm.Instruction
}
tests := []struct {
name string
args args
want map[int]string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := collectLabels(tt.args.bytecode); !reflect.DeepEqual(got, tt.want) {
t.Errorf("collectLabels() = %v, want %v", got, tt.want)
}
})
}
}
func Test_disasmLine(t *testing.T) {
type args struct {
ip int
instr vm.Instruction
p *vm.Program
labels map[int]string
}
tests := []struct {
name string
args args
want string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := disasmLine(tt.args.ip, tt.args.instr, tt.args.p, tt.args.labels); got != tt.want {
t.Errorf("disasmLine() = %v, want %v", got, tt.want)
}
})
}
}

52
pkg/asm/formatter.go Normal file
View File

@@ -0,0 +1,52 @@
package asm
import (
"fmt"
"github.com/MontFerret/ferret/pkg/runtime"
"github.com/MontFerret/ferret/pkg/vm"
)
// labelOrAddr returns a label name if one exists for the given address; otherwise just the number.
func labelOrAddr(pos int, labels map[int]string) string {
if label, ok := labels[pos]; ok {
return label
}
return fmt.Sprintf("%d", pos)
}
// constValue renders the constant at a given index from the program.
func constValue(p *vm.Program, idx int) string {
if idx >= 0 && idx < len(p.Constants) {
constant := p.Constants[idx]
if runtime.IsNumber(constant) {
return fmt.Sprintf("%d", constant)
}
return fmt.Sprintf("%q", constant.String())
}
return "<invalid>"
}
// formatLocation returns a line/col comment if available for the given instruction.
func formatLocation(p *vm.Program, ip int) string {
if ip < len(p.Locations) {
loc := p.Locations[ip]
return fmt.Sprintf("; line %d col %d", loc.Line, loc.Column)
}
return ""
}
// formatParams generates comments mapping register indices to parameter names.
func formatParams(p *vm.Program) []string {
lines := []string{}
for i, name := range p.Params {
lines = append(lines, fmt.Sprintf("; param R%d = %s", i, name))
}
return lines
}

View File

@@ -74,40 +74,51 @@ func (ec *ExprCompiler) compileUnary(ctx fql.IUnaryOperatorContext, parent fql.I
// TODO: Free temporary registers if needed
func (ec *ExprCompiler) compileLogicalAnd(ctx fql.IExpressionContext) vm.Operand {
dst := ec.ctx.Registers.Allocate(core.Temp)
left := ec.Compile(ctx.GetLeft())
ec.ctx.Emitter.EmitMove(dst, left)
end := ec.ctx.Emitter.NewLabel()
rightDone := ec.ctx.Emitter.NewLabel()
dst := ec.ctx.Registers.Allocate(core.Temp)
// If left is false, jump to end
ec.ctx.Emitter.EmitJumpIfFalse(dst, end)
// If left is falsy, jump to end and use left
ec.ctx.Emitter.EmitJumpIfFalse(left, end)
// Otherwise evaluate right and use it
right := ec.Compile(ctx.GetRight())
ec.ctx.Emitter.EmitMove(dst, right)
ec.ctx.Emitter.EmitJump(rightDone)
// Short-circuit: use left as result
ec.ctx.Emitter.MarkLabel(end)
ec.ctx.Emitter.EmitMove(dst, left)
ec.ctx.Emitter.MarkLabel(rightDone)
return dst
}
// TODO: Free temporary registers if needed
func (ec *ExprCompiler) compileLogicalOr(ctx fql.IExpressionContext) vm.Operand {
dst := ec.ctx.Registers.Allocate(core.Temp)
left := ec.Compile(ctx.GetLeft())
ec.ctx.Emitter.EmitMove(dst, left)
end := ec.ctx.Emitter.NewLabel()
rightDone := ec.ctx.Emitter.NewLabel()
dst := ec.ctx.Registers.Allocate(core.Temp)
// If left is true, jump to end
ec.ctx.Emitter.EmitJumpIfTrue(dst, end)
// If left is truthy, short-circuit and skip right
ec.ctx.Emitter.EmitJumpIfTrue(left, end)
// Otherwise evaluate right
right := ec.Compile(ctx.GetRight())
ec.ctx.Emitter.EmitMove(dst, right)
ec.ctx.Emitter.EmitJump(rightDone)
// Short-circuit: use left value
ec.ctx.Emitter.MarkLabel(end)
ec.ctx.Emitter.EmitMove(dst, left)
// Common exit
ec.ctx.Emitter.MarkLabel(rightDone)
return dst
}

View File

@@ -1,35 +0,0 @@
package vm
import (
"strings"
"github.com/MontFerret/ferret/pkg/runtime"
)
func validateParams(env *Environment, program *Program) error {
if len(program.Params) == 0 {
return nil
}
// There might be no errors.
// Thus, we allocate this slice lazily, on a first error.
var missedParams []string
for _, n := range program.Params {
_, exists := env.params[n]
if !exists {
if missedParams == nil {
missedParams = make([]string, 0, len(program.Params))
}
missedParams = append(missedParams, "@"+n)
}
}
if len(missedParams) > 0 {
return runtime.Error(ErrMissedParam, strings.Join(missedParams, ", "))
}
return nil
}

View File

@@ -1,9 +1,5 @@
package vm
import (
"bytes"
)
type Instruction struct {
Opcode Opcode
Operands [3]Operand
@@ -28,20 +24,3 @@ func NewInstruction(opcode Opcode, operands ...Operand) Instruction {
Operands: ops,
}
}
func (i Instruction) String() string {
var buf bytes.Buffer
buf.WriteString(i.Opcode.String())
for idx, operand := range i.Operands {
if operand == 0 && idx > 0 {
break
}
buf.WriteString(" ")
buf.WriteString(operand.String())
}
return buf.String()
}

View File

@@ -231,6 +231,26 @@ func (op Opcode) String() string {
return "CALL"
case OpProtectedCall:
return "PCALL"
case OpCall0:
return "CALL0"
case OpProtectedCall0:
return "PCALL0"
case OpCall1:
return "CALL1"
case OpProtectedCall1:
return "PCALL1"
case OpCall2:
return "CALL2"
case OpProtectedCall2:
return "PCALL2"
case OpCall3:
return "CALL3"
case OpProtectedCall3:
return "PCALL3"
case OpCall4:
return "CALL4"
case OpProtectedCall4:
return "PCALL4"
// Collection Creation
case OpList:

78
pkg/vm/validation.go Normal file
View File

@@ -0,0 +1,78 @@
package vm
import (
"strings"
"github.com/MontFerret/ferret/pkg/runtime"
)
func validate(env *Environment, program *Program) error {
if err := validateParams(env, program); err != nil {
return err
}
if err := validateFunctions(env, program); err != nil {
return err
}
return nil
}
func validateParams(env *Environment, program *Program) error {
if len(program.Params) == 0 {
return nil
}
// There might be no errors.
// Thus, we allocate this slice lazily, on a first error.
var missedParams []string
for _, n := range program.Params {
_, exists := env.params[n]
if !exists {
if missedParams == nil {
missedParams = make([]string, 0, len(program.Params))
}
missedParams = append(missedParams, "@"+n)
}
}
if len(missedParams) > 0 {
return runtime.Error(ErrMissedParam, strings.Join(missedParams, ", "))
}
return nil
}
// TODO: Implement this function.
func validateFunctions(env *Environment, program *Program) error {
//if len(program.Locations) == 0 {
// return nil
//}
//
//// There might be no errors.
//// Thus, we allocate this slice lazily, on a first error.
//var missedFunctions []string
//
//for _, loc := range program.Locations {
// if loc.Function == "" {
// continue
// }
//
// if _, exists := env.functions[loc.Function]; !exists {
// if missedFunctions == nil {
// missedFunctions = make([]string, 0, len(program.Locations))
// }
//
// missedFunctions = append(missedFunctions, loc.Function)
// }
//}
//
//if len(missedFunctions) > 0 {
// return runtime.Error(ErrFunctionNotFound, strings.Join(missedFunctions, ", "))
//}
//
return nil
}

View File

@@ -26,7 +26,7 @@ func New(program *Program) *VM {
func (vm *VM) Run(ctx context.Context, opts []EnvironmentOption) (runtime.Value, error) {
env := newEnvironment(opts)
if err := validateParams(env, vm.program); err != nil {
if err := validate(env, vm.program); err != nil {
return nil, err
}

View File

@@ -1,9 +1,9 @@
package bytecode_test
package compiler_test
import (
"github.com/smartystreets/goconvey/convey"
"fmt"
"github.com/MontFerret/ferret/pkg/vm"
"github.com/smartystreets/goconvey/convey"
)
func CastToProgram(prog any) *vm.Program {
@@ -15,13 +15,26 @@ func CastToProgram(prog any) *vm.Program {
}
func ShouldEqualBytecode(e any, a ...any) string {
expected := CastToProgram(e).Bytecode
actual := CastToProgram(a[0]).Bytecode
expected := CastToProgram(e)
actual := CastToProgram(a[0])
for i := 0; i < len(expected); i++ {
if err := convey.ShouldEqual(actual[i].String(), expected[i].String()); err != "" {
for i := 0; i < len(expected.Bytecode); i++ {
actualIns := actual.Bytecode[i]
expectedIns := expected.Bytecode[i]
if err := convey.ShouldEqual(actualIns.Opcode, expectedIns.Opcode); err != "" {
return err
}
if err := convey.ShouldEqual(len(actualIns.Operands), len(expectedIns.Operands)); err != "" {
return fmt.Sprintf("operends length mismatch at index %d: expected %d, got %d", i, len(expectedIns.Operands), len(actualIns.Operands))
}
for j := 0; j < len(actualIns.Operands); j++ {
if err := convey.ShouldEqual(actualIns.Operands[j], expectedIns.Operands[j]); err != "" {
return fmt.Sprintf("operands mismatch at index %d, operand %d: expected %s, got %s", i, j, expectedIns.Operands[j], actualIns.Operands[j])
}
}
}
return ""

View File

@@ -1,4 +1,4 @@
package bytecode_test
package compiler_test
import (
"testing"

View File

@@ -1,4 +1,4 @@
package bytecode_test
package compiler_test
import (
"testing"

View File

@@ -1,4 +1,4 @@
package bytecode_test
package compiler_test
import (
"testing"

View File

@@ -1,4 +1,4 @@
package bytecode_test
package compiler_test
import (
"testing"

View File

@@ -1,4 +1,4 @@
package bytecode_test
package compiler_test
import (
"testing"

View File

@@ -1,4 +1,4 @@
package bytecode_test
package compiler_test
import (
"github.com/MontFerret/ferret/pkg/vm"

View File

@@ -1,4 +1,4 @@
package bytecode_test
package compiler_test
import (
"testing"

View File

@@ -1,4 +1,4 @@
package bytecode_test
package compiler_test
import (
"testing"

View File

@@ -1,4 +1,4 @@
package bytecode_test
package compiler_test
import (
"testing"

View File

@@ -0,0 +1,40 @@
package compiler_test
import (
"github.com/MontFerret/ferret/pkg/vm"
"testing"
)
func TestLogicalOperators(t *testing.T) {
RunUseCases(t, []UseCase{
SkipByteCodeCase("RETURN 1 AND 0", BC{
I(vm.OpLoadConst, 1, C(0)),
I(vm.OpJumpIfFalse),
I(vm.OpLoadConst, 1, C(1)),
I(vm.OpReturn, 1),
}),
ByteCodeCase("RETURN 1 OR 0", BC{
I(vm.OpLoadConst, 1, C(0)),
I(vm.OpJumpIfFalse),
I(vm.OpLoadConst, 1, C(1)),
I(vm.OpReturn, 1),
}),
//Case("RETURN 1 AND 1", 1),
//Case("RETURN 2 > 1 AND 1 > 0", true),
//Case("RETURN NONE && true", nil),
//Case("RETURN '' && true", ""),
//Case("RETURN true && 23", 23),
//Case("RETURN 1 OR 0", 1),
//Case("RETURN 0 OR 1", 1),
//Case("RETURN 2 OR 1", 2),
//Case("RETURN 2 > 1 OR 1 > 0", true),
//Case("RETURN 2 < 1 OR 1 > 0", true),
//Case("RETURN 1 || 7", 1),
//Case("RETURN 0 || 7", 7),
//Case("RETURN NONE || 'foo'", "foo"),
//Case("RETURN '' || 'foo'", "foo"),
//Case(`RETURN ERROR()? || 'boo'`, "boo"),
//Case(`RETURN !ERROR()? && TRUE`, true),
//Case(`LET u = { valid: false } RETURN u.valid || TRUE`, true),
})
}

View File

@@ -1,4 +1,4 @@
package bytecode_test
package compiler_test
import (
"testing"
@@ -10,61 +10,10 @@ func TestString(t *testing.T) {
RunUseCases(t, []UseCase{
ByteCodeCase(
`
RETURN "
FOO
BAR
"
`, []vm.Instruction{
RETURN "FOO BAR"
`, BC{
I(vm.OpLoadConst, 1, C(0)),
I(vm.OpMove, 0, R(1)),
I(vm.OpReturn, 0),
I(vm.OpReturn, 1),
}, "Should be possible to use multi line string"),
//
// CaseJSON(
// fmt.Sprintf(`
//RETURN %s<!DOCTYPE html>
// <html lang="en">
// <head>
// <meta charset="UTF-8">
// <title>GetTitle</title>
// </head>
// <body>
// Hello world
// </body>
// </html>%s
//`, "`", "`"), `<!DOCTYPE html>
// <html lang="en">
// <head>
// <meta charset="UTF-8">
// <title>GetTitle</title>
// </head>
// <body>
// Hello world
// </body>
// </html>`, "Should be possible to use multi line string with nested strings using backtick"),
//
// CaseJSON(
// fmt.Sprintf(`
//RETURN %s<!DOCTYPE html>
// <html lang="en">
// <head>
// <meta charset="UTF-8">
// <title>GetTitle</title>
// </head>
// <body>
// Hello world
// </body>
// </html>%s
//`, "´", "´"),
// `<!DOCTYPE html>
// <html lang="en">
// <head>
// <meta charset="UTF-8">
// <title>GetTitle</title>
// </head>
// <body>
// Hello world
// </body>
// </html>`, "Should be possible to use multi line string with nested strings using tick"),
})
}

View File

@@ -1,4 +1,4 @@
package bytecode_test
package compiler_test
import (
"github.com/MontFerret/ferret/pkg/vm"

View File

@@ -1,7 +1,8 @@
package bytecode_test
package compiler_test
import (
"fmt"
"github.com/MontFerret/ferret/pkg/asm"
"strings"
"testing"
@@ -64,7 +65,12 @@ func RunUseCasesWith(t *testing.T, c *compiler.Compiler, useCases []UseCase) {
println("")
println("Actual:")
println(actual.String())
println(asm.Disassemble(actual))
if p, ok := useCase.Expected.(*vm.Program); ok {
println("Expected:")
println(asm.Disassemble(p))
}
convey.So(err, convey.ShouldBeNil)