From d24fac821bfdbf7af1ae3e6c044ec48ec278efbf Mon Sep 17 00:00:00 2001 From: Tim Voronov Date: Tue, 29 Jul 2025 18:35:55 -0400 Subject: [PATCH] Refactor compiler and error handling: introduce multi-error support, improve snippet generation, standardize error listener, update compiler context, and enhance test cases. --- pkg/compiler/compiler.go | 38 ++++---- pkg/compiler/error.go | 2 +- pkg/compiler/error_listener.go | 71 ++++++++++++++ pkg/compiler/internal/context.go | 4 +- pkg/compiler/internal/core/error.go | 12 +-- pkg/compiler/internal/core/error_formatter.go | 21 ++--- pkg/compiler/internal/core/error_handler.go | 34 ++----- pkg/compiler/internal/core/error_helpers.go | 37 ++++++++ .../internal/core/error_recognizer.go | 57 +++++++++++ pkg/compiler/internal/expr.go | 2 +- pkg/compiler/internal/stmt.go | 20 ++++ pkg/compiler/listener.go | 26 ----- pkg/compiler/visitor.go | 5 +- pkg/file/snippet.go | 47 ++++++++++ pkg/file/source.go | 33 ++++--- test/integration/base/assertions.go | 94 +++++++++++++++---- .../compiler/compiler_errors_test.go | 19 +++- test/integration/compiler/shortcuts.go | 1 + test/integration/compiler/test_case.go | 20 ++++ test/integration/vm/vm_for_in_test.go | 4 - 20 files changed, 413 insertions(+), 134 deletions(-) create mode 100644 pkg/compiler/error_listener.go create mode 100644 pkg/compiler/internal/core/error_helpers.go create mode 100644 pkg/compiler/internal/core/error_recognizer.go delete mode 100644 pkg/compiler/listener.go create mode 100644 pkg/file/snippet.go diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index debb0b94..922ea0a7 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -30,40 +30,36 @@ func (c *Compiler) Compile(src *file.Source) (program *vm.Program, err error) { return nil, core.NewEmptyQueryErr(src) } + errorHandler := core.NewErrorHandler(src, 10) + defer func() { if r := recover(); r != nil { - // find out exactly what the error was and set err - // Find out exactly what the error was and set err + var e *CompilationError + + buf := make([]byte, 1024) + n := goruntime.Stack(buf, false) + stackTrace := string(buf[:n]) + + // Find out exactly what the error was and add the e switch x := r.(type) { - case *CompilationError: - err = x - case *AggregatedCompilationErrors: - err = x case string: - buf := make([]byte, 1024) - n := goruntime.Stack(buf, false) - stackTrace := string(buf[:n]) - err = core.NewInternalErr(src, x+"\n"+stackTrace) + e = core.NewInternalErr(src, x+"\n"+stackTrace) case error: - buf := make([]byte, 1024) - n := goruntime.Stack(buf, false) - stackTrace := string(buf[:n]) - err = core.NewInternalErrWith(src, "unknown panic\n"+stackTrace, x) + e = core.NewInternalErrWith(src, "unknown panic\n"+stackTrace, x) default: - buf := make([]byte, 1024) - n := goruntime.Stack(buf, false) - stackTrace := string(buf[:n]) - err = core.NewInternalErr(src, "unknown panic\n"+stackTrace) + e = core.NewInternalErr(src, "unknown panic\n"+stackTrace) } + errorHandler.Add(e) + program = nil + err = errorHandler.Unwrap() } }() + l := NewVisitor(src, errorHandler) p := parser.New(src.Content()) - p.AddErrorListener(newErrorListener()) - - l := NewVisitor(src) + p.AddErrorListener(newErrorListener(l.Ctx.Errors)) p.Visit(l) if l.Ctx.Errors.HasErrors() { diff --git a/pkg/compiler/error.go b/pkg/compiler/error.go index 169a367c..4153ad26 100644 --- a/pkg/compiler/error.go +++ b/pkg/compiler/error.go @@ -4,7 +4,7 @@ import "github.com/MontFerret/ferret/pkg/compiler/internal/core" type ErrorKind = core.ErrorKind type CompilationError = core.CompilationError -type AggregatedCompilationErrors = core.MultiCompilationError +type MultiCompilationError = core.MultiCompilationError var ( UnknownError = core.UnknownError diff --git a/pkg/compiler/error_listener.go b/pkg/compiler/error_listener.go new file mode 100644 index 00000000..849fb428 --- /dev/null +++ b/pkg/compiler/error_listener.go @@ -0,0 +1,71 @@ +package compiler + +import ( + "github.com/antlr4-go/antlr/v4" + + "github.com/MontFerret/ferret/pkg/compiler/internal/core" + "github.com/MontFerret/ferret/pkg/file" +) + +type errorListener struct { + *antlr.DiagnosticErrorListener + + handler *core.ErrorHandler +} + +func newErrorListener(handler *core.ErrorHandler) antlr.ErrorListener { + return &errorListener{ + DiagnosticErrorListener: antlr.NewDiagnosticErrorListener(false), + handler: handler, + } +} + +func (d *errorListener) ReportAttemptingFullContext(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, conflictingAlts *antlr.BitSet, configs *antlr.ATNConfigSet) { +} + +func (d *errorListener) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex, prediction int, configs *antlr.ATNConfigSet) { +} + +func (d *errorListener) SyntaxError(_ antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) { + message, hint := core.ExplainSyntaxError(msg) + + d.handler.Add(&core.CompilationError{ + Message: message, + Kind: core.SyntaxError, + Location: d.findErrorLocation(offendingSymbol, e), + Hint: hint, + Cause: nil, + }) +} + +func (d *errorListener) findErrorLocation(offendingSymbol interface{}, e antlr.RecognitionException) *file.Location { + line := 0 + column := 0 + start := 0 + end := 0 + + if token, ok := offendingSymbol.(antlr.Token); ok { + line = token.GetLine() - 1 + column = token.GetColumn() + start = token.GetStart() + end = token.GetStop() + } + + if line < 0 { + line = 0 + } + + if column < 0 { + column = 0 + } + + if start < 0 { + start = 0 + } + + if end < 0 { + end = 0 + } + + return file.NewLocation(line, column, start, end) +} diff --git a/pkg/compiler/internal/context.go b/pkg/compiler/internal/context.go index 262f57c1..99a535d6 100644 --- a/pkg/compiler/internal/context.go +++ b/pkg/compiler/internal/context.go @@ -25,10 +25,10 @@ type CompilerContext struct { } // NewCompilerContext initializes a new CompilerContext with default values. -func NewCompilerContext(src *file.Source) *CompilerContext { +func NewCompilerContext(src *file.Source, errors *core.ErrorHandler) *CompilerContext { ctx := &CompilerContext{ Source: src, - Errors: core.NewErrorHandler(src, 10), + Errors: errors, Emitter: core.NewEmitter(), Registers: core.NewRegisterAllocator(), Symbols: nil, // set later diff --git a/pkg/compiler/internal/core/error.go b/pkg/compiler/internal/core/error.go index 3ebf38dc..74477ff2 100644 --- a/pkg/compiler/internal/core/error.go +++ b/pkg/compiler/internal/core/error.go @@ -10,12 +10,12 @@ type ( ErrorKind string CompilationError struct { - Message string - Kind ErrorKind - Source *file.Source - Location *file.Location - Hint string - Cause error + Message string `json:"message"` + Kind ErrorKind `json:"kind"` + Source *file.Source `json:"source"` + Location *file.Location `json:"location"` + Hint string `json:"hint"` + Cause error `json:"cause"` } ) diff --git a/pkg/compiler/internal/core/error_formatter.go b/pkg/compiler/internal/core/error_formatter.go index b4b10f26..01990d56 100644 --- a/pkg/compiler/internal/core/error_formatter.go +++ b/pkg/compiler/internal/core/error_formatter.go @@ -14,25 +14,24 @@ func FormatError(out io.Writer, e *CompilationError, indent int) { fmt.Fprintf(out, "%s --> %s:%d:%d\n", prefix, e.Source.Name(), e.Location.Line(), e.Location.Column()) // Determine padding width for line number column - lineNum := e.Location.Line() - lineNoWidth := len(fmt.Sprintf("%d", lineNum)) - - // Pipe line + lineNoWidth := len(fmt.Sprintf("%d", e.Location.Line())) fmt.Fprintf(out, "%s%s\n", prefix, strings.Repeat(" ", lineNoWidth)+" |") - // Code line - lineText, caret := e.Source.Snippet(*e.Location) - fmt.Fprintf(out, "%s%*d | %s\n", prefix, lineNoWidth, lineNum, lineText) + // Multi-line snippet with context + snippetLines := e.Source.Snippet(*e.Location) - // Caret line - fmt.Fprintf(out, "%s%s | %s\n", prefix, strings.Repeat(" ", lineNoWidth), caret) + for _, sl := range snippetLines { + fmt.Fprintf(out, "%s%*d | %s\n", prefix, lineNoWidth, sl.Line, sl.Text) + + if sl.Caret != "" { + fmt.Fprintf(out, "%s%s | %s\n", prefix, strings.Repeat(" ", lineNoWidth), sl.Caret) + } + } - // Hint if e.Hint != "" { fmt.Fprintf(out, "%sHint: %s\n", prefix, e.Hint) } - // Cause if e.Cause != nil { if nested, ok := e.Cause.(*CompilationError); ok { fmt.Fprintf(out, "%sCaused by:\n", prefix) diff --git a/pkg/compiler/internal/core/error_handler.go b/pkg/compiler/internal/core/error_handler.go index 4a061202..86df4031 100644 --- a/pkg/compiler/internal/core/error_handler.go +++ b/pkg/compiler/internal/core/error_handler.go @@ -14,23 +14,6 @@ type ErrorHandler struct { threshold int } -func ParserLocation(ctx antlr.ParserRuleContext) *file.Location { - start := ctx.GetStart() - stop := ctx.GetStop() - - // Defensive: avoid nil dereference - if start == nil || stop == nil { - return file.NewLocation(0, 0, 0, 0) - } - - return file.NewLocation( - start.GetLine(), - start.GetColumn()+1, - start.GetStart(), - stop.GetStop(), - ) -} - func NewErrorHandler(src *file.Source, threshold int) *ErrorHandler { if threshold <= 0 { threshold = 10 @@ -68,20 +51,23 @@ func (h *ErrorHandler) Add(err *CompilationError) { return } + // If the number of errors exceeds the threshold, we stop adding new errors + if len(h.errors) > h.threshold { + return + } + if err.Source == nil { err.Source = h.src } h.errors = append(h.errors, err) - if len(h.errors) >= h.threshold { + if len(h.errors) == h.threshold { h.errors = append(h.errors, &CompilationError{ Message: "Too many errors", Kind: SemanticError, Hint: "Too many errors encountered during compilation.", }) - - panic(h.Unwrap()) } } @@ -89,7 +75,7 @@ func (h *ErrorHandler) UnexpectedToken(ctx antlr.ParserRuleContext) { h.Add(&CompilationError{ Message: fmt.Sprintf("Unexpected token '%s'", ctx.GetText()), Source: h.src, - Location: ParserLocation(ctx), + Location: LocationFromRuleContext(ctx), Kind: SyntaxError, }) } @@ -99,16 +85,16 @@ func (h *ErrorHandler) VariableNotUnique(ctx antlr.ParserRuleContext, name strin h.Add(&CompilationError{ Message: fmt.Sprintf("Variable '%s' is already defined", name), Source: h.src, - Location: ParserLocation(ctx), + Location: LocationFromRuleContext(ctx), Kind: NameError, }) } -func (h *ErrorHandler) VariableNotFound(ctx antlr.ParserRuleContext, name string) { +func (h *ErrorHandler) VariableNotFound(ctx antlr.Token, name string) { h.Add(&CompilationError{ Message: fmt.Sprintf("Variable '%s' is not defined", name), Source: h.src, - Location: ParserLocation(ctx), + Location: LocationFromToken(ctx), Kind: NameError, }) } diff --git a/pkg/compiler/internal/core/error_helpers.go b/pkg/compiler/internal/core/error_helpers.go new file mode 100644 index 00000000..22650b75 --- /dev/null +++ b/pkg/compiler/internal/core/error_helpers.go @@ -0,0 +1,37 @@ +package core + +import ( + "github.com/antlr4-go/antlr/v4" + + "github.com/MontFerret/ferret/pkg/file" +) + +func LocationFromRuleContext(ctx antlr.ParserRuleContext) *file.Location { + start := ctx.GetStart() + stop := ctx.GetStop() + + // Defensive: avoid nil dereference + if start == nil || stop == nil { + return file.EmptyLocation() + } + + return file.NewLocation( + start.GetLine(), + start.GetColumn()+1, + start.GetStart(), + stop.GetStop(), + ) +} + +func LocationFromToken(token antlr.Token) *file.Location { + if token == nil { + return file.EmptyLocation() + } + + return file.NewLocation( + token.GetLine(), + token.GetColumn()+1, + token.GetStart(), + token.GetStop(), + ) +} diff --git a/pkg/compiler/internal/core/error_recognizer.go b/pkg/compiler/internal/core/error_recognizer.go new file mode 100644 index 00000000..06ea560e --- /dev/null +++ b/pkg/compiler/internal/core/error_recognizer.go @@ -0,0 +1,57 @@ +package core + +import ( + "regexp" + "strings" +) + +func ExplainSyntaxError(err string) (msg string, hint string) { + var matched bool + parsers := []func(string) (string, string, bool){ + explainNoViableAltError, + explainExtraneousError, + } + + for _, parser := range parsers { + msg, hint, matched = parser(err) + + if matched { + return + } + } + + msg = "Syntax error" + hint = "Check the syntax of your code. It may be missing a keyword, operator, or punctuation." + + return +} + +func explainExtraneousError(err string) (msg string, hint string, matched bool) { + recognizer := regexp.MustCompile("extraneous input '' expecting") + + if !recognizer.MatchString(err) { + return "", "", false + } + + return "Extraneous input at end of file", "Check the syntax of your code. It may be missing a keyword, operator, or punctuation", true +} + +func explainNoViableAltError(err string) (msg string, hint string, matched bool) { + recognizer := regexp.MustCompile("no viable alternative at input '(\\w+).+'") + + matches := recognizer.FindAllStringSubmatch(err, -1) + + if len(matches) == 0 { + return "", "", false + } + + keyword := matches[0][1] + + switch strings.ToLower(keyword) { + case "return": + msg = "Unexpected 'return' keyword" + hint = "Did you mean to return a value?" + } + + return +} diff --git a/pkg/compiler/internal/expr.go b/pkg/compiler/internal/expr.go index 84981a87..3d2b0cad 100644 --- a/pkg/compiler/internal/expr.go +++ b/pkg/compiler/internal/expr.go @@ -474,7 +474,7 @@ func (c *ExprCompiler) CompileVariable(ctx fql.IVariableContext) vm.Operand { op, _, found := c.ctx.Symbols.Resolve(name) if !found { - c.ctx.Errors.VariableNotFound(ctx, name) + c.ctx.Errors.VariableNotFound(ctx.Identifier().GetSymbol(), name) return vm.NoopOperand } diff --git a/pkg/compiler/internal/stmt.go b/pkg/compiler/internal/stmt.go index b31b279e..850a30fe 100644 --- a/pkg/compiler/internal/stmt.go +++ b/pkg/compiler/internal/stmt.go @@ -25,6 +25,10 @@ func NewStmtCompiler(ctx *CompilerContext) *StmtCompiler { // Parameters: // - ctx: The body context from the AST func (c *StmtCompiler) Compile(ctx fql.IBodyContext) { + if ctx == nil { + return + } + // Process all statements in the body for _, statement := range ctx.AllBodyStatement() { c.CompileBodyStatement(statement) @@ -40,6 +44,10 @@ func (c *StmtCompiler) Compile(ctx fql.IBodyContext) { // Parameters: // - ctx: The body statement context from the AST func (c *StmtCompiler) CompileBodyStatement(ctx fql.IBodyStatementContext) { + if ctx == nil { + return + } + // Handle variable declarations (e.g., LET x = 1) if vd := ctx.VariableDeclaration(); vd != nil { c.CompileVariableDeclaration(vd) @@ -58,6 +66,10 @@ func (c *StmtCompiler) CompileBodyStatement(ctx fql.IBodyStatementContext) { // Parameters: // - ctx: The body expression context from the AST func (c *StmtCompiler) CompileBodyExpression(ctx fql.IBodyExpressionContext) { + if ctx == nil { + return + } + // Handle FOR expressions (e.g., FOR x IN y RETURN z) if fe := ctx.ForExpression(); fe != nil { // Compile the FOR loop and get the destination register @@ -95,6 +107,10 @@ func (c *StmtCompiler) CompileBodyExpression(ctx fql.IBodyExpressionContext) { // - An operand representing the register where the variable value is stored, // or NoopOperand if the variable is ignored func (c *StmtCompiler) CompileVariableDeclaration(ctx fql.IVariableDeclarationContext) vm.Operand { + if ctx == nil { + return vm.NoopOperand + } + // Start with the ignore pseudo-variable as the default name name := core.IgnorePseudoVariable @@ -158,6 +174,10 @@ func (c *StmtCompiler) CompileVariableDeclaration(ctx fql.IVariableDeclarationCo // Returns: // - An operand representing the register where the function call result is stored func (c *StmtCompiler) CompileFunctionCall(ctx fql.IFunctionCallExpressionContext) vm.Operand { + if ctx == nil { + return vm.NoopOperand + } + // Delegate to the expression compiler for function call compilation return c.ctx.ExprCompiler.CompileFunctionCallExpression(ctx) } diff --git a/pkg/compiler/listener.go b/pkg/compiler/listener.go deleted file mode 100644 index 11163e99..00000000 --- a/pkg/compiler/listener.go +++ /dev/null @@ -1,26 +0,0 @@ -package compiler - -import ( - "github.com/antlr4-go/antlr/v4" - "github.com/pkg/errors" -) - -type errorListener struct { - *antlr.DiagnosticErrorListener -} - -func newErrorListener() antlr.ErrorListener { - return &errorListener{ - antlr.NewDiagnosticErrorListener(false), - } -} - -func (d *errorListener) ReportAttemptingFullContext(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, conflictingAlts *antlr.BitSet, configs *antlr.ATNConfigSet) { -} - -func (d *errorListener) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex, prediction int, configs *antlr.ATNConfigSet) { -} - -func (d *errorListener) SyntaxError(_ antlr.Recognizer, _ interface{}, line, column int, msg string, _ antlr.RecognitionException) { - panic(errors.Errorf("%s at %d:%d", msg, line, column)) -} diff --git a/pkg/compiler/visitor.go b/pkg/compiler/visitor.go index 93f6288d..9315435b 100644 --- a/pkg/compiler/visitor.go +++ b/pkg/compiler/visitor.go @@ -2,6 +2,7 @@ package compiler import ( "github.com/MontFerret/ferret/pkg/compiler/internal" + "github.com/MontFerret/ferret/pkg/compiler/internal/core" "github.com/MontFerret/ferret/pkg/file" "github.com/MontFerret/ferret/pkg/parser/fql" ) @@ -11,10 +12,10 @@ type Visitor struct { Ctx *internal.CompilerContext } -func NewVisitor(src *file.Source) *Visitor { +func NewVisitor(src *file.Source, errors *core.ErrorHandler) *Visitor { v := new(Visitor) v.BaseFqlParserVisitor = new(fql.BaseFqlParserVisitor) - v.Ctx = internal.NewCompilerContext(src) + v.Ctx = internal.NewCompilerContext(src, errors) return v } diff --git a/pkg/file/snippet.go b/pkg/file/snippet.go new file mode 100644 index 00000000..98bd3a28 --- /dev/null +++ b/pkg/file/snippet.go @@ -0,0 +1,47 @@ +package file + +import "strings" + +type Snippet struct { + Line int + Text string + Caret string +} + +func NewSnippet(src []string, line int) Snippet { + text := src[line-1] + + return Snippet{ + Line: line, + Text: text, + } +} + +func NewSnippetWithCaret(src []string, loc Location) Snippet { + if loc.line <= 0 || loc.line > len(src) { + return Snippet{} + } + + srcLine := src[loc.Line()-1] + runes := []rune(srcLine) + column := loc.Column() + + // Clamp column to within bounds (1-based) + if column < 1 { + column = 1 + } + + if column > len(runes)+1 { + column = len(runes) + 1 + } + + // Caret must align with visual column (accounting for tabs) + visualOffset := computeVisualOffset(srcLine, column) + caretLine := strings.Repeat("_", visualOffset) + "^" + + return Snippet{ + Line: loc.line, + Text: srcLine, + Caret: caretLine, + } +} diff --git a/pkg/file/source.go b/pkg/file/source.go index f4860793..0d305545 100644 --- a/pkg/file/source.go +++ b/pkg/file/source.go @@ -38,32 +38,31 @@ func (s *Source) Content() string { return s.text } -func (s *Source) Snippet(loc Location) (line string, caret string) { - if s.Empty() || loc.Line() <= 0 || loc.Line() > len(s.lines) { - return "", "" +func (s *Source) Snippet(loc Location) []Snippet { + if s.Empty() { + return []Snippet{} } - srcLine := s.lines[loc.Line()-1] - runes := []rune(srcLine) - column := loc.Column() + lineNum := loc.Line() + lines := s.lines + var result []Snippet - // Clamp column to within bounds (1-based) - if column < 1 { - column = 1 - } - if column > len(runes)+1 { - column = len(runes) + 1 + // Show previous line if it exists + if lineNum > 1 { + result = append(result, NewSnippet(lines, lineNum-1)) } - // Caret must align with visual column (accounting for tabs) - visualOffset := s.computeVisualOffset(srcLine, column) + result = append(result, NewSnippetWithCaret(lines, loc)) - caretLine := strings.Repeat(" ", visualOffset) + "^" + // Show next line if it exists + if lineNum < len(lines) { + result = append(result, NewSnippet(lines, lineNum+1)) + } - return srcLine, caretLine + return result } -func (s *Source) computeVisualOffset(line string, column int) int { +func computeVisualOffset(line string, column int) int { runes := []rune(line) offset := 0 tabWidth := 4 diff --git a/test/integration/base/assertions.go b/test/integration/base/assertions.go index 89449275..0180323f 100644 --- a/test/integration/base/assertions.go +++ b/test/integration/base/assertions.go @@ -8,10 +8,18 @@ import ( . "github.com/smartystreets/goconvey/convey" ) -type ExpectedError struct { - Message string - Kind compiler.ErrorKind -} +type ( + ExpectedError struct { + Message string + Kind compiler.ErrorKind + Hint string + } + + ExpectedMultiError struct { + Number int + Errors []*ExpectedError + } +) func ArePtrsEqual(expected, actual any) bool { if expected == nil || actual == nil { @@ -37,35 +45,85 @@ func ShouldHaveSameItems(actual any, expected ...any) string { return "" } -func ShouldBeCompilationError(actual any, expected ...any) string { - err, ok := actual.(*compiler.CompilationError) - - if !ok { +func assertExpectedError(actual *compiler.CompilationError, expected *ExpectedError) string { + if actual == nil { return "expected a compilation error" } + if expected.Kind != "" && actual.Kind != expected.Kind { + return fmt.Sprintf("expected error kind %s, got %s", expected.Kind, actual.Kind) + } + + if expected.Message != "" && actual.Message != expected.Message { + return fmt.Sprintf("expected error message '%s', got '%s'", expected.Message, actual.Message) + } + + if expected.Hint != "" && actual.Hint != expected.Hint { + return fmt.Sprintf("expected error hint '%s', got '%s'", expected.Hint, actual.Hint) + } + + return "" +} + +func assertExpectedErrors(actual *compiler.MultiCompilationError, expected *ExpectedMultiError) string { + if actual == nil { + return "expected a multi compilation error" + } + + if expected.Number > 0 && len(actual.Errors) != expected.Number { + return fmt.Sprintf("expected %d errors, got %d", expected.Number, len(actual.Errors)) + } + + if len(expected.Errors) > 0 { + for i, err := range actual.Errors { + if i >= len(expected.Errors) { + break + } + + msg := assertExpectedError(err, expected.Errors[i]) + + if msg != "" { + return msg + } + } + } + + return "" +} + +func ShouldBeCompilationError(actual any, expected ...any) string { var msg string switch ex := expected[0].(type) { case *ExpectedError: - if ex.Kind != "" { - msg = ShouldEqual(err.Kind, ex.Kind) + err, ok := actual.(*compiler.CompilationError) + + if !ok { + return "expected a compilation error" } - if msg == "" { - msg = ShouldEqual(err.Message, ex.Message) + msg = assertExpectedError(err, ex) + + if msg != "" { + fmt.Println(err.Format()) } break - case string: - msg = ShouldEqual(err.Message, ex) + case *ExpectedMultiError: + err, ok := actual.(*compiler.MultiCompilationError) + + if !ok { + return "expected a multi compilation error" + } + + msg = assertExpectedErrors(err, ex) + + if msg != "" { + fmt.Println(err.Format()) + } default: msg = "expected a compilation error" } - if msg != "" { - fmt.Println(err.Format()) - } - return msg } diff --git a/test/integration/compiler/compiler_errors_test.go b/test/integration/compiler/compiler_errors_test.go index c83e962a..3e997b8e 100644 --- a/test/integration/compiler/compiler_errors_test.go +++ b/test/integration/compiler/compiler_errors_test.go @@ -8,6 +8,23 @@ import ( func TestErrors(t *testing.T) { RunUseCases(t, []UseCase{ + ErrorCase( + ` + LET i = NONE + `, E{ + Kind: compiler.SyntaxError, + Message: "Variable 'i' is already defined", + //Message: "Extraneous input at end of file", + }, "Syntax error: missing return statement"), + ErrorCase( + ` + LET i = NONE + RETURN + `, E{ + Kind: compiler.SyntaxError, + //Message: "Unexpected 'return' keyword", + //Hint: "Did you mean to return a value?", + }, "Syntax error: missing return value"), ErrorCase( ` LET i = NONE @@ -15,7 +32,7 @@ func TestErrors(t *testing.T) { RETURN i `, E{ Kind: compiler.NameError, - Message: "Variable 'i' is already defined", + Message: "Variable '' is already defined", }, "Global variable not unique"), ErrorCase( ` diff --git a/test/integration/compiler/shortcuts.go b/test/integration/compiler/shortcuts.go index 6c3b11b7..9d64fee2 100644 --- a/test/integration/compiler/shortcuts.go +++ b/test/integration/compiler/shortcuts.go @@ -8,6 +8,7 @@ import ( type BC = []vm.Instruction type UseCase = base.TestCase type E = base.ExpectedError +type ME = base.ExpectedMultiError var I = vm.NewInstruction var C = vm.NewConstant diff --git a/test/integration/compiler/test_case.go b/test/integration/compiler/test_case.go index b1758b83..a7a96a07 100644 --- a/test/integration/compiler/test_case.go +++ b/test/integration/compiler/test_case.go @@ -42,6 +42,26 @@ func ErrorCase(expression string, expected base.ExpectedError, desc ...string) U return uc } +func SkipErrorCase(expression string, expected base.ExpectedError, desc ...string) UseCase { + return Skip(ErrorCase(expression, expected, desc...)) +} + +func MultiErrorCase(expression string, expected base.ExpectedMultiError, desc ...string) UseCase { + uc := NewCase(expression, &expected, nil, desc...) + uc.PreAssertion = base.ShouldBeCompilationError + uc.Assertions = []convey.Assertion{ + func(actual any, expected ...any) string { + return "expected compilation error" + }, + } + + return uc +} + +func SkipMultiErrorCase(expression string, expected base.ExpectedMultiError, desc ...string) UseCase { + return Skip(MultiErrorCase(expression, expected, desc...)) +} + func SkipByteCodeCase(expression string, expected []vm.Instruction, desc ...string) UseCase { return Skip(ByteCodeCase(expression, expected, desc...)) } diff --git a/test/integration/vm/vm_for_in_test.go b/test/integration/vm/vm_for_in_test.go index 0087199f..c6af667b 100644 --- a/test/integration/vm/vm_for_in_test.go +++ b/test/integration/vm/vm_for_in_test.go @@ -17,10 +17,6 @@ func TestForIn(t *testing.T) { // ShouldEqualJSON, //}, RunUseCases(t, []UseCase{ - SkipCaseCompilationError(` - FOR foo IN foo - RETURN foo - `, "Should not compile FOR foo IN foo"), CaseArray(` FOR i IN 1..5 RETURN i