1
0
mirror of https://github.com/MontFerret/ferret.git synced 2025-08-13 19:52:52 +02:00

Refactor compiler and error handling: introduce multi-error support, improve snippet generation, standardize error listener, update compiler context, and enhance test cases.

This commit is contained in:
Tim Voronov
2025-07-29 18:35:55 -04:00
parent b2163ebc14
commit d24fac821b
20 changed files with 413 additions and 134 deletions

View File

@@ -30,40 +30,36 @@ func (c *Compiler) Compile(src *file.Source) (program *vm.Program, err error) {
return nil, core.NewEmptyQueryErr(src)
}
errorHandler := core.NewErrorHandler(src, 10)
defer func() {
if r := recover(); r != nil {
// find out exactly what the error was and set err
// Find out exactly what the error was and set err
var e *CompilationError
buf := make([]byte, 1024)
n := goruntime.Stack(buf, false)
stackTrace := string(buf[:n])
// Find out exactly what the error was and add the e
switch x := r.(type) {
case *CompilationError:
err = x
case *AggregatedCompilationErrors:
err = x
case string:
buf := make([]byte, 1024)
n := goruntime.Stack(buf, false)
stackTrace := string(buf[:n])
err = core.NewInternalErr(src, x+"\n"+stackTrace)
e = core.NewInternalErr(src, x+"\n"+stackTrace)
case error:
buf := make([]byte, 1024)
n := goruntime.Stack(buf, false)
stackTrace := string(buf[:n])
err = core.NewInternalErrWith(src, "unknown panic\n"+stackTrace, x)
e = core.NewInternalErrWith(src, "unknown panic\n"+stackTrace, x)
default:
buf := make([]byte, 1024)
n := goruntime.Stack(buf, false)
stackTrace := string(buf[:n])
err = core.NewInternalErr(src, "unknown panic\n"+stackTrace)
e = core.NewInternalErr(src, "unknown panic\n"+stackTrace)
}
errorHandler.Add(e)
program = nil
err = errorHandler.Unwrap()
}
}()
l := NewVisitor(src, errorHandler)
p := parser.New(src.Content())
p.AddErrorListener(newErrorListener())
l := NewVisitor(src)
p.AddErrorListener(newErrorListener(l.Ctx.Errors))
p.Visit(l)
if l.Ctx.Errors.HasErrors() {

View File

@@ -4,7 +4,7 @@ import "github.com/MontFerret/ferret/pkg/compiler/internal/core"
type ErrorKind = core.ErrorKind
type CompilationError = core.CompilationError
type AggregatedCompilationErrors = core.MultiCompilationError
type MultiCompilationError = core.MultiCompilationError
var (
UnknownError = core.UnknownError

View File

@@ -0,0 +1,71 @@
package compiler
import (
"github.com/antlr4-go/antlr/v4"
"github.com/MontFerret/ferret/pkg/compiler/internal/core"
"github.com/MontFerret/ferret/pkg/file"
)
type errorListener struct {
*antlr.DiagnosticErrorListener
handler *core.ErrorHandler
}
func newErrorListener(handler *core.ErrorHandler) antlr.ErrorListener {
return &errorListener{
DiagnosticErrorListener: antlr.NewDiagnosticErrorListener(false),
handler: handler,
}
}
func (d *errorListener) ReportAttemptingFullContext(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, conflictingAlts *antlr.BitSet, configs *antlr.ATNConfigSet) {
}
func (d *errorListener) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex, prediction int, configs *antlr.ATNConfigSet) {
}
func (d *errorListener) SyntaxError(_ antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) {
message, hint := core.ExplainSyntaxError(msg)
d.handler.Add(&core.CompilationError{
Message: message,
Kind: core.SyntaxError,
Location: d.findErrorLocation(offendingSymbol, e),
Hint: hint,
Cause: nil,
})
}
func (d *errorListener) findErrorLocation(offendingSymbol interface{}, e antlr.RecognitionException) *file.Location {
line := 0
column := 0
start := 0
end := 0
if token, ok := offendingSymbol.(antlr.Token); ok {
line = token.GetLine() - 1
column = token.GetColumn()
start = token.GetStart()
end = token.GetStop()
}
if line < 0 {
line = 0
}
if column < 0 {
column = 0
}
if start < 0 {
start = 0
}
if end < 0 {
end = 0
}
return file.NewLocation(line, column, start, end)
}

View File

@@ -25,10 +25,10 @@ type CompilerContext struct {
}
// NewCompilerContext initializes a new CompilerContext with default values.
func NewCompilerContext(src *file.Source) *CompilerContext {
func NewCompilerContext(src *file.Source, errors *core.ErrorHandler) *CompilerContext {
ctx := &CompilerContext{
Source: src,
Errors: core.NewErrorHandler(src, 10),
Errors: errors,
Emitter: core.NewEmitter(),
Registers: core.NewRegisterAllocator(),
Symbols: nil, // set later

View File

@@ -10,12 +10,12 @@ type (
ErrorKind string
CompilationError struct {
Message string
Kind ErrorKind
Source *file.Source
Location *file.Location
Hint string
Cause error
Message string `json:"message"`
Kind ErrorKind `json:"kind"`
Source *file.Source `json:"source"`
Location *file.Location `json:"location"`
Hint string `json:"hint"`
Cause error `json:"cause"`
}
)

View File

@@ -14,25 +14,24 @@ func FormatError(out io.Writer, e *CompilationError, indent int) {
fmt.Fprintf(out, "%s --> %s:%d:%d\n", prefix, e.Source.Name(), e.Location.Line(), e.Location.Column())
// Determine padding width for line number column
lineNum := e.Location.Line()
lineNoWidth := len(fmt.Sprintf("%d", lineNum))
// Pipe line
lineNoWidth := len(fmt.Sprintf("%d", e.Location.Line()))
fmt.Fprintf(out, "%s%s\n", prefix, strings.Repeat(" ", lineNoWidth)+" |")
// Code line
lineText, caret := e.Source.Snippet(*e.Location)
fmt.Fprintf(out, "%s%*d | %s\n", prefix, lineNoWidth, lineNum, lineText)
// Multi-line snippet with context
snippetLines := e.Source.Snippet(*e.Location)
// Caret line
fmt.Fprintf(out, "%s%s | %s\n", prefix, strings.Repeat(" ", lineNoWidth), caret)
for _, sl := range snippetLines {
fmt.Fprintf(out, "%s%*d | %s\n", prefix, lineNoWidth, sl.Line, sl.Text)
if sl.Caret != "" {
fmt.Fprintf(out, "%s%s | %s\n", prefix, strings.Repeat(" ", lineNoWidth), sl.Caret)
}
}
// Hint
if e.Hint != "" {
fmt.Fprintf(out, "%sHint: %s\n", prefix, e.Hint)
}
// Cause
if e.Cause != nil {
if nested, ok := e.Cause.(*CompilationError); ok {
fmt.Fprintf(out, "%sCaused by:\n", prefix)

View File

@@ -14,23 +14,6 @@ type ErrorHandler struct {
threshold int
}
func ParserLocation(ctx antlr.ParserRuleContext) *file.Location {
start := ctx.GetStart()
stop := ctx.GetStop()
// Defensive: avoid nil dereference
if start == nil || stop == nil {
return file.NewLocation(0, 0, 0, 0)
}
return file.NewLocation(
start.GetLine(),
start.GetColumn()+1,
start.GetStart(),
stop.GetStop(),
)
}
func NewErrorHandler(src *file.Source, threshold int) *ErrorHandler {
if threshold <= 0 {
threshold = 10
@@ -68,20 +51,23 @@ func (h *ErrorHandler) Add(err *CompilationError) {
return
}
// If the number of errors exceeds the threshold, we stop adding new errors
if len(h.errors) > h.threshold {
return
}
if err.Source == nil {
err.Source = h.src
}
h.errors = append(h.errors, err)
if len(h.errors) >= h.threshold {
if len(h.errors) == h.threshold {
h.errors = append(h.errors, &CompilationError{
Message: "Too many errors",
Kind: SemanticError,
Hint: "Too many errors encountered during compilation.",
})
panic(h.Unwrap())
}
}
@@ -89,7 +75,7 @@ func (h *ErrorHandler) UnexpectedToken(ctx antlr.ParserRuleContext) {
h.Add(&CompilationError{
Message: fmt.Sprintf("Unexpected token '%s'", ctx.GetText()),
Source: h.src,
Location: ParserLocation(ctx),
Location: LocationFromRuleContext(ctx),
Kind: SyntaxError,
})
}
@@ -99,16 +85,16 @@ func (h *ErrorHandler) VariableNotUnique(ctx antlr.ParserRuleContext, name strin
h.Add(&CompilationError{
Message: fmt.Sprintf("Variable '%s' is already defined", name),
Source: h.src,
Location: ParserLocation(ctx),
Location: LocationFromRuleContext(ctx),
Kind: NameError,
})
}
func (h *ErrorHandler) VariableNotFound(ctx antlr.ParserRuleContext, name string) {
func (h *ErrorHandler) VariableNotFound(ctx antlr.Token, name string) {
h.Add(&CompilationError{
Message: fmt.Sprintf("Variable '%s' is not defined", name),
Source: h.src,
Location: ParserLocation(ctx),
Location: LocationFromToken(ctx),
Kind: NameError,
})
}

View File

@@ -0,0 +1,37 @@
package core
import (
"github.com/antlr4-go/antlr/v4"
"github.com/MontFerret/ferret/pkg/file"
)
func LocationFromRuleContext(ctx antlr.ParserRuleContext) *file.Location {
start := ctx.GetStart()
stop := ctx.GetStop()
// Defensive: avoid nil dereference
if start == nil || stop == nil {
return file.EmptyLocation()
}
return file.NewLocation(
start.GetLine(),
start.GetColumn()+1,
start.GetStart(),
stop.GetStop(),
)
}
func LocationFromToken(token antlr.Token) *file.Location {
if token == nil {
return file.EmptyLocation()
}
return file.NewLocation(
token.GetLine(),
token.GetColumn()+1,
token.GetStart(),
token.GetStop(),
)
}

View File

@@ -0,0 +1,57 @@
package core
import (
"regexp"
"strings"
)
func ExplainSyntaxError(err string) (msg string, hint string) {
var matched bool
parsers := []func(string) (string, string, bool){
explainNoViableAltError,
explainExtraneousError,
}
for _, parser := range parsers {
msg, hint, matched = parser(err)
if matched {
return
}
}
msg = "Syntax error"
hint = "Check the syntax of your code. It may be missing a keyword, operator, or punctuation."
return
}
func explainExtraneousError(err string) (msg string, hint string, matched bool) {
recognizer := regexp.MustCompile("extraneous input '<EOF>' expecting")
if !recognizer.MatchString(err) {
return "", "", false
}
return "Extraneous input at end of file", "Check the syntax of your code. It may be missing a keyword, operator, or punctuation", true
}
func explainNoViableAltError(err string) (msg string, hint string, matched bool) {
recognizer := regexp.MustCompile("no viable alternative at input '(\\w+).+'")
matches := recognizer.FindAllStringSubmatch(err, -1)
if len(matches) == 0 {
return "", "", false
}
keyword := matches[0][1]
switch strings.ToLower(keyword) {
case "return":
msg = "Unexpected 'return' keyword"
hint = "Did you mean to return a value?"
}
return
}

View File

@@ -474,7 +474,7 @@ func (c *ExprCompiler) CompileVariable(ctx fql.IVariableContext) vm.Operand {
op, _, found := c.ctx.Symbols.Resolve(name)
if !found {
c.ctx.Errors.VariableNotFound(ctx, name)
c.ctx.Errors.VariableNotFound(ctx.Identifier().GetSymbol(), name)
return vm.NoopOperand
}

View File

@@ -25,6 +25,10 @@ func NewStmtCompiler(ctx *CompilerContext) *StmtCompiler {
// Parameters:
// - ctx: The body context from the AST
func (c *StmtCompiler) Compile(ctx fql.IBodyContext) {
if ctx == nil {
return
}
// Process all statements in the body
for _, statement := range ctx.AllBodyStatement() {
c.CompileBodyStatement(statement)
@@ -40,6 +44,10 @@ func (c *StmtCompiler) Compile(ctx fql.IBodyContext) {
// Parameters:
// - ctx: The body statement context from the AST
func (c *StmtCompiler) CompileBodyStatement(ctx fql.IBodyStatementContext) {
if ctx == nil {
return
}
// Handle variable declarations (e.g., LET x = 1)
if vd := ctx.VariableDeclaration(); vd != nil {
c.CompileVariableDeclaration(vd)
@@ -58,6 +66,10 @@ func (c *StmtCompiler) CompileBodyStatement(ctx fql.IBodyStatementContext) {
// Parameters:
// - ctx: The body expression context from the AST
func (c *StmtCompiler) CompileBodyExpression(ctx fql.IBodyExpressionContext) {
if ctx == nil {
return
}
// Handle FOR expressions (e.g., FOR x IN y RETURN z)
if fe := ctx.ForExpression(); fe != nil {
// Compile the FOR loop and get the destination register
@@ -95,6 +107,10 @@ func (c *StmtCompiler) CompileBodyExpression(ctx fql.IBodyExpressionContext) {
// - An operand representing the register where the variable value is stored,
// or NoopOperand if the variable is ignored
func (c *StmtCompiler) CompileVariableDeclaration(ctx fql.IVariableDeclarationContext) vm.Operand {
if ctx == nil {
return vm.NoopOperand
}
// Start with the ignore pseudo-variable as the default name
name := core.IgnorePseudoVariable
@@ -158,6 +174,10 @@ func (c *StmtCompiler) CompileVariableDeclaration(ctx fql.IVariableDeclarationCo
// Returns:
// - An operand representing the register where the function call result is stored
func (c *StmtCompiler) CompileFunctionCall(ctx fql.IFunctionCallExpressionContext) vm.Operand {
if ctx == nil {
return vm.NoopOperand
}
// Delegate to the expression compiler for function call compilation
return c.ctx.ExprCompiler.CompileFunctionCallExpression(ctx)
}

View File

@@ -1,26 +0,0 @@
package compiler
import (
"github.com/antlr4-go/antlr/v4"
"github.com/pkg/errors"
)
type errorListener struct {
*antlr.DiagnosticErrorListener
}
func newErrorListener() antlr.ErrorListener {
return &errorListener{
antlr.NewDiagnosticErrorListener(false),
}
}
func (d *errorListener) ReportAttemptingFullContext(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, conflictingAlts *antlr.BitSet, configs *antlr.ATNConfigSet) {
}
func (d *errorListener) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex, prediction int, configs *antlr.ATNConfigSet) {
}
func (d *errorListener) SyntaxError(_ antlr.Recognizer, _ interface{}, line, column int, msg string, _ antlr.RecognitionException) {
panic(errors.Errorf("%s at %d:%d", msg, line, column))
}

View File

@@ -2,6 +2,7 @@ package compiler
import (
"github.com/MontFerret/ferret/pkg/compiler/internal"
"github.com/MontFerret/ferret/pkg/compiler/internal/core"
"github.com/MontFerret/ferret/pkg/file"
"github.com/MontFerret/ferret/pkg/parser/fql"
)
@@ -11,10 +12,10 @@ type Visitor struct {
Ctx *internal.CompilerContext
}
func NewVisitor(src *file.Source) *Visitor {
func NewVisitor(src *file.Source, errors *core.ErrorHandler) *Visitor {
v := new(Visitor)
v.BaseFqlParserVisitor = new(fql.BaseFqlParserVisitor)
v.Ctx = internal.NewCompilerContext(src)
v.Ctx = internal.NewCompilerContext(src, errors)
return v
}

47
pkg/file/snippet.go Normal file
View File

@@ -0,0 +1,47 @@
package file
import "strings"
type Snippet struct {
Line int
Text string
Caret string
}
func NewSnippet(src []string, line int) Snippet {
text := src[line-1]
return Snippet{
Line: line,
Text: text,
}
}
func NewSnippetWithCaret(src []string, loc Location) Snippet {
if loc.line <= 0 || loc.line > len(src) {
return Snippet{}
}
srcLine := src[loc.Line()-1]
runes := []rune(srcLine)
column := loc.Column()
// Clamp column to within bounds (1-based)
if column < 1 {
column = 1
}
if column > len(runes)+1 {
column = len(runes) + 1
}
// Caret must align with visual column (accounting for tabs)
visualOffset := computeVisualOffset(srcLine, column)
caretLine := strings.Repeat("_", visualOffset) + "^"
return Snippet{
Line: loc.line,
Text: srcLine,
Caret: caretLine,
}
}

View File

@@ -38,32 +38,31 @@ func (s *Source) Content() string {
return s.text
}
func (s *Source) Snippet(loc Location) (line string, caret string) {
if s.Empty() || loc.Line() <= 0 || loc.Line() > len(s.lines) {
return "", ""
func (s *Source) Snippet(loc Location) []Snippet {
if s.Empty() {
return []Snippet{}
}
srcLine := s.lines[loc.Line()-1]
runes := []rune(srcLine)
column := loc.Column()
lineNum := loc.Line()
lines := s.lines
var result []Snippet
// Clamp column to within bounds (1-based)
if column < 1 {
column = 1
}
if column > len(runes)+1 {
column = len(runes) + 1
// Show previous line if it exists
if lineNum > 1 {
result = append(result, NewSnippet(lines, lineNum-1))
}
// Caret must align with visual column (accounting for tabs)
visualOffset := s.computeVisualOffset(srcLine, column)
result = append(result, NewSnippetWithCaret(lines, loc))
caretLine := strings.Repeat(" ", visualOffset) + "^"
// Show next line if it exists
if lineNum < len(lines) {
result = append(result, NewSnippet(lines, lineNum+1))
}
return srcLine, caretLine
return result
}
func (s *Source) computeVisualOffset(line string, column int) int {
func computeVisualOffset(line string, column int) int {
runes := []rune(line)
offset := 0
tabWidth := 4

View File

@@ -8,10 +8,18 @@ import (
. "github.com/smartystreets/goconvey/convey"
)
type ExpectedError struct {
Message string
Kind compiler.ErrorKind
}
type (
ExpectedError struct {
Message string
Kind compiler.ErrorKind
Hint string
}
ExpectedMultiError struct {
Number int
Errors []*ExpectedError
}
)
func ArePtrsEqual(expected, actual any) bool {
if expected == nil || actual == nil {
@@ -37,35 +45,85 @@ func ShouldHaveSameItems(actual any, expected ...any) string {
return ""
}
func ShouldBeCompilationError(actual any, expected ...any) string {
err, ok := actual.(*compiler.CompilationError)
if !ok {
func assertExpectedError(actual *compiler.CompilationError, expected *ExpectedError) string {
if actual == nil {
return "expected a compilation error"
}
if expected.Kind != "" && actual.Kind != expected.Kind {
return fmt.Sprintf("expected error kind %s, got %s", expected.Kind, actual.Kind)
}
if expected.Message != "" && actual.Message != expected.Message {
return fmt.Sprintf("expected error message '%s', got '%s'", expected.Message, actual.Message)
}
if expected.Hint != "" && actual.Hint != expected.Hint {
return fmt.Sprintf("expected error hint '%s', got '%s'", expected.Hint, actual.Hint)
}
return ""
}
func assertExpectedErrors(actual *compiler.MultiCompilationError, expected *ExpectedMultiError) string {
if actual == nil {
return "expected a multi compilation error"
}
if expected.Number > 0 && len(actual.Errors) != expected.Number {
return fmt.Sprintf("expected %d errors, got %d", expected.Number, len(actual.Errors))
}
if len(expected.Errors) > 0 {
for i, err := range actual.Errors {
if i >= len(expected.Errors) {
break
}
msg := assertExpectedError(err, expected.Errors[i])
if msg != "" {
return msg
}
}
}
return ""
}
func ShouldBeCompilationError(actual any, expected ...any) string {
var msg string
switch ex := expected[0].(type) {
case *ExpectedError:
if ex.Kind != "" {
msg = ShouldEqual(err.Kind, ex.Kind)
err, ok := actual.(*compiler.CompilationError)
if !ok {
return "expected a compilation error"
}
if msg == "" {
msg = ShouldEqual(err.Message, ex.Message)
msg = assertExpectedError(err, ex)
if msg != "" {
fmt.Println(err.Format())
}
break
case string:
msg = ShouldEqual(err.Message, ex)
case *ExpectedMultiError:
err, ok := actual.(*compiler.MultiCompilationError)
if !ok {
return "expected a multi compilation error"
}
msg = assertExpectedErrors(err, ex)
if msg != "" {
fmt.Println(err.Format())
}
default:
msg = "expected a compilation error"
}
if msg != "" {
fmt.Println(err.Format())
}
return msg
}

View File

@@ -8,6 +8,23 @@ import (
func TestErrors(t *testing.T) {
RunUseCases(t, []UseCase{
ErrorCase(
`
LET i = NONE
`, E{
Kind: compiler.SyntaxError,
Message: "Variable 'i' is already defined",
//Message: "Extraneous input at end of file",
}, "Syntax error: missing return statement"),
ErrorCase(
`
LET i = NONE
RETURN
`, E{
Kind: compiler.SyntaxError,
//Message: "Unexpected 'return' keyword",
//Hint: "Did you mean to return a value?",
}, "Syntax error: missing return value"),
ErrorCase(
`
LET i = NONE
@@ -15,7 +32,7 @@ func TestErrors(t *testing.T) {
RETURN i
`, E{
Kind: compiler.NameError,
Message: "Variable 'i' is already defined",
Message: "Variable '' is already defined",
}, "Global variable not unique"),
ErrorCase(
`

View File

@@ -8,6 +8,7 @@ import (
type BC = []vm.Instruction
type UseCase = base.TestCase
type E = base.ExpectedError
type ME = base.ExpectedMultiError
var I = vm.NewInstruction
var C = vm.NewConstant

View File

@@ -42,6 +42,26 @@ func ErrorCase(expression string, expected base.ExpectedError, desc ...string) U
return uc
}
func SkipErrorCase(expression string, expected base.ExpectedError, desc ...string) UseCase {
return Skip(ErrorCase(expression, expected, desc...))
}
func MultiErrorCase(expression string, expected base.ExpectedMultiError, desc ...string) UseCase {
uc := NewCase(expression, &expected, nil, desc...)
uc.PreAssertion = base.ShouldBeCompilationError
uc.Assertions = []convey.Assertion{
func(actual any, expected ...any) string {
return "expected compilation error"
},
}
return uc
}
func SkipMultiErrorCase(expression string, expected base.ExpectedMultiError, desc ...string) UseCase {
return Skip(MultiErrorCase(expression, expected, desc...))
}
func SkipByteCodeCase(expression string, expected []vm.Instruction, desc ...string) UseCase {
return Skip(ByteCodeCase(expression, expected, desc...))
}

View File

@@ -17,10 +17,6 @@ func TestForIn(t *testing.T) {
// ShouldEqualJSON,
//},
RunUseCases(t, []UseCase{
SkipCaseCompilationError(`
FOR foo IN foo
RETURN foo
`, "Should not compile FOR foo IN foo"),
CaseArray(`
FOR i IN 1..5
RETURN i