1
0
mirror of https://github.com/go-task/task.git synced 2024-12-14 10:52:43 +02:00

Update github.com/mvdan/sh

This commit is contained in:
Andrey Nering 2017-05-17 14:49:27 -03:00
parent b590e74ce6
commit 504723bc19
11 changed files with 474 additions and 409 deletions

View File

@ -70,7 +70,7 @@ func (r *Runner) arithm(expr syntax.ArithmExpr) int {
} }
return binArit(x.Op, r.arithm(x.X), r.arithm(x.Y)) return binArit(x.Op, r.arithm(x.X), r.arithm(x.Y))
default: default:
r.errf("unexpected arithm expr: %T", x) r.runErr(expr.Pos(), "unexpected arithm expr: %T", x)
return 0 return 0
} }
} }

View File

@ -8,6 +8,7 @@ import (
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings"
"github.com/mvdan/sh/syntax" "github.com/mvdan/sh/syntax"
) )
@ -18,32 +19,31 @@ func isBuiltin(name string) bool {
"echo", "printf", "break", "continue", "pwd", "cd", "echo", "printf", "break", "continue", "pwd", "cd",
"wait", "builtin", "trap", "type", "source", "command", "wait", "builtin", "trap", "type", "source", "command",
"pushd", "popd", "umask", "alias", "unalias", "fg", "bg", "pushd", "popd", "umask", "alias", "unalias", "fg", "bg",
"getopts": "getopts", "eval":
return true return true
} }
return false return false
} }
func (r *Runner) builtin(pos syntax.Pos, name string, args []string) { func (r *Runner) builtinCode(pos syntax.Pos, name string, args []string) int {
exit := 0
switch name { switch name {
case "true", ":": case "true", ":":
case "false": case "false":
exit = 1 return 1
case "exit": case "exit":
switch len(args) { switch len(args) {
case 0: case 0:
r.lastExit()
case 1: case 1:
if n, err := strconv.Atoi(args[0]); err != nil { if n, err := strconv.Atoi(args[0]); err != nil {
r.runErr(pos, "invalid exit code: %q", args[0]) r.runErr(pos, "invalid exit code: %q", args[0])
} else { } else {
exit = n r.exit = n
r.err = ExitCode(n)
} }
default: default:
r.runErr(pos, "exit cannot take multiple arguments") r.runErr(pos, "exit cannot take multiple arguments")
} }
r.lastExit()
return r.exit
case "set": case "set":
r.args = args r.args = args
case "shift": case "shift":
@ -58,8 +58,7 @@ func (r *Runner) builtin(pos syntax.Pos, name string, args []string) {
fallthrough fallthrough
default: default:
r.errf("usage: shift [n]\n") r.errf("usage: shift [n]\n")
exit = 2 return 2
break
} }
if len(r.args) < n { if len(r.args) < n {
n = len(r.args) n = len(r.args)
@ -97,8 +96,7 @@ func (r *Runner) builtin(pos syntax.Pos, name string, args []string) {
case "printf": case "printf":
if len(args) == 0 { if len(args) == 0 {
r.errf("usage: printf format [arguments]\n") r.errf("usage: printf format [arguments]\n")
exit = 2 return 2
break
} }
var a []interface{} var a []interface{}
for _, arg := range args[1:] { for _, arg := range args[1:] {
@ -121,7 +119,7 @@ func (r *Runner) builtin(pos syntax.Pos, name string, args []string) {
fallthrough fallthrough
default: default:
r.errf("usage: break [n]\n") r.errf("usage: break [n]\n")
exit = 2 return 2
} }
case "continue": case "continue":
if !r.inLoop { if !r.inLoop {
@ -139,15 +137,14 @@ func (r *Runner) builtin(pos syntax.Pos, name string, args []string) {
fallthrough fallthrough
default: default:
r.errf("usage: continue [n]\n") r.errf("usage: continue [n]\n")
exit = 2 return 2
} }
case "pwd": case "pwd":
r.outf("%s\n", r.getVar("PWD")) r.outf("%s\n", r.getVar("PWD"))
case "cd": case "cd":
if len(args) > 1 { if len(args) > 1 {
r.errf("usage: cd [dir]\n") r.errf("usage: cd [dir]\n")
exit = 2 return 2
break
} }
var dir string var dir string
if len(args) == 0 { if len(args) == 0 {
@ -160,13 +157,12 @@ func (r *Runner) builtin(pos syntax.Pos, name string, args []string) {
} }
_, err := os.Stat(dir) _, err := os.Stat(dir)
if err != nil { if err != nil {
exit = 1 return 1
break
} }
r.Dir = dir r.Dir = dir
case "wait": case "wait":
if len(args) > 0 { if len(args) > 0 {
r.errf("wait with args not handled yet") r.runErr(pos, "wait with args not handled yet")
break break
} }
r.bgShells.Wait() r.bgShells.Wait()
@ -175,13 +171,16 @@ func (r *Runner) builtin(pos syntax.Pos, name string, args []string) {
break break
} }
if !isBuiltin(args[0]) { if !isBuiltin(args[0]) {
exit = 1 return 1
break
} }
// TODO: pos return r.builtinCode(pos, args[0], args[1:])
r.builtin(0, args[0], args[1:])
case "type": case "type":
anyNotFound := false
for _, arg := range args { for _, arg := range args {
if _, ok := r.funcs[arg]; ok {
r.outf("%s is a function\n", arg)
continue
}
if isBuiltin(arg) { if isBuiltin(arg) {
r.outf("%s is a shell builtin\n", arg) r.outf("%s is a shell builtin\n", arg)
continue continue
@ -190,12 +189,27 @@ func (r *Runner) builtin(pos syntax.Pos, name string, args []string) {
r.outf("%s is %s\n", arg, path) r.outf("%s is %s\n", arg, path)
continue continue
} }
exit = 1
r.errf("type: %s: not found\n", arg) r.errf("type: %s: not found\n", arg)
anyNotFound = true
} }
if anyNotFound {
return 1
}
case "eval":
src := strings.Join(args, " ")
p := syntax.NewParser()
file, err := p.Parse(strings.NewReader(src), "")
if err != nil {
r.errf("eval: %v\n", err)
return 1
}
r2 := *r
r2.File = file
r2.Run()
return r2.exit
case "trap", "source", "command", "pushd", "popd", case "trap", "source", "command", "pushd", "popd",
"umask", "alias", "unalias", "fg", "bg", "getopts": "umask", "alias", "unalias", "fg", "bg", "getopts":
r.errf("unhandled builtin: %s", name) r.runErr(pos, "unhandled builtin: %s", name)
} }
r.exit = exit return 0
} }

View File

@ -262,26 +262,25 @@ func (r *Runner) stmt(st *syntax.Stmt) {
} }
} }
func (r *Runner) assignValue(word *syntax.Word) varValue { func (r *Runner) assignValue(as *syntax.Assign) varValue {
if word == nil { if as.Value != nil {
return nil return r.loneWord(as.Value)
} }
ae, ok := word.Parts[0].(*syntax.ArrayExpr) if as.Array != nil {
if !ok { strs := make([]string, len(as.Array.List))
return r.loneWord(word) for i, w := range as.Array.List {
strs[i] = r.loneWord(w)
}
return strs
} }
strs := make([]string, len(ae.List)) return nil
for i, w := range ae.List {
strs[i] = r.loneWord(w)
}
return strs
} }
func (r *Runner) stmtSync(st *syntax.Stmt) { func (r *Runner) stmtSync(st *syntax.Stmt) {
oldVars := r.cmdVars oldVars := r.cmdVars
for _, as := range st.Assigns { for _, as := range st.Assigns {
name := as.Name.Value name := as.Name.Value
val := r.assignValue(as.Value) val := r.assignValue(as)
if st.Cmd == nil { if st.Cmd == nil {
r.setVar(name, val) r.setVar(name, val)
continue continue
@ -383,22 +382,9 @@ func (r *Runner) cmd(cm syntax.Command) {
case *syntax.WhileClause: case *syntax.WhileClause:
for r.err == nil { for r.err == nil {
r.stmts(x.CondStmts) r.stmts(x.CondStmts)
if r.exit != 0 { stop := (r.exit == 0) == x.Until
r.exit = 0
break
}
if r.loopStmtsBroken(x.DoStmts) {
break
}
}
case *syntax.UntilClause:
for r.err == nil {
r.stmts(x.CondStmts)
if r.exit == 0 {
break
}
r.exit = 0 r.exit = 0
if r.loopStmtsBroken(x.DoStmts) { if stop || r.loopStmtsBroken(x.DoStmts) {
break break
} }
} }
@ -440,9 +426,7 @@ func (r *Runner) cmd(cm syntax.Command) {
for _, pl := range x.List { for _, pl := range x.List {
for _, word := range pl.Patterns { for _, word := range pl.Patterns {
pat := r.loneWord(word) pat := r.loneWord(word)
// TODO: error? if match(pat, str) {
matched, _ := path.Match(pat, str)
if matched {
r.stmts(pl.Stmts) r.stmts(pl.Stmts)
return return
} }
@ -453,7 +437,7 @@ func (r *Runner) cmd(cm syntax.Command) {
r.exit = 1 r.exit = 1
} }
default: default:
r.errf("unhandled command node: %T", x) r.runErr(cm.Pos(), "unhandled command node: %T", x)
} }
} }
@ -463,6 +447,11 @@ func (r *Runner) stmts(stmts []*syntax.Stmt) {
} }
} }
func match(pattern, name string) bool {
matched, _ := path.Match(pattern, name)
return matched
}
func (r *Runner) redir(rd *syntax.Redirect) (io.Closer, error) { func (r *Runner) redir(rd *syntax.Redirect) (io.Closer, error) {
if rd.Hdoc != nil { if rd.Hdoc != nil {
hdoc := r.loneWord(rd.Hdoc) hdoc := r.loneWord(rd.Hdoc)
@ -491,7 +480,7 @@ func (r *Runner) redir(rd *syntax.Redirect) (io.Closer, error) {
} }
return nil, nil return nil, nil
case syntax.DplIn: case syntax.DplIn:
r.errf("unhandled redirect op: %v", rd.Op) r.runErr(rd.Pos(), "unhandled redirect op: %v", rd.Op)
} }
mode := os.O_RDONLY mode := os.O_RDONLY
switch rd.Op { switch rd.Op {
@ -514,7 +503,7 @@ func (r *Runner) redir(rd *syntax.Redirect) (io.Closer, error) {
r.Stdout = f r.Stdout = f
r.Stderr = f r.Stderr = f
default: default:
r.errf("unhandled redirect op: %v", rd.Op) r.runErr(rd.Pos(), "unhandled redirect op: %v", rd.Op)
} }
return f, nil return f, nil
} }
@ -605,7 +594,7 @@ func (r *Runner) wordParts(wps []syntax.WordPart, quoted bool) []string {
case *syntax.ArithmExp: case *syntax.ArithmExp:
curBuf.WriteString(strconv.Itoa(r.arithm(x.X))) curBuf.WriteString(strconv.Itoa(r.arithm(x.X)))
default: default:
r.errf("unhandled word part: %T", x) r.runErr(wp.Pos(), "unhandled word part: %T", x)
} }
} }
flush() flush()
@ -622,7 +611,7 @@ func (r *Runner) call(pos syntax.Pos, name string, args []string) {
return return
} }
if isBuiltin(name) { if isBuiltin(name) {
r.builtin(pos, name, args) r.exit = r.builtinCode(pos, name, args)
return return
} }
cmd := exec.CommandContext(r.Context, name, args...) cmd := exec.CommandContext(r.Context, name, args...)
@ -639,10 +628,9 @@ func (r *Runner) call(pos syntax.Pos, name string, args []string) {
case *exec.ExitError: case *exec.ExitError:
// started, but errored - default to 1 if OS // started, but errored - default to 1 if OS
// doesn't have exit statuses // doesn't have exit statuses
r.exit = 1
if status, ok := x.Sys().(syscall.WaitStatus); ok { if status, ok := x.Sys().(syscall.WaitStatus); ok {
r.exit = status.ExitStatus() r.exit = status.ExitStatus()
} else {
r.exit = 1
} }
case *exec.Error: case *exec.Error:
// did not start // did not start

View File

@ -4,7 +4,6 @@
package interp package interp
import ( import (
"path"
"strconv" "strconv"
"strings" "strings"
"unicode" "unicode"
@ -34,8 +33,8 @@ func (r *Runner) paramExp(pe *syntax.ParamExp) string {
} }
} }
str := varStr(val) str := varStr(val)
if pe.Ind != nil { if pe.Index != nil {
str = r.varInd(val, pe.Ind.Expr) str = r.varInd(val, pe.Index)
} }
switch { switch {
case pe.Length: case pe.Length:
@ -155,9 +154,9 @@ func (r *Runner) paramExp(pe *syntax.ParamExp) string {
} }
str = string(rns) str = string(rns)
case "P", "A", "a": case "P", "A", "a":
r.errf("unhandled @%s param expansion", arg) r.runErr(pe.Pos(), "unhandled @%s param expansion", arg)
default: default:
r.errf("unexpected @%s param expansion", arg) r.runErr(pe.Pos(), "unexpected @%s param expansion", arg)
} }
} }
} }
@ -173,7 +172,7 @@ func removePattern(str, pattern string, fromEnd, longest bool) string {
i = 0 i = 0
} }
for { for {
if m, _ := path.Match(pattern, s); m { if match(pattern, s) {
last = str[i:] last = str[i:]
if fromEnd { if fromEnd {
last = str[:i] last = str[:i]

View File

@ -6,7 +6,6 @@ package interp
import ( import (
"os" "os"
"os/exec" "os/exec"
"path"
"path/filepath" "path/filepath"
"regexp" "regexp"
@ -75,11 +74,9 @@ func (r *Runner) binTest(op syntax.BinTestOperator, x, y string) bool {
case syntax.OrTest: case syntax.OrTest:
return x != "" || y != "" return x != "" || y != ""
case syntax.TsMatch: case syntax.TsMatch:
m, _ := path.Match(y, x) return match(y, x)
return m
case syntax.TsNoMatch: case syntax.TsNoMatch:
m, _ := path.Match(y, x) return !match(y, x)
return !m
case syntax.TsBefore: case syntax.TsBefore:
return x < y return x < y
default: // syntax.TsAfter default: // syntax.TsAfter
@ -154,7 +151,7 @@ func (r *Runner) unTest(op syntax.UnTestOperator, x string) bool {
case syntax.TsNot: case syntax.TsNot:
return x == "" return x == ""
default: default:
r.errf("unhandled unary test op: %v", op) r.runErr(0, "unhandled unary test op: %v", op)
return false return false
} }
} }

View File

@ -46,7 +46,7 @@ func wordBreak(r rune) bool {
return false return false
} }
func (p *parser) rune() rune { func (p *Parser) rune() rune {
retry: retry:
if p.npos < len(p.bs) { if p.npos < len(p.bs) {
if b := p.bs[p.npos]; b < utf8.RuneSelf { if b := p.bs[p.npos]; b < utf8.RuneSelf {
@ -86,7 +86,7 @@ retry:
return p.r return p.r
} }
func (p *parser) unrune(r rune) { func (p *Parser) unrune(r rune) {
if p.r != utf8.RuneSelf { if p.r != utf8.RuneSelf {
p.npos -= utf8.RuneLen(p.r) p.npos -= utf8.RuneLen(p.r)
p.r = r p.r = r
@ -96,7 +96,7 @@ func (p *parser) unrune(r rune) {
// fill reads more bytes from the input src into readBuf. Any bytes that // fill reads more bytes from the input src into readBuf. Any bytes that
// had not yet been used at the end of the buffer are slid into the // had not yet been used at the end of the buffer are slid into the
// beginning of the buffer. // beginning of the buffer.
func (p *parser) fill() { func (p *Parser) fill() {
left := len(p.bs) - p.npos left := len(p.bs) - p.npos
p.offs += p.npos p.offs += p.npos
copy(p.readBuf[:left], p.readBuf[p.npos:]) copy(p.readBuf[:left], p.readBuf[p.npos:])
@ -124,7 +124,7 @@ func (p *parser) fill() {
p.npos = 0 p.npos = 0
} }
func (p *parser) nextKeepSpaces() { func (p *Parser) nextKeepSpaces() {
r := p.r r := p.r
if p.pos = p.getPos(); r > utf8.RuneSelf { if p.pos = p.getPos(); r > utf8.RuneSelf {
p.pos -= Pos(utf8.RuneLen(r) - 1) p.pos -= Pos(utf8.RuneLen(r) - 1)
@ -174,7 +174,7 @@ func (p *parser) nextKeepSpaces() {
} }
} }
func (p *parser) next() { func (p *Parser) next() {
if p.r == utf8.RuneSelf { if p.r == utf8.RuneSelf {
p.tok = _EOF p.tok = _EOF
return return
@ -231,7 +231,7 @@ skipSpace:
for r != utf8.RuneSelf && r != '\n' { for r != utf8.RuneSelf && r != '\n' {
r = p.rune() r = p.rune()
} }
if p.mode&ParseComments > 0 { if p.keepComments {
p.f.Comments = append(p.f.Comments, &Comment{ p.f.Comments = append(p.f.Comments, &Comment{
Hash: p.pos, Hash: p.pos,
Text: p.endLit(), Text: p.endLit(),
@ -282,14 +282,14 @@ skipSpace:
} }
} }
func (p *parser) peekByte(b byte) bool { func (p *Parser) peekByte(b byte) bool {
if p.npos == len(p.bs) && p.readErr == nil { if p.npos == len(p.bs) && p.readErr == nil {
p.fill() p.fill()
} }
return p.npos < len(p.bs) && p.bs[p.npos] == b return p.npos < len(p.bs) && p.bs[p.npos] == b
} }
func (p *parser) regToken(r rune) token { func (p *Parser) regToken(r rune) token {
switch r { switch r {
case '\'': case '\'':
p.rune() p.rune()
@ -432,7 +432,7 @@ func (p *parser) regToken(r rune) token {
} }
} }
func (p *parser) dqToken(r rune) token { func (p *Parser) dqToken(r rune) token {
switch r { switch r {
case '"': case '"':
p.rune() p.rune()
@ -462,7 +462,7 @@ func (p *parser) dqToken(r rune) token {
} }
} }
func (p *parser) paramToken(r rune) token { func (p *Parser) paramToken(r rune) token {
switch r { switch r {
case '}': case '}':
p.rune() p.rune()
@ -537,7 +537,7 @@ func (p *parser) paramToken(r rune) token {
} }
} }
func (p *parser) arithmToken(r rune) token { func (p *Parser) arithmToken(r rune) token {
switch r { switch r {
case '!': case '!':
if p.rune() == '=' { if p.rune() == '=' {
@ -666,7 +666,7 @@ func (p *parser) arithmToken(r rune) token {
} }
} }
func (p *parser) newLit(r rune) { func (p *Parser) newLit(r rune) {
// don't let r == utf8.RuneSelf go to the second case as RuneLen // don't let r == utf8.RuneSelf go to the second case as RuneLen
// would return -1 // would return -1
if r <= utf8.RuneSelf { if r <= utf8.RuneSelf {
@ -678,9 +678,9 @@ func (p *parser) newLit(r rune) {
} }
} }
func (p *parser) discardLit(n int) { p.litBs = p.litBs[:len(p.litBs)-n] } func (p *Parser) discardLit(n int) { p.litBs = p.litBs[:len(p.litBs)-n] }
func (p *parser) endLit() (s string) { func (p *Parser) endLit() (s string) {
if p.r == utf8.RuneSelf { if p.r == utf8.RuneSelf {
s = string(p.litBs) s = string(p.litBs)
} else if len(p.litBs) > 0 { } else if len(p.litBs) > 0 {
@ -690,7 +690,7 @@ func (p *parser) endLit() (s string) {
return return
} }
func (p *parser) advanceLitOther(r rune) { func (p *Parser) advanceLitOther(r rune) {
tok := _LitWord tok := _LitWord
loop: loop:
for p.newLit(r); r != utf8.RuneSelf; r = p.rune() { for p.newLit(r); r != utf8.RuneSelf; r = p.rune() {
@ -756,6 +756,9 @@ loop:
if p.quote&allParamReg != 0 { if p.quote&allParamReg != 0 {
break loop break loop
} }
if r == '[' && p.bash() && p.quote&allArithmExpr != 0 {
break loop
}
case '+': case '+':
if p.quote == paramName && p.peekByte('(') { if p.quote == paramName && p.peekByte('(') {
tok = _Lit tok = _Lit
@ -764,8 +767,7 @@ loop:
fallthrough fallthrough
case '-': case '-':
switch p.quote { switch p.quote {
case paramExpInd, paramExpLen, paramExpOff, case paramExpExp, paramExpRepl, sglQuotes:
paramExpExp, paramExpRepl, sglQuotes:
default: default:
break loop break loop
} }
@ -780,7 +782,7 @@ loop:
p.tok, p.val = tok, p.endLit() p.tok, p.val = tok, p.endLit()
} }
func (p *parser) advanceLitNone(r rune) { func (p *Parser) advanceLitNone(r rune) {
p.asPos = 0 p.asPos = 0
tok := _LitWord tok := _LitWord
loop: loop:
@ -826,12 +828,17 @@ loop:
} }
case '=': case '=':
p.asPos = len(p.litBs) - 1 p.asPos = len(p.litBs) - 1
case '[':
if p.bash() && len(p.litBs) > 1 && p.litBs[0] != '[' {
tok = _Lit
break loop
}
} }
} }
p.tok, p.val = tok, p.endLit() p.tok, p.val = tok, p.endLit()
} }
func (p *parser) advanceLitDquote(r rune) { func (p *Parser) advanceLitDquote(r rune) {
tok := _LitWord tok := _LitWord
loop: loop:
for p.newLit(r); r != utf8.RuneSelf; r = p.rune() { for p.newLit(r); r != utf8.RuneSelf; r = p.rune() {
@ -848,7 +855,7 @@ loop:
p.tok, p.val = tok, p.endLit() p.tok, p.val = tok, p.endLit()
} }
func (p *parser) advanceLitHdoc(r rune) { func (p *Parser) advanceLitHdoc(r rune) {
p.tok = _Lit p.tok = _Lit
p.newLit(r) p.newLit(r)
if p.quote == hdocBodyTabs { if p.quote == hdocBodyTabs {
@ -886,7 +893,7 @@ loop:
} }
} }
func (p *parser) hdocLitWord() *Word { func (p *Parser) hdocLitWord() *Word {
r := p.r r := p.r
p.newLit(r) p.newLit(r)
pos, val := p.getPos(), "" pos, val := p.getPos(), ""
@ -918,7 +925,7 @@ func (p *parser) hdocLitWord() *Word {
return p.word(p.wps(l)) return p.word(p.wps(l))
} }
func (p *parser) advanceLitRe(r rune) { func (p *Parser) advanceLitRe(r rune) {
lparens := 0 lparens := 0
loop: loop:
for p.newLit(r); r != utf8.RuneSelf; r = p.rune() { for p.newLit(r); r != utf8.RuneSelf; r = p.rune() {

View File

@ -145,9 +145,9 @@ func (s *Stmt) End() Pos {
// Command represents all nodes that are simple commands, which are // Command represents all nodes that are simple commands, which are
// directly placed in a Stmt. // directly placed in a Stmt.
// //
// These are *CallExpr, *IfClause, *WhileClause, *UntilClause, // These are *CallExpr, *IfClause, *WhileClause, *ForClause,
// *ForClause, *CaseClause, *Block, *Subshell, *BinaryCmd, *FuncDecl, // *CaseClause, *Block, *Subshell, *BinaryCmd, *FuncDecl, *ArithmCmd,
// *ArithmCmd, *TestClause, *DeclClause, *LetClause, and *CoprocClause. // *TestClause, *DeclClause, *LetClause, and *CoprocClause.
type Command interface { type Command interface {
Node Node
commandNode() commandNode()
@ -156,7 +156,6 @@ type Command interface {
func (*CallExpr) commandNode() {} func (*CallExpr) commandNode() {}
func (*IfClause) commandNode() {} func (*IfClause) commandNode() {}
func (*WhileClause) commandNode() {} func (*WhileClause) commandNode() {}
func (*UntilClause) commandNode() {}
func (*ForClause) commandNode() {} func (*ForClause) commandNode() {}
func (*CaseClause) commandNode() {} func (*CaseClause) commandNode() {}
func (*Block) commandNode() {} func (*Block) commandNode() {}
@ -173,7 +172,9 @@ func (*CoprocClause) commandNode() {}
type Assign struct { type Assign struct {
Append bool Append bool
Name *Lit Name *Lit
Index ArithmExpr
Value *Word Value *Word
Array *ArrayExpr
} }
func (a *Assign) Pos() Pos { func (a *Assign) Pos() Pos {
@ -187,6 +188,12 @@ func (a *Assign) End() Pos {
if a.Value != nil { if a.Value != nil {
return a.Value.End() return a.Value.End()
} }
if a.Array != nil {
return a.Array.End()
}
if a.Index != nil {
return a.Index.End() + 2
}
return a.Name.End() + 1 return a.Name.End() + 1
} }
@ -253,9 +260,10 @@ type Elif struct {
ThenStmts []*Stmt ThenStmts []*Stmt
} }
// WhileClause represents a while clause. // WhileClause represents a while or an until clause.
type WhileClause struct { type WhileClause struct {
While, Do, Done Pos While, Do, Done Pos
Until bool
CondStmts []*Stmt CondStmts []*Stmt
DoStmts []*Stmt DoStmts []*Stmt
} }
@ -263,16 +271,6 @@ type WhileClause struct {
func (w *WhileClause) Pos() Pos { return w.While } func (w *WhileClause) Pos() Pos { return w.While }
func (w *WhileClause) End() Pos { return w.Done + 4 } func (w *WhileClause) End() Pos { return w.Done + 4 }
// UntilClause represents an until clause.
type UntilClause struct {
Until, Do, Done Pos
CondStmts []*Stmt
DoStmts []*Stmt
}
func (u *UntilClause) Pos() Pos { return u.Until }
func (u *UntilClause) End() Pos { return u.Done + 4 }
// ForClause represents a for clause. // ForClause represents a for clause.
type ForClause struct { type ForClause struct {
For, Do, Done Pos For, Do, Done Pos
@ -347,7 +345,7 @@ func (w *Word) End() Pos { return w.Parts[len(w.Parts)-1].End() }
// WordPart represents all nodes that can form a word. // WordPart represents all nodes that can form a word.
// //
// These are *Lit, *SglQuoted, *DblQuoted, *ParamExp, *CmdSubst, // These are *Lit, *SglQuoted, *DblQuoted, *ParamExp, *CmdSubst,
// *ArithmExp, *ProcSubst, *ArrayExpr, and *ExtGlob. // *ArithmExp, *ProcSubst, and *ExtGlob.
type WordPart interface { type WordPart interface {
Node Node
wordPartNode() wordPartNode()
@ -360,7 +358,6 @@ func (*ParamExp) wordPartNode() {}
func (*CmdSubst) wordPartNode() {} func (*CmdSubst) wordPartNode() {}
func (*ArithmExp) wordPartNode() {} func (*ArithmExp) wordPartNode() {}
func (*ProcSubst) wordPartNode() {} func (*ProcSubst) wordPartNode() {}
func (*ArrayExpr) wordPartNode() {}
func (*ExtGlob) wordPartNode() {} func (*ExtGlob) wordPartNode() {}
// Lit represents an unquoted string consisting of characters that were // Lit represents an unquoted string consisting of characters that were
@ -423,7 +420,7 @@ type ParamExp struct {
Indirect bool Indirect bool
Length bool Length bool
Param *Lit Param *Lit
Ind *Index Index ArithmExpr
Slice *Slice Slice *Slice
Repl *Replace Repl *Replace
Exp *Expansion Exp *Expansion
@ -437,12 +434,7 @@ func (p *ParamExp) End() Pos {
return p.Param.End() return p.Param.End()
} }
// Index represents access to an array via an index inside a ParamExp. func (p *ParamExp) nakedIndex() bool { return p.Short && p.Index != nil }
//
// This node will never appear when in PosixConformant mode.
type Index struct {
Expr ArithmExpr
}
// Slice represents character slicing inside a ParamExp. // Slice represents character slicing inside a ParamExp.
// //

View File

@ -8,29 +8,34 @@ import (
"fmt" "fmt"
"io" "io"
"strconv" "strconv"
"sync"
"unicode/utf8" "unicode/utf8"
) )
// ParseMode controls the parser behaviour via a set of flags. func KeepComments(p *Parser) { p.keepComments = true }
type ParseMode uint
type LangVariant int
const ( const (
ParseComments ParseMode = 1 << iota // add comments to the AST LangBash LangVariant = iota
PosixConformant // match the POSIX standard where it differs from bash LangPOSIX
) )
var parserFree = sync.Pool{ func Variant(l LangVariant) func(*Parser) {
New: func() interface{} { return func(p *Parser) { p.lang = l }
return &parser{helperBuf: new(bytes.Buffer)} }
},
func NewParser(options ...func(*Parser)) *Parser {
p := &Parser{helperBuf: new(bytes.Buffer)}
for _, opt := range options {
opt(p)
}
return p
} }
// Parse reads and parses a shell program with an optional name. It // Parse reads and parses a shell program with an optional name. It
// returns the parsed program if no issues were encountered. Otherwise, // returns the parsed program if no issues were encountered. Otherwise,
// an error is returned. // an error is returned.
func Parse(src io.Reader, name string, mode ParseMode) (*File, error) { func (p *Parser) Parse(src io.Reader, name string) (*File, error) {
p := parserFree.Get().(*parser)
p.reset() p.reset()
alloc := &struct { alloc := &struct {
f File f File
@ -39,7 +44,7 @@ func Parse(src io.Reader, name string, mode ParseMode) (*File, error) {
p.f = &alloc.f p.f = &alloc.f
p.f.Name = name p.f.Name = name
p.f.lines = alloc.l[:1] p.f.lines = alloc.l[:1]
p.src, p.mode = src, mode p.src = src
p.rune() p.rune()
p.next() p.next()
p.f.Stmts = p.stmts() p.f.Stmts = p.stmts()
@ -48,18 +53,15 @@ func Parse(src io.Reader, name string, mode ParseMode) (*File, error) {
// trigger it // trigger it
p.doHeredocs() p.doHeredocs()
} }
f, err := p.f, p.err return p.f, p.err
parserFree.Put(p)
return f, err
} }
type parser struct { type Parser struct {
src io.Reader src io.Reader
bs []byte // current chunk of read bytes bs []byte // current chunk of read bytes
r rune r rune
f *File f *File
mode ParseMode
spaced bool // whether tok has whitespace on its left spaced bool // whether tok has whitespace on its left
newLine bool // whether tok is on a new line newLine bool // whether tok is on a new line
@ -77,6 +79,9 @@ type parser struct {
quote quoteState // current lexer state quote quoteState // current lexer state
asPos int // position of '=' in a literal asPos int // position of '=' in a literal
keepComments bool
lang LangVariant
forbidNested bool forbidNested bool
// list of pending heredoc bodies // list of pending heredoc bodies
@ -100,7 +105,7 @@ type parser struct {
const bufSize = 1 << 10 const bufSize = 1 << 10
func (p *parser) reset() { func (p *Parser) reset() {
p.bs = nil p.bs = nil
p.offs, p.npos = 0, 0 p.offs, p.npos = 0, 0
p.r, p.err, p.readErr = 0, nil, nil p.r, p.err, p.readErr = 0, nil, nil
@ -108,9 +113,9 @@ func (p *parser) reset() {
p.heredocs, p.buriedHdocs = p.heredocs[:0], 0 p.heredocs, p.buriedHdocs = p.heredocs[:0], 0
} }
func (p *parser) getPos() Pos { return Pos(p.offs + p.npos) } func (p *Parser) getPos() Pos { return Pos(p.offs + p.npos) }
func (p *parser) lit(pos Pos, val string) *Lit { func (p *Parser) lit(pos Pos, val string) *Lit {
if len(p.litBatch) == 0 { if len(p.litBatch) == 0 {
p.litBatch = make([]Lit, 64) p.litBatch = make([]Lit, 64)
} }
@ -122,7 +127,7 @@ func (p *parser) lit(pos Pos, val string) *Lit {
return l return l
} }
func (p *parser) word(parts []WordPart) *Word { func (p *Parser) word(parts []WordPart) *Word {
if len(p.wordBatch) == 0 { if len(p.wordBatch) == 0 {
p.wordBatch = make([]Word, 32) p.wordBatch = make([]Word, 32)
} }
@ -132,7 +137,7 @@ func (p *parser) word(parts []WordPart) *Word {
return w return w
} }
func (p *parser) wps(wp WordPart) []WordPart { func (p *Parser) wps(wp WordPart) []WordPart {
if len(p.wpsBatch) == 0 { if len(p.wpsBatch) == 0 {
p.wpsBatch = make([]WordPart, 64) p.wpsBatch = make([]WordPart, 64)
} }
@ -142,7 +147,7 @@ func (p *parser) wps(wp WordPart) []WordPart {
return wps return wps
} }
func (p *parser) stmt(pos Pos) *Stmt { func (p *Parser) stmt(pos Pos) *Stmt {
if len(p.stmtBatch) == 0 { if len(p.stmtBatch) == 0 {
p.stmtBatch = make([]Stmt, 16) p.stmtBatch = make([]Stmt, 16)
} }
@ -152,7 +157,7 @@ func (p *parser) stmt(pos Pos) *Stmt {
return s return s
} }
func (p *parser) stList() []*Stmt { func (p *Parser) stList() []*Stmt {
if len(p.stListBatch) == 0 { if len(p.stListBatch) == 0 {
p.stListBatch = make([]*Stmt, 128) p.stListBatch = make([]*Stmt, 128)
} }
@ -166,7 +171,7 @@ type callAlloc struct {
ws [4]*Word ws [4]*Word
} }
func (p *parser) call(w *Word) *CallExpr { func (p *Parser) call(w *Word) *CallExpr {
if len(p.callBatch) == 0 { if len(p.callBatch) == 0 {
p.callBatch = make([]callAlloc, 32) p.callBatch = make([]callAlloc, 32)
} }
@ -208,30 +213,30 @@ const (
allRegTokens = noState | subCmd | subCmdBckquo | hdocWord | switchCase allRegTokens = noState | subCmd | subCmdBckquo | hdocWord | switchCase
allArithmExpr = arithmExpr | arithmExprLet | arithmExprCmd | allArithmExpr = arithmExpr | arithmExprLet | arithmExprCmd |
arithmExprBrack | allParamArith arithmExprBrack | allParamArith
allRbrack = arithmExprBrack | paramExpInd allRbrack = arithmExprBrack | paramExpInd | paramName
allParamArith = paramExpInd | paramExpOff | paramExpLen allParamArith = paramExpInd | paramExpOff | paramExpLen
allParamReg = paramName | paramExpName | allParamArith allParamReg = paramName | paramExpName | allParamArith
allParamExp = allParamReg | paramExpRepl | paramExpExp allParamExp = allParamReg | paramExpRepl | paramExpExp
) )
func (p *parser) bash() bool { return p.mode&PosixConformant == 0 } func (p *Parser) bash() bool { return p.lang == LangBash }
type saveState struct { type saveState struct {
quote quoteState quote quoteState
buriedHdocs int buriedHdocs int
} }
func (p *parser) preNested(quote quoteState) (s saveState) { func (p *Parser) preNested(quote quoteState) (s saveState) {
s.quote, s.buriedHdocs = p.quote, p.buriedHdocs s.quote, s.buriedHdocs = p.quote, p.buriedHdocs
p.buriedHdocs, p.quote = len(p.heredocs), quote p.buriedHdocs, p.quote = len(p.heredocs), quote
return return
} }
func (p *parser) postNested(s saveState) { func (p *Parser) postNested(s saveState) {
p.quote, p.buriedHdocs = s.quote, s.buriedHdocs p.quote, p.buriedHdocs = s.quote, s.buriedHdocs
} }
func (p *parser) unquotedWordBytes(w *Word) ([]byte, bool) { func (p *Parser) unquotedWordBytes(w *Word) ([]byte, bool) {
p.helperBuf.Reset() p.helperBuf.Reset()
didUnquote := false didUnquote := false
for _, wp := range w.Parts { for _, wp := range w.Parts {
@ -242,7 +247,7 @@ func (p *parser) unquotedWordBytes(w *Word) ([]byte, bool) {
return p.helperBuf.Bytes(), didUnquote return p.helperBuf.Bytes(), didUnquote
} }
func (p *parser) unquotedWordPart(buf *bytes.Buffer, wp WordPart, quotes bool) (quoted bool) { func (p *Parser) unquotedWordPart(buf *bytes.Buffer, wp WordPart, quotes bool) (quoted bool) {
switch x := wp.(type) { switch x := wp.(type) {
case *Lit: case *Lit:
for i := 0; i < len(x.Value); i++ { for i := 0; i < len(x.Value); i++ {
@ -267,7 +272,7 @@ func (p *parser) unquotedWordPart(buf *bytes.Buffer, wp WordPart, quotes bool) (
return return
} }
func (p *parser) doHeredocs() { func (p *Parser) doHeredocs() {
old := p.quote old := p.quote
hdocs := p.heredocs[p.buriedHdocs:] hdocs := p.heredocs[p.buriedHdocs:]
p.heredocs = p.heredocs[:p.buriedHdocs] p.heredocs = p.heredocs[:p.buriedHdocs]
@ -295,7 +300,7 @@ func (p *parser) doHeredocs() {
p.quote = old p.quote = old
} }
func (p *parser) got(tok token) bool { func (p *Parser) got(tok token) bool {
if p.tok == tok { if p.tok == tok {
p.next() p.next()
return true return true
@ -303,7 +308,7 @@ func (p *parser) got(tok token) bool {
return false return false
} }
func (p *parser) gotRsrv(val string) bool { func (p *Parser) gotRsrv(val string) bool {
if p.tok == _LitWord && p.val == val { if p.tok == _LitWord && p.val == val {
p.next() p.next()
return true return true
@ -311,7 +316,7 @@ func (p *parser) gotRsrv(val string) bool {
return false return false
} }
func (p *parser) gotSameLine(tok token) bool { func (p *Parser) gotSameLine(tok token) bool {
if !p.newLine && p.tok == tok { if !p.newLine && p.tok == tok {
p.next() p.next()
return true return true
@ -327,16 +332,16 @@ func readableStr(s string) string {
return s return s
} }
func (p *parser) followErr(pos Pos, left, right string) { func (p *Parser) followErr(pos Pos, left, right string) {
leftStr := readableStr(left) leftStr := readableStr(left)
p.posErr(pos, "%s must be followed by %s", leftStr, right) p.posErr(pos, "%s must be followed by %s", leftStr, right)
} }
func (p *parser) followErrExp(pos Pos, left string) { func (p *Parser) followErrExp(pos Pos, left string) {
p.followErr(pos, left, "an expression") p.followErr(pos, left, "an expression")
} }
func (p *parser) follow(lpos Pos, left string, tok token) Pos { func (p *Parser) follow(lpos Pos, left string, tok token) Pos {
pos := p.pos pos := p.pos
if !p.got(tok) { if !p.got(tok) {
p.followErr(lpos, left, tok.String()) p.followErr(lpos, left, tok.String())
@ -344,7 +349,7 @@ func (p *parser) follow(lpos Pos, left string, tok token) Pos {
return pos return pos
} }
func (p *parser) followRsrv(lpos Pos, left, val string) Pos { func (p *Parser) followRsrv(lpos Pos, left, val string) Pos {
pos := p.pos pos := p.pos
if !p.gotRsrv(val) { if !p.gotRsrv(val) {
p.followErr(lpos, left, fmt.Sprintf("%q", val)) p.followErr(lpos, left, fmt.Sprintf("%q", val))
@ -352,7 +357,7 @@ func (p *parser) followRsrv(lpos Pos, left, val string) Pos {
return pos return pos
} }
func (p *parser) followStmts(left string, lpos Pos, stops ...string) []*Stmt { func (p *Parser) followStmts(left string, lpos Pos, stops ...string) []*Stmt {
if p.gotSameLine(semicolon) { if p.gotSameLine(semicolon) {
return nil return nil
} }
@ -363,7 +368,7 @@ func (p *parser) followStmts(left string, lpos Pos, stops ...string) []*Stmt {
return sts return sts
} }
func (p *parser) followWordTok(tok token, pos Pos) *Word { func (p *Parser) followWordTok(tok token, pos Pos) *Word {
w := p.getWord() w := p.getWord()
if w == nil { if w == nil {
p.followErr(pos, tok.String(), "a word") p.followErr(pos, tok.String(), "a word")
@ -371,7 +376,7 @@ func (p *parser) followWordTok(tok token, pos Pos) *Word {
return w return w
} }
func (p *parser) followWord(s string, pos Pos) *Word { func (p *Parser) followWord(s string, pos Pos) *Word {
w := p.getWord() w := p.getWord()
if w == nil { if w == nil {
p.followErr(pos, s, "a word") p.followErr(pos, s, "a word")
@ -379,7 +384,7 @@ func (p *parser) followWord(s string, pos Pos) *Word {
return w return w
} }
func (p *parser) stmtEnd(n Node, start, end string) Pos { func (p *Parser) stmtEnd(n Node, start, end string) Pos {
pos := p.pos pos := p.pos
if !p.gotRsrv(end) { if !p.gotRsrv(end) {
p.posErr(n.Pos(), "%s statement must end with %q", start, end) p.posErr(n.Pos(), "%s statement must end with %q", start, end)
@ -387,17 +392,17 @@ func (p *parser) stmtEnd(n Node, start, end string) Pos {
return pos return pos
} }
func (p *parser) quoteErr(lpos Pos, quote token) { func (p *Parser) quoteErr(lpos Pos, quote token) {
p.posErr(lpos, "reached %s without closing quote %s", p.posErr(lpos, "reached %s without closing quote %s",
p.tok.String(), quote) p.tok.String(), quote)
} }
func (p *parser) matchingErr(lpos Pos, left, right interface{}) { func (p *Parser) matchingErr(lpos Pos, left, right interface{}) {
p.posErr(lpos, "reached %s without matching %s with %s", p.posErr(lpos, "reached %s without matching %s with %s",
p.tok.String(), left, right) p.tok.String(), left, right)
} }
func (p *parser) matched(lpos Pos, left, right token) Pos { func (p *Parser) matched(lpos Pos, left, right token) Pos {
pos := p.pos pos := p.pos
if !p.got(right) { if !p.got(right) {
p.matchingErr(lpos, left, right) p.matchingErr(lpos, left, right)
@ -405,7 +410,7 @@ func (p *parser) matched(lpos Pos, left, right token) Pos {
return pos return pos
} }
func (p *parser) errPass(err error) { func (p *Parser) errPass(err error) {
if p.err == nil { if p.err == nil {
p.err = err p.err = err
p.npos = len(p.bs) + 1 p.npos = len(p.bs) + 1
@ -424,18 +429,18 @@ func (e *ParseError) Error() string {
return fmt.Sprintf("%s: %s", e.Position.String(), e.Text) return fmt.Sprintf("%s: %s", e.Position.String(), e.Text)
} }
func (p *parser) posErr(pos Pos, format string, a ...interface{}) { func (p *Parser) posErr(pos Pos, format string, a ...interface{}) {
p.errPass(&ParseError{ p.errPass(&ParseError{
Position: p.f.Position(pos), Position: p.f.Position(pos),
Text: fmt.Sprintf(format, a...), Text: fmt.Sprintf(format, a...),
}) })
} }
func (p *parser) curErr(format string, a ...interface{}) { func (p *Parser) curErr(format string, a ...interface{}) {
p.posErr(p.pos, format, a...) p.posErr(p.pos, format, a...)
} }
func (p *parser) stmts(stops ...string) (sts []*Stmt) { func (p *Parser) stmts(stops ...string) (sts []*Stmt) {
gotEnd := true gotEnd := true
for p.tok != _EOF { for p.tok != _EOF {
switch p.tok { switch p.tok {
@ -478,7 +483,7 @@ func (p *parser) stmts(stops ...string) (sts []*Stmt) {
return return
} }
func (p *parser) invalidStmtStart() { func (p *Parser) invalidStmtStart() {
switch p.tok { switch p.tok {
case semicolon, and, or, andAnd, orOr: case semicolon, and, or, andAnd, orOr:
p.curErr("%s can only immediately follow a statement", p.tok) p.curErr("%s can only immediately follow a statement", p.tok)
@ -489,14 +494,14 @@ func (p *parser) invalidStmtStart() {
} }
} }
func (p *parser) getWord() *Word { func (p *Parser) getWord() *Word {
if parts := p.wordParts(); len(parts) > 0 { if parts := p.wordParts(); len(parts) > 0 {
return p.word(parts) return p.word(parts)
} }
return nil return nil
} }
func (p *parser) getWordOrEmpty() *Word { func (p *Parser) getWordOrEmpty() *Word {
parts := p.wordParts() parts := p.wordParts()
if len(parts) == 0 { if len(parts) == 0 {
l := p.lit(p.pos, "") l := p.lit(p.pos, "")
@ -506,7 +511,7 @@ func (p *parser) getWordOrEmpty() *Word {
return p.word(parts) return p.word(parts)
} }
func (p *parser) getLit() *Lit { func (p *Parser) getLit() *Lit {
switch p.tok { switch p.tok {
case _Lit, _LitWord, _LitRedir: case _Lit, _LitWord, _LitRedir:
l := p.lit(p.pos, p.val) l := p.lit(p.pos, p.val)
@ -516,7 +521,7 @@ func (p *parser) getLit() *Lit {
return nil return nil
} }
func (p *parser) wordParts() (wps []WordPart) { func (p *Parser) wordParts() (wps []WordPart) {
for { for {
n := p.wordPart() n := p.wordPart()
if n == nil { if n == nil {
@ -533,13 +538,13 @@ func (p *parser) wordParts() (wps []WordPart) {
} }
} }
func (p *parser) ensureNoNested() { func (p *Parser) ensureNoNested() {
if p.forbidNested { if p.forbidNested {
p.curErr("expansions not allowed in heredoc words") p.curErr("expansions not allowed in heredoc words")
} }
} }
func (p *parser) wordPart() WordPart { func (p *Parser) wordPart() WordPart {
switch p.tok { switch p.tok {
case _Lit, _LitWord: case _Lit, _LitWord:
l := p.lit(p.pos, p.val) l := p.lit(p.pos, p.val)
@ -581,27 +586,8 @@ func (p *parser) wordPart() WordPart {
cs.Right = p.matched(cs.Left, leftParen, rightParen) cs.Right = p.matched(cs.Left, leftParen, rightParen)
return cs return cs
case dollar: case dollar:
r := p.r
if r == utf8.RuneSelf || wordBreak(r) || r == '"' || r == '\'' || r == '`' || r == '[' {
l := p.lit(p.pos, "$")
p.next()
return l
}
p.ensureNoNested() p.ensureNoNested()
pe := &ParamExp{Dollar: p.pos, Short: true} return p.shortParamExp()
p.pos++
switch r {
case '@', '*', '#', '$', '?', '!', '0', '-':
p.rune()
p.tok, p.val = _LitWord, string(r)
default:
old := p.quote
p.quote = paramName
p.advanceLitOther(r)
p.quote = old
}
pe.Param = p.getLit()
return pe
case cmdIn, cmdOut: case cmdIn, cmdOut:
p.ensureNoNested() p.ensureNoNested()
ps := &ProcSubst{Op: ProcOperator(p.tok), OpPos: p.pos} ps := &ProcSubst{Op: ProcOperator(p.tok), OpPos: p.pos}
@ -736,7 +722,7 @@ func arithmOpLevel(op BinAritOperator) int {
return -1 return -1
} }
func (p *parser) arithmExpr(ftok token, fpos Pos, level int, compact, tern bool) ArithmExpr { func (p *Parser) arithmExpr(ftok token, fpos Pos, level int, compact, tern bool) ArithmExpr {
if p.tok == _EOF || p.peekArithmEnd() { if p.tok == _EOF || p.peekArithmEnd() {
return nil return nil
} }
@ -804,7 +790,7 @@ func (p *parser) arithmExpr(ftok token, fpos Pos, level int, compact, tern bool)
return b return b
} }
func (p *parser) arithmExprBase(compact bool) ArithmExpr { func (p *Parser) arithmExprBase(compact bool) ArithmExpr {
var x ArithmExpr var x ArithmExpr
switch p.tok { switch p.tok {
case exclMark: case exclMark:
@ -837,16 +823,34 @@ func (p *parser) arithmExprBase(compact bool) ArithmExpr {
if p.next(); compact && p.spaced { if p.next(); compact && p.spaced {
p.followErrExp(ue.OpPos, ue.Op.String()) p.followErrExp(ue.OpPos, ue.Op.String())
} }
ue.X = p.arithmExpr(token(ue.Op), ue.OpPos, 0, compact, false) ue.X = p.arithmExprBase(compact)
if ue.X == nil { if ue.X == nil {
p.followErrExp(ue.OpPos, ue.Op.String()) p.followErrExp(ue.OpPos, ue.Op.String())
} }
x = ue x = ue
case illegalTok, rightBrack, rightBrace, rightParen: case illegalTok, rightBrack, rightBrace, rightParen:
case _LitWord: case _LitWord:
x = p.getLit() l := p.getLit()
case dollar, dollBrace: if p.r != '[' {
x = p.wordPart().(*ParamExp) x = l
break
}
pe := &ParamExp{Dollar: l.ValuePos, Short: true, Param: l}
p.rune()
left := p.pos + 1
old := p.preNested(arithmExprBrack)
p.next()
pe.Index = p.arithmExpr(leftBrack, left, 0, false, false)
if pe.Index == nil {
p.followErrExp(left, "[")
}
p.postNested(old)
p.matched(left, leftBrack, rightBrack)
x = pe
case dollar:
x = p.shortParamExp()
case dollBrace:
x = p.paramExp()
case bckQuote: case bckQuote:
if p.quote == arithmExprLet { if p.quote == arithmExprLet {
return nil return nil
@ -862,7 +866,16 @@ func (p *parser) arithmExprBase(compact bool) ArithmExpr {
return x return x
} }
if p.tok == addAdd || p.tok == subSub { if p.tok == addAdd || p.tok == subSub {
if l, ok := x.(*Lit); !ok || !validIdent(l.Value, p.bash()) { switch y := x.(type) {
case *Lit:
if !validIdent(y.Value, p.bash()) {
p.curErr("%s must follow a name", p.tok.String())
}
case *ParamExp:
if !y.nakedIndex() {
p.curErr("%s must follow a name", p.tok.String())
}
default:
p.curErr("%s must follow a name", p.tok.String()) p.curErr("%s must follow a name", p.tok.String())
} }
u := &UnaryArithm{ u := &UnaryArithm{
@ -877,7 +890,27 @@ func (p *parser) arithmExprBase(compact bool) ArithmExpr {
return x return x
} }
func (p *parser) paramExp() *ParamExp { func (p *Parser) shortParamExp() *ParamExp {
pe := &ParamExp{Dollar: p.pos, Short: true}
p.pos++
switch p.r {
case '@', '*', '#', '$', '?', '!', '0', '-':
p.tok, p.val = _LitWord, string(p.r)
p.rune()
default:
old := p.quote
p.quote = paramName
p.advanceLitOther(p.r)
p.quote = old
if p.val == "" || p.val == "\x80" {
p.posErr(pe.Dollar, "$ must be escaped or followed by a literal")
}
}
pe.Param = p.getLit()
return pe
}
func (p *Parser) paramExp() *ParamExp {
pe := &ParamExp{Dollar: p.pos} pe := &ParamExp{Dollar: p.pos}
old := p.quote old := p.quote
p.quote = paramExpName p.quote = paramExpName
@ -934,15 +967,11 @@ func (p *parser) paramExp() *ParamExp {
p.quote = paramExpInd p.quote = paramExpInd
p.next() p.next()
switch p.tok { switch p.tok {
case star: case star, at:
p.tok, p.val = _LitWord, "*" p.tok, p.val = _LitWord, p.tok.String()
case at:
p.tok, p.val = _LitWord, "@"
} }
pe.Ind = &Index{ pe.Index = p.arithmExpr(leftBrack, lpos, 0, false, false)
Expr: p.arithmExpr(leftBrack, lpos, 0, false, false), if pe.Index == nil {
}
if pe.Ind.Expr == nil {
p.followErrExp(lpos, "[") p.followErrExp(lpos, "[")
} }
p.quote = paramExpName p.quote = paramExpName
@ -1009,11 +1038,11 @@ func (p *parser) paramExp() *ParamExp {
return pe return pe
} }
func (p *parser) peekArithmEnd() bool { func (p *Parser) peekArithmEnd() bool {
return p.tok == rightParen && p.r == ')' return p.tok == rightParen && p.r == ')'
} }
func (p *parser) arithmEnd(ltok token, lpos Pos, old saveState) Pos { func (p *Parser) arithmEnd(ltok token, lpos Pos, old saveState) Pos {
if p.peekArithmEnd() { if p.peekArithmEnd() {
p.rune() p.rune()
} else { } else {
@ -1050,47 +1079,70 @@ func validIdent(val string, bash bool) bool {
return true return true
} }
func (p *parser) hasValidIdent() bool { func (p *Parser) hasValidIdent() bool {
if p.asPos < 1 { if p.asPos > 0 && validIdent(p.val[:p.asPos], p.bash()) {
return false return true
} }
return validIdent(p.val[:p.asPos], p.bash()) return p.tok == _Lit && p.r == '['
} }
func (p *parser) getAssign() *Assign { func (p *Parser) getAssign() *Assign {
as := &Assign{} as := &Assign{}
nameEnd := p.asPos if p.asPos > 0 { // foo=bar
if p.bash() && p.val[p.asPos-1] == '+' { nameEnd := p.asPos
// a+=b if p.bash() && p.val[p.asPos-1] == '+' {
as.Append = true // a+=b
nameEnd-- as.Append = true
nameEnd--
}
as.Name = p.lit(p.pos, p.val[:nameEnd])
// since we're not using the entire p.val
as.Name.ValueEnd = as.Name.ValuePos + Pos(nameEnd)
left := p.lit(p.pos+1, p.val[p.asPos+1:])
if left.Value != "" {
left.ValuePos += Pos(p.asPos)
as.Value = p.word(p.wps(left))
}
if p.next(); p.spaced {
return as
}
} else { // foo[i]=bar
as.Name = p.lit(p.pos, p.val)
// hasValidIdent already checks p.r is '['
p.rune()
left := p.pos + 1
old := p.preNested(arithmExprBrack)
p.next()
as.Index = p.arithmExpr(leftBrack, left, 0, false, false)
if as.Index == nil {
p.followErrExp(left, "[")
}
p.postNested(old)
p.matched(left, leftBrack, rightBrack)
if p.tok == _EOF || p.val[0] != '=' {
p.followErr(as.Pos(), "a[b]", "=")
return nil
}
p.pos++
p.val = p.val[1:]
if p.val == "" {
p.next()
}
} }
as.Name = p.lit(p.pos, p.val[:nameEnd]) if as.Value == nil && p.tok == leftParen {
// since we're not using the entire p.val
as.Name.ValueEnd = as.Name.ValuePos + Pos(nameEnd)
start := p.lit(p.pos+1, p.val[p.asPos+1:])
if start.Value != "" {
start.ValuePos += Pos(p.asPos)
as.Value = p.word(p.wps(start))
}
if p.next(); p.spaced {
return as
}
if start.Value == "" && p.tok == leftParen {
if !p.bash() { if !p.bash() {
p.curErr("arrays are a bash feature") p.curErr("arrays are a bash feature")
} }
ae := &ArrayExpr{Lparen: p.pos} as.Array = &ArrayExpr{Lparen: p.pos}
p.next() p.next()
for p.tok != _EOF && p.tok != rightParen { for p.tok != _EOF && p.tok != rightParen {
if w := p.getWord(); w == nil { if w := p.getWord(); w == nil {
p.curErr("array elements must be words") p.curErr("array elements must be words")
} else { } else {
ae.List = append(ae.List, w) as.Array.List = append(as.Array.List, w)
} }
} }
ae.Rparen = p.matched(ae.Lparen, leftParen, rightParen) as.Array.Rparen = p.matched(as.Array.Lparen, leftParen, rightParen)
as.Value = p.word(p.wps(ae))
} else if !p.newLine && !stopToken(p.tok) { } else if !p.newLine && !stopToken(p.tok) {
if w := p.getWord(); w != nil { if w := p.getWord(); w != nil {
if as.Value == nil { if as.Value == nil {
@ -1103,7 +1155,7 @@ func (p *parser) getAssign() *Assign {
return as return as
} }
func (p *parser) peekRedir() bool { func (p *Parser) peekRedir() bool {
switch p.tok { switch p.tok {
case rdrOut, appOut, rdrIn, dplIn, dplOut, clbOut, rdrInOut, case rdrOut, appOut, rdrIn, dplIn, dplOut, clbOut, rdrInOut,
hdoc, dashHdoc, wordHdoc, rdrAll, appAll, _LitRedir: hdoc, dashHdoc, wordHdoc, rdrAll, appAll, _LitRedir:
@ -1112,7 +1164,7 @@ func (p *parser) peekRedir() bool {
return false return false
} }
func (p *parser) doRedirect(s *Stmt) { func (p *Parser) doRedirect(s *Stmt) {
r := &Redirect{} r := &Redirect{}
r.N = p.getLit() r.N = p.getLit()
r.Op, r.OpPos = RedirOperator(p.tok), p.pos r.Op, r.OpPos = RedirOperator(p.tok), p.pos
@ -1136,7 +1188,7 @@ func (p *parser) doRedirect(s *Stmt) {
s.Redirs = append(s.Redirs, r) s.Redirs = append(s.Redirs, r)
} }
func (p *parser) getStmt(readEnd, binCmd bool) (s *Stmt, gotEnd bool) { func (p *Parser) getStmt(readEnd, binCmd bool) (s *Stmt, gotEnd bool) {
s = p.stmt(p.pos) s = p.stmt(p.pos)
if p.gotRsrv("!") { if p.gotRsrv("!") {
s.Negated = true s.Negated = true
@ -1209,7 +1261,7 @@ preLoop:
return return
} }
func (p *parser) gotStmtPipe(s *Stmt) *Stmt { func (p *Parser) gotStmtPipe(s *Stmt) *Stmt {
switch p.tok { switch p.tok {
case _LitWord: case _LitWord:
switch p.val { switch p.val {
@ -1217,10 +1269,8 @@ func (p *parser) gotStmtPipe(s *Stmt) *Stmt {
s.Cmd = p.block() s.Cmd = p.block()
case "if": case "if":
s.Cmd = p.ifClause() s.Cmd = p.ifClause()
case "while": case "while", "until":
s.Cmd = p.whileClause() s.Cmd = p.whileClause(p.val == "until")
case "until":
s.Cmd = p.untilClause()
case "for": case "for":
s.Cmd = p.forClause() s.Cmd = p.forClause()
case "case": case "case":
@ -1307,7 +1357,7 @@ func (p *parser) gotStmtPipe(s *Stmt) *Stmt {
return s return s
} }
func (p *parser) subshell() *Subshell { func (p *Parser) subshell() *Subshell {
s := &Subshell{Lparen: p.pos} s := &Subshell{Lparen: p.pos}
old := p.preNested(subCmd) old := p.preNested(subCmd)
p.next() p.next()
@ -1317,7 +1367,7 @@ func (p *parser) subshell() *Subshell {
return s return s
} }
func (p *parser) arithmExpCmd() Command { func (p *Parser) arithmExpCmd() Command {
ar := &ArithmCmd{Left: p.pos} ar := &ArithmCmd{Left: p.pos}
old := p.preNested(arithmExprCmd) old := p.preNested(arithmExprCmd)
p.next() p.next()
@ -1326,7 +1376,7 @@ func (p *parser) arithmExpCmd() Command {
return ar return ar
} }
func (p *parser) block() *Block { func (p *Parser) block() *Block {
b := &Block{Lbrace: p.pos} b := &Block{Lbrace: p.pos}
p.next() p.next()
b.Stmts = p.stmts("}") b.Stmts = p.stmts("}")
@ -1337,7 +1387,7 @@ func (p *parser) block() *Block {
return b return b
} }
func (p *parser) ifClause() *IfClause { func (p *Parser) ifClause() *IfClause {
ic := &IfClause{If: p.pos} ic := &IfClause{If: p.pos}
p.next() p.next()
ic.CondStmts = p.followStmts("if", ic.If, "then") ic.CondStmts = p.followStmts("if", ic.If, "then")
@ -1359,27 +1409,23 @@ func (p *parser) ifClause() *IfClause {
return ic return ic
} }
func (p *parser) whileClause() *WhileClause { func (p *Parser) whileClause(until bool) *WhileClause {
wc := &WhileClause{While: p.pos} wc := &WhileClause{While: p.pos, Until: until}
rsrv := "while"
rsrvCond := "while <cond>"
if wc.Until {
rsrv = "until"
rsrvCond = "until <cond>"
}
p.next() p.next()
wc.CondStmts = p.followStmts("while", wc.While, "do") wc.CondStmts = p.followStmts(rsrv, wc.While, "do")
wc.Do = p.followRsrv(wc.While, "while <cond>", "do") wc.Do = p.followRsrv(wc.While, rsrvCond, "do")
wc.DoStmts = p.followStmts("do", wc.Do, "done") wc.DoStmts = p.followStmts("do", wc.Do, "done")
wc.Done = p.stmtEnd(wc, "while", "done") wc.Done = p.stmtEnd(wc, rsrv, "done")
return wc return wc
} }
func (p *parser) untilClause() *UntilClause { func (p *Parser) forClause() *ForClause {
uc := &UntilClause{Until: p.pos}
p.next()
uc.CondStmts = p.followStmts("until", uc.Until, "do")
uc.Do = p.followRsrv(uc.Until, "until <cond>", "do")
uc.DoStmts = p.followStmts("do", uc.Do, "done")
uc.Done = p.stmtEnd(uc, "until", "done")
return uc
}
func (p *parser) forClause() *ForClause {
fc := &ForClause{For: p.pos} fc := &ForClause{For: p.pos}
p.next() p.next()
fc.Loop = p.loop(fc.For) fc.Loop = p.loop(fc.For)
@ -1389,7 +1435,7 @@ func (p *parser) forClause() *ForClause {
return fc return fc
} }
func (p *parser) loop(forPos Pos) Loop { func (p *Parser) loop(forPos Pos) Loop {
if p.tok == dblLeftParen { if p.tok == dblLeftParen {
cl := &CStyleLoop{Lparen: p.pos} cl := &CStyleLoop{Lparen: p.pos}
old := p.preNested(arithmExprCmd) old := p.preNested(arithmExprCmd)
@ -1437,7 +1483,7 @@ func (p *parser) loop(forPos Pos) Loop {
return wi return wi
} }
func (p *parser) caseClause() *CaseClause { func (p *Parser) caseClause() *CaseClause {
cc := &CaseClause{Case: p.pos} cc := &CaseClause{Case: p.pos}
p.next() p.next()
cc.Word = p.followWord("case", cc.Case) cc.Word = p.followWord("case", cc.Case)
@ -1447,7 +1493,7 @@ func (p *parser) caseClause() *CaseClause {
return cc return cc
} }
func (p *parser) patLists() (pls []*PatternList) { func (p *Parser) patLists() (pls []*PatternList) {
for p.tok != _EOF && !(p.tok == _LitWord && p.val == "esac") { for p.tok != _EOF && !(p.tok == _LitWord && p.val == "esac") {
pl := &PatternList{} pl := &PatternList{}
p.got(leftParen) p.got(leftParen)
@ -1481,7 +1527,7 @@ func (p *parser) patLists() (pls []*PatternList) {
return return
} }
func (p *parser) testClause() *TestClause { func (p *Parser) testClause() *TestClause {
tc := &TestClause{Left: p.pos} tc := &TestClause{Left: p.pos}
if p.next(); p.tok == _EOF || p.gotRsrv("]]") { if p.next(); p.tok == _EOF || p.gotRsrv("]]") {
p.posErr(tc.Left, "test clause requires at least one expression") p.posErr(tc.Left, "test clause requires at least one expression")
@ -1494,7 +1540,7 @@ func (p *parser) testClause() *TestClause {
return tc return tc
} }
func (p *parser) testExpr(ftok token, fpos Pos, level int) TestExpr { func (p *Parser) testExpr(ftok token, fpos Pos, level int) TestExpr {
var left TestExpr var left TestExpr
if level > 1 { if level > 1 {
left = p.testExprBase(ftok, fpos) left = p.testExprBase(ftok, fpos)
@ -1517,7 +1563,7 @@ func (p *parser) testExpr(ftok token, fpos Pos, level int) TestExpr {
case _EOF, rightParen: case _EOF, rightParen:
return left return left
case _Lit: case _Lit:
p.curErr("not a valid test operator: %s", p.val) p.curErr("test operator words must consist of a single literal")
default: default:
p.curErr("not a valid test operator: %v", p.tok) p.curErr("not a valid test operator: %v", p.tok)
} }
@ -1555,7 +1601,7 @@ func (p *parser) testExpr(ftok token, fpos Pos, level int) TestExpr {
return b return b
} }
func (p *parser) testExprBase(ftok token, fpos Pos) TestExpr { func (p *Parser) testExprBase(ftok token, fpos Pos) TestExpr {
switch p.tok { switch p.tok {
case _EOF: case _EOF:
return nil return nil
@ -1598,7 +1644,7 @@ func (p *parser) testExprBase(ftok token, fpos Pos) TestExpr {
} }
} }
func (p *parser) declClause() *DeclClause { func (p *Parser) declClause() *DeclClause {
name := p.val name := p.val
ds := &DeclClause{Position: p.pos} ds := &DeclClause{Position: p.pos}
switch name { switch name {
@ -1637,7 +1683,7 @@ func isBashCompoundCommand(tok token, val string) bool {
return false return false
} }
func (p *parser) coprocClause() *CoprocClause { func (p *Parser) coprocClause() *CoprocClause {
cc := &CoprocClause{Coproc: p.pos} cc := &CoprocClause{Coproc: p.pos}
if p.next(); isBashCompoundCommand(p.tok, p.val) { if p.next(); isBashCompoundCommand(p.tok, p.val) {
// has no name // has no name
@ -1669,7 +1715,7 @@ func (p *parser) coprocClause() *CoprocClause {
return cc return cc
} }
func (p *parser) letClause() *LetClause { func (p *Parser) letClause() *LetClause {
lc := &LetClause{Let: p.pos} lc := &LetClause{Let: p.pos}
old := p.preNested(arithmExprLet) old := p.preNested(arithmExprLet)
p.next() p.next()
@ -1690,7 +1736,7 @@ func (p *parser) letClause() *LetClause {
return lc return lc
} }
func (p *parser) bashFuncDecl() *FuncDecl { func (p *Parser) bashFuncDecl() *FuncDecl {
fpos := p.pos fpos := p.pos
p.next() p.next()
if p.tok != _LitWord { if p.tok != _LitWord {
@ -1705,7 +1751,7 @@ func (p *parser) bashFuncDecl() *FuncDecl {
return p.funcDecl(name, fpos) return p.funcDecl(name, fpos)
} }
func (p *parser) callExpr(s *Stmt, w *Word) *CallExpr { func (p *Parser) callExpr(s *Stmt, w *Word) *CallExpr {
ce := p.call(w) ce := p.call(w)
for !p.newLine { for !p.newLine {
switch p.tok { switch p.tok {
@ -1743,7 +1789,7 @@ func (p *parser) callExpr(s *Stmt, w *Word) *CallExpr {
return ce return ce
} }
func (p *parser) funcDecl(name *Lit, pos Pos) *FuncDecl { func (p *Parser) funcDecl(name *Lit, pos Pos) *FuncDecl {
fd := &FuncDecl{ fd := &FuncDecl{
Position: pos, Position: pos,
BashStyle: pos != name.ValuePos, BashStyle: pos != name.ValuePos,

View File

@ -6,54 +6,40 @@ package syntax
import ( import (
"bufio" "bufio"
"io" "io"
"sync" "strings"
) )
// PrintConfig controls how the printing of an AST node will behave. func Indent(spaces int) func(*Printer) {
type PrintConfig struct { return func(p *Printer) { p.indentSpaces = spaces }
// Spaces dictates the indentation style. The default value of 0
// uses tabs, and any positive value uses that number of spaces.
Spaces int
// BinaryNextLine makes binary operators (such as &&, || and |)
// be at the start of a line if the statement that follows them
// is on a separate line. This means that the operator will come
// after an escaped newline.
BinaryNextLine bool
} }
var printerFree = sync.Pool{ func BinaryNextLine(p *Printer) { p.binNextLine = true }
New: func() interface{} {
return &printer{ func NewPrinter(options ...func(*Printer)) *Printer {
bufWriter: bufio.NewWriter(nil), p := &Printer{
lenPrinter: new(printer), bufWriter: bufio.NewWriter(nil),
} lenPrinter: new(Printer),
}, }
for _, opt := range options {
opt(p)
}
return p
} }
// Fprint "pretty-prints" the given AST file to the given writer. // Print "pretty-prints" the given AST file to the given writer.
func (c PrintConfig) Fprint(w io.Writer, f *File) error { func (p *Printer) Print(w io.Writer, f *File) error {
p := printerFree.Get().(*printer)
p.reset() p.reset()
p.PrintConfig = c
p.lines, p.comments = f.lines, f.Comments p.lines, p.comments = f.lines, f.Comments
p.bufWriter.Reset(w) p.bufWriter.Reset(w)
p.stmts(f.Stmts) p.stmts(f.Stmts)
p.commentsUpTo(0) p.commentsUpTo(0)
p.newline(0) p.newline(0)
var err error
if flusher, ok := p.bufWriter.(interface { if flusher, ok := p.bufWriter.(interface {
Flush() error Flush() error
}); ok { }); ok {
err = flusher.Flush() return flusher.Flush()
} }
printerFree.Put(p) return nil
return err
}
// Fprint "pretty-prints" the given AST file to the given writer. It
// calls PrintConfig.Fprint with its default settings.
func Fprint(w io.Writer, f *File) error {
return PrintConfig{}.Fprint(w, f)
} }
type bufWriter interface { type bufWriter interface {
@ -62,10 +48,12 @@ type bufWriter interface {
Reset(io.Writer) Reset(io.Writer)
} }
type printer struct { type Printer struct {
bufWriter bufWriter
PrintConfig indentSpaces int
binNextLine bool
lines []Pos lines []Pos
wantSpace bool wantSpace bool
@ -94,12 +82,12 @@ type printer struct {
// pendingHdocs is the list of pending heredocs to write. // pendingHdocs is the list of pending heredocs to write.
pendingHdocs []*Redirect pendingHdocs []*Redirect
// used in stmtLen to align comments // used in stmtCols to align comments
lenPrinter *printer lenPrinter *Printer
lenCounter byteCounter lenCounter byteCounter
} }
func (p *printer) reset() { func (p *Printer) reset() {
p.wantSpace, p.wantNewline = false, false p.wantSpace, p.wantNewline = false, false
p.commentPadding = 0 p.commentPadding = 0
p.nline, p.nlineIndex = 0, 0 p.nline, p.nlineIndex = 0, 0
@ -109,7 +97,7 @@ func (p *printer) reset() {
p.pendingHdocs = p.pendingHdocs[:0] p.pendingHdocs = p.pendingHdocs[:0]
} }
func (p *printer) incLine() { func (p *Printer) incLine() {
if p.nlineIndex++; p.nlineIndex >= len(p.lines) { if p.nlineIndex++; p.nlineIndex >= len(p.lines) {
p.nline = maxPos p.nline = maxPos
} else { } else {
@ -117,19 +105,19 @@ func (p *printer) incLine() {
} }
} }
func (p *printer) incLines(pos Pos) { func (p *Printer) incLines(pos Pos) {
for p.nline < pos { for p.nline < pos {
p.incLine() p.incLine()
} }
} }
func (p *printer) spaces(n int) { func (p *Printer) spaces(n int) {
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
p.WriteByte(' ') p.WriteByte(' ')
} }
} }
func (p *printer) bslashNewl() { func (p *Printer) bslashNewl() {
if p.wantSpace { if p.wantSpace {
p.WriteByte(' ') p.WriteByte(' ')
} }
@ -138,7 +126,7 @@ func (p *printer) bslashNewl() {
p.incLine() p.incLine()
} }
func (p *printer) spacedString(s string) { func (p *Printer) spacedString(s string) {
if p.wantSpace { if p.wantSpace {
p.WriteByte(' ') p.WriteByte(' ')
} }
@ -146,7 +134,7 @@ func (p *printer) spacedString(s string) {
p.wantSpace = true p.wantSpace = true
} }
func (p *printer) semiOrNewl(s string, pos Pos) { func (p *Printer) semiOrNewl(s string, pos Pos) {
if p.wantNewline { if p.wantNewline {
p.newline(pos) p.newline(pos)
p.indent() p.indent()
@ -161,7 +149,7 @@ func (p *printer) semiOrNewl(s string, pos Pos) {
p.wantSpace = true p.wantSpace = true
} }
func (p *printer) incLevel() { func (p *Printer) incLevel() {
inc := false inc := false
if p.level <= p.lastLevel || len(p.levelIncs) == 0 { if p.level <= p.lastLevel || len(p.levelIncs) == 0 {
p.level++ p.level++
@ -173,27 +161,27 @@ func (p *printer) incLevel() {
p.levelIncs = append(p.levelIncs, inc) p.levelIncs = append(p.levelIncs, inc)
} }
func (p *printer) decLevel() { func (p *Printer) decLevel() {
if p.levelIncs[len(p.levelIncs)-1] { if p.levelIncs[len(p.levelIncs)-1] {
p.level-- p.level--
} }
p.levelIncs = p.levelIncs[:len(p.levelIncs)-1] p.levelIncs = p.levelIncs[:len(p.levelIncs)-1]
} }
func (p *printer) indent() { func (p *Printer) indent() {
p.lastLevel = p.level p.lastLevel = p.level
switch { switch {
case p.level == 0: case p.level == 0:
case p.Spaces == 0: case p.indentSpaces == 0:
for i := 0; i < p.level; i++ { for i := 0; i < p.level; i++ {
p.WriteByte('\t') p.WriteByte('\t')
} }
case p.Spaces > 0: case p.indentSpaces > 0:
p.spaces(p.Spaces * p.level) p.spaces(p.indentSpaces * p.level)
} }
} }
func (p *printer) newline(pos Pos) { func (p *Printer) newline(pos Pos) {
p.wantNewline, p.wantSpace = false, false p.wantNewline, p.wantSpace = false, false
p.WriteByte('\n') p.WriteByte('\n')
if pos > p.nline { if pos > p.nline {
@ -211,7 +199,7 @@ func (p *printer) newline(pos Pos) {
} }
} }
func (p *printer) newlines(pos Pos) { func (p *Printer) newlines(pos Pos) {
p.newline(pos) p.newline(pos)
if pos > p.nline { if pos > p.nline {
// preserve single empty lines // preserve single empty lines
@ -221,14 +209,14 @@ func (p *printer) newlines(pos Pos) {
p.indent() p.indent()
} }
func (p *printer) commentsAndSeparate(pos Pos) { func (p *Printer) commentsAndSeparate(pos Pos) {
p.commentsUpTo(pos) p.commentsUpTo(pos)
if p.wantNewline || pos > p.nline { if p.wantNewline || pos > p.nline {
p.newlines(pos) p.newlines(pos)
} }
} }
func (p *printer) sepTok(s string, pos Pos) { func (p *Printer) sepTok(s string, pos Pos) {
p.level++ p.level++
p.commentsUpTo(pos) p.commentsUpTo(pos)
p.level-- p.level--
@ -239,7 +227,7 @@ func (p *printer) sepTok(s string, pos Pos) {
p.wantSpace = true p.wantSpace = true
} }
func (p *printer) semiRsrv(s string, pos Pos, fallback bool) { func (p *Printer) semiRsrv(s string, pos Pos, fallback bool) {
p.level++ p.level++
p.commentsUpTo(pos) p.commentsUpTo(pos)
p.level-- p.level--
@ -257,14 +245,14 @@ func (p *printer) semiRsrv(s string, pos Pos, fallback bool) {
p.wantSpace = true p.wantSpace = true
} }
func (p *printer) anyCommentsBefore(pos Pos) bool { func (p *Printer) anyCommentsBefore(pos Pos) bool {
if !pos.IsValid() || len(p.comments) < 1 { if !pos.IsValid() || len(p.comments) < 1 {
return false return false
} }
return p.comments[0].Hash < pos return p.comments[0].Hash < pos
} }
func (p *printer) commentsUpTo(pos Pos) { func (p *Printer) commentsUpTo(pos Pos) {
if len(p.comments) < 1 { if len(p.comments) < 1 {
return return
} }
@ -286,7 +274,7 @@ func (p *printer) commentsUpTo(pos Pos) {
p.commentsUpTo(pos) p.commentsUpTo(pos)
} }
func (p *printer) wordPart(wp WordPart) { func (p *Printer) wordPart(wp WordPart) {
switch x := wp.(type) { switch x := wp.(type) {
case *Lit: case *Lit:
p.WriteString(x.Value) p.WriteString(x.Value)
@ -320,13 +308,8 @@ func (p *printer) wordPart(wp WordPart) {
p.paramExp(x) p.paramExp(x)
case *ArithmExp: case *ArithmExp:
p.WriteString("$((") p.WriteString("$((")
p.arithmExpr(x.X, false) p.arithmExpr(x.X, false, false)
p.WriteString("))") p.WriteString("))")
case *ArrayExpr:
p.wantSpace = false
p.WriteByte('(')
p.wordJoin(x.List, false)
p.sepTok(")", x.Rparen)
case *ExtGlob: case *ExtGlob:
p.WriteString(x.Op.String()) p.WriteString(x.Op.String())
p.WriteString(x.Pattern.Value) p.WriteString(x.Pattern.Value)
@ -343,12 +326,20 @@ func (p *printer) wordPart(wp WordPart) {
} }
} }
func (p *printer) paramExp(pe *ParamExp) { func (p *Printer) paramExp(pe *ParamExp) {
if pe.Short { if pe.nakedIndex() { // arr[i]
p.WriteString(pe.Param.Value)
p.WriteByte('[')
p.arithmExpr(pe.Index, false, false)
p.WriteByte(']')
return
}
if pe.Short { // $var
p.WriteByte('$') p.WriteByte('$')
p.WriteString(pe.Param.Value) p.WriteString(pe.Param.Value)
return return
} }
// ${var...}
p.WriteString("${") p.WriteString("${")
switch { switch {
case pe.Length: case pe.Length:
@ -359,23 +350,17 @@ func (p *printer) paramExp(pe *ParamExp) {
if pe.Param != nil { if pe.Param != nil {
p.WriteString(pe.Param.Value) p.WriteString(pe.Param.Value)
} }
if pe.Ind != nil { if pe.Index != nil {
p.WriteByte('[') p.WriteByte('[')
p.arithmExpr(pe.Ind.Expr, false) p.arithmExpr(pe.Index, false, false)
p.WriteByte(']') p.WriteByte(']')
} }
if pe.Slice != nil { if pe.Slice != nil {
p.WriteByte(':') p.WriteByte(':')
if un, ok := pe.Slice.Offset.(*UnaryArithm); ok { p.arithmExpr(pe.Slice.Offset, true, true)
if un.Op == Plus || un.Op == Minus {
// to avoid :+ and :-
p.WriteByte(' ')
}
}
p.arithmExpr(pe.Slice.Offset, true)
if pe.Slice.Length != nil { if pe.Slice.Length != nil {
p.WriteByte(':') p.WriteByte(':')
p.arithmExpr(pe.Slice.Length, true) p.arithmExpr(pe.Slice.Length, true, false)
} }
} else if pe.Repl != nil { } else if pe.Repl != nil {
if pe.Repl.All { if pe.Repl.All {
@ -392,7 +377,7 @@ func (p *printer) paramExp(pe *ParamExp) {
p.WriteByte('}') p.WriteByte('}')
} }
func (p *printer) loop(loop Loop) { func (p *Printer) loop(loop Loop) {
switch x := loop.(type) { switch x := loop.(type) {
case *WordIter: case *WordIter:
p.WriteString(x.Name.Value) p.WriteString(x.Name.Value)
@ -405,16 +390,16 @@ func (p *printer) loop(loop Loop) {
if x.Init == nil { if x.Init == nil {
p.WriteByte(' ') p.WriteByte(' ')
} }
p.arithmExpr(x.Init, false) p.arithmExpr(x.Init, false, false)
p.WriteString("; ") p.WriteString("; ")
p.arithmExpr(x.Cond, false) p.arithmExpr(x.Cond, false, false)
p.WriteString("; ") p.WriteString("; ")
p.arithmExpr(x.Post, false) p.arithmExpr(x.Post, false, false)
p.WriteString("))") p.WriteString("))")
} }
} }
func (p *printer) arithmExpr(expr ArithmExpr, compact bool) { func (p *Printer) arithmExpr(expr ArithmExpr, compact, spacePlusMinus bool) {
switch x := expr.(type) { switch x := expr.(type) {
case *Lit: case *Lit:
p.WriteString(x.Value) p.WriteString(x.Value)
@ -422,34 +407,40 @@ func (p *printer) arithmExpr(expr ArithmExpr, compact bool) {
p.paramExp(x) p.paramExp(x)
case *BinaryArithm: case *BinaryArithm:
if compact { if compact {
p.arithmExpr(x.X, compact) p.arithmExpr(x.X, compact, spacePlusMinus)
p.WriteString(x.Op.String()) p.WriteString(x.Op.String())
p.arithmExpr(x.Y, compact) p.arithmExpr(x.Y, compact, false)
} else { } else {
p.arithmExpr(x.X, compact) p.arithmExpr(x.X, compact, spacePlusMinus)
if x.Op != Comma { if x.Op != Comma {
p.WriteByte(' ') p.WriteByte(' ')
} }
p.WriteString(x.Op.String()) p.WriteString(x.Op.String())
p.WriteByte(' ') p.WriteByte(' ')
p.arithmExpr(x.Y, compact) p.arithmExpr(x.Y, compact, false)
} }
case *UnaryArithm: case *UnaryArithm:
if x.Post { if x.Post {
p.arithmExpr(x.X, compact) p.arithmExpr(x.X, compact, spacePlusMinus)
p.WriteString(x.Op.String()) p.WriteString(x.Op.String())
} else { } else {
if spacePlusMinus {
switch x.Op {
case Plus, Minus:
p.WriteByte(' ')
}
}
p.WriteString(x.Op.String()) p.WriteString(x.Op.String())
p.arithmExpr(x.X, compact) p.arithmExpr(x.X, compact, false)
} }
case *ParenArithm: case *ParenArithm:
p.WriteByte('(') p.WriteByte('(')
p.arithmExpr(x.X, false) p.arithmExpr(x.X, false, false)
p.WriteByte(')') p.WriteByte(')')
} }
} }
func (p *printer) testExpr(expr TestExpr) { func (p *Printer) testExpr(expr TestExpr) {
switch x := expr.(type) { switch x := expr.(type) {
case *Word: case *Word:
p.word(x) p.word(x)
@ -470,14 +461,14 @@ func (p *printer) testExpr(expr TestExpr) {
} }
} }
func (p *printer) word(w *Word) { func (p *Printer) word(w *Word) {
for _, n := range w.Parts { for _, n := range w.Parts {
p.wordPart(n) p.wordPart(n)
} }
p.wantSpace = true p.wantSpace = true
} }
func (p *printer) unquotedWord(w *Word) { func (p *Printer) unquotedWord(w *Word) {
for _, wp := range w.Parts { for _, wp := range w.Parts {
switch x := wp.(type) { switch x := wp.(type) {
case *SglQuoted: case *SglQuoted:
@ -500,7 +491,7 @@ func (p *printer) unquotedWord(w *Word) {
} }
} }
func (p *printer) wordJoin(ws []*Word, backslash bool) { func (p *Printer) wordJoin(ws []*Word, backslash bool) {
anyNewline := false anyNewline := false
for _, w := range ws { for _, w := range ws {
if pos := w.Pos(); pos > p.nline { if pos := w.Pos(); pos > p.nline {
@ -527,7 +518,7 @@ func (p *printer) wordJoin(ws []*Word, backslash bool) {
} }
} }
func (p *printer) stmt(s *Stmt) { func (p *Printer) stmt(s *Stmt) {
if s.Negated { if s.Negated {
p.spacedString("!") p.spacedString("!")
} }
@ -575,7 +566,7 @@ func (p *printer) stmt(s *Stmt) {
} }
} }
func (p *printer) command(cmd Command, redirs []*Redirect) (startRedirs int) { func (p *Printer) command(cmd Command, redirs []*Redirect) (startRedirs int) {
if p.wantSpace { if p.wantSpace {
p.WriteByte(' ') p.WriteByte(' ')
p.wantSpace = false p.wantSpace = false
@ -631,7 +622,11 @@ func (p *printer) command(cmd Command, redirs []*Redirect) (startRedirs int) {
p.nestedStmts(x.Stmts, x.Rparen) p.nestedStmts(x.Stmts, x.Rparen)
p.sepTok(")", x.Rparen) p.sepTok(")", x.Rparen)
case *WhileClause: case *WhileClause:
p.spacedString("while") if x.Until {
p.spacedString("until")
} else {
p.spacedString("while")
}
p.nestedStmts(x.CondStmts, 0) p.nestedStmts(x.CondStmts, 0)
p.semiOrNewl("do", x.Do) p.semiOrNewl("do", x.Do)
p.nestedStmts(x.DoStmts, 0) p.nestedStmts(x.DoStmts, 0)
@ -649,7 +644,7 @@ func (p *printer) command(cmd Command, redirs []*Redirect) (startRedirs int) {
p.incLevel() p.incLevel()
} }
_, p.nestedBinary = x.Y.Cmd.(*BinaryCmd) _, p.nestedBinary = x.Y.Cmd.(*BinaryCmd)
if p.BinaryNextLine { if p.binNextLine {
if len(p.pendingHdocs) == 0 && x.Y.Pos() > p.nline { if len(p.pendingHdocs) == 0 && x.Y.Pos() > p.nline {
p.bslashNewl() p.bslashNewl()
p.indent() p.indent()
@ -665,8 +660,12 @@ func (p *printer) command(cmd Command, redirs []*Redirect) (startRedirs int) {
p.indent() p.indent()
} }
} else { } else {
p.wantSpace = true
p.spacedString(x.Op.String()) p.spacedString(x.Op.String())
if x.Y.Pos() > p.nline { if x.Y.Pos() > p.nline {
if x.OpPos > p.nline {
p.incLines(x.OpPos)
}
p.commentsUpTo(x.Y.Pos()) p.commentsUpTo(x.Y.Pos())
p.newline(0) p.newline(0)
p.indent() p.indent()
@ -720,15 +719,9 @@ func (p *printer) command(cmd Command, redirs []*Redirect) (startRedirs int) {
} }
p.decLevel() p.decLevel()
p.semiRsrv("esac", x.Esac, len(x.List) == 0) p.semiRsrv("esac", x.Esac, len(x.List) == 0)
case *UntilClause:
p.spacedString("until")
p.nestedStmts(x.CondStmts, 0)
p.semiOrNewl("do", x.Do)
p.nestedStmts(x.DoStmts, 0)
p.semiRsrv("done", x.Done, true)
case *ArithmCmd: case *ArithmCmd:
p.WriteString("((") p.WriteString("((")
p.arithmExpr(x.X, false) p.arithmExpr(x.X, false, false)
p.WriteString("))") p.WriteString("))")
case *TestClause: case *TestClause:
p.WriteString("[[ ") p.WriteString("[[ ")
@ -756,7 +749,7 @@ func (p *printer) command(cmd Command, redirs []*Redirect) (startRedirs int) {
p.spacedString("let") p.spacedString("let")
for _, n := range x.Exprs { for _, n := range x.Exprs {
p.WriteByte(' ') p.WriteByte(' ')
p.arithmExpr(n, true) p.arithmExpr(n, true, false)
} }
} }
return startRedirs return startRedirs
@ -772,7 +765,7 @@ func startsWithLparen(s *Stmt) bool {
return false return false
} }
func (p *printer) hasInline(pos, npos, nline Pos) bool { func (p *Printer) hasInline(pos, npos, nline Pos) bool {
for _, c := range p.comments { for _, c := range p.comments {
if c.Hash > nline { if c.Hash > nline {
return false return false
@ -784,7 +777,7 @@ func (p *printer) hasInline(pos, npos, nline Pos) bool {
return false return false
} }
func (p *printer) stmts(stmts []*Stmt) { func (p *Printer) stmts(stmts []*Stmt) {
switch len(stmts) { switch len(stmts) {
case 0: case 0:
return return
@ -836,10 +829,10 @@ func (p *printer) stmts(stmts []*Stmt) {
if j+1 < len(follow) { if j+1 < len(follow) {
npos2 = follow[j+1].Pos() npos2 = follow[j+1].Pos()
} }
if pos2 > nline2 || !p.hasInline(pos2, npos2, nline2) { if !p.hasInline(pos2, npos2, nline2) {
break break
} }
if l := p.stmtLen(s2); l > inlineIndent { if l := p.stmtCols(s2); l > inlineIndent {
inlineIndent = l inlineIndent = l
} }
if ind2++; ind2 >= len(p.lines) { if ind2++; ind2 >= len(p.lines) {
@ -854,7 +847,9 @@ func (p *printer) stmts(stmts []*Stmt) {
} }
} }
if inlineIndent > 0 { if inlineIndent > 0 {
p.commentPadding = inlineIndent - p.stmtLen(s) if l := p.stmtCols(s); l > 0 {
p.commentPadding = inlineIndent - l
}
} }
} }
p.wantNewline = true p.wantNewline = true
@ -863,24 +858,41 @@ func (p *printer) stmts(stmts []*Stmt) {
type byteCounter int type byteCounter int
func (c *byteCounter) WriteByte(b byte) error { func (c *byteCounter) WriteByte(b byte) error {
*c++ switch {
case *c < 0:
case b == '\n':
*c = -1
default:
*c++
}
return nil return nil
} }
func (c *byteCounter) WriteString(s string) (int, error) { func (c *byteCounter) WriteString(s string) (int, error) {
*c += byteCounter(len(s)) switch {
case *c < 0:
case strings.Contains(s, "\n"):
*c = -1
default:
*c += byteCounter(len(s))
}
return 0, nil return 0, nil
} }
func (c *byteCounter) Reset(io.Writer) { *c = 0 } func (c *byteCounter) Reset(io.Writer) { *c = 0 }
func (p *printer) stmtLen(s *Stmt) int { // stmtCols reports the length that s will take when formatted in a
*p.lenPrinter = printer{bufWriter: &p.lenCounter} // single line. If it will span multiple lines, stmtCols will return -1.
func (p *Printer) stmtCols(s *Stmt) int {
*p.lenPrinter = Printer{
bufWriter: &p.lenCounter,
lines: p.lines,
}
p.lenPrinter.bufWriter.Reset(nil) p.lenPrinter.bufWriter.Reset(nil)
p.lenPrinter.incLines(s.Pos()) p.lenPrinter.incLines(s.Pos())
p.lenPrinter.stmt(s) p.lenPrinter.stmt(s)
return int(p.lenCounter) return int(p.lenCounter)
} }
func (p *printer) nestedStmts(stmts []*Stmt, closing Pos) { func (p *Printer) nestedStmts(stmts []*Stmt, closing Pos) {
p.incLevel() p.incLevel()
if len(stmts) == 1 && closing > p.nline && stmts[0].End() <= p.nline { if len(stmts) == 1 && closing > p.nline && stmts[0].End() <= p.nline {
p.newline(0) p.newline(0)
@ -890,7 +902,7 @@ func (p *printer) nestedStmts(stmts []*Stmt, closing Pos) {
p.decLevel() p.decLevel()
} }
func (p *printer) assigns(assigns []*Assign) { func (p *Printer) assigns(assigns []*Assign) {
anyNewline := false anyNewline := false
for _, a := range assigns { for _, a := range assigns {
if a.Pos() > p.nline { if a.Pos() > p.nline {
@ -905,6 +917,11 @@ func (p *printer) assigns(assigns []*Assign) {
} }
if a.Name != nil { if a.Name != nil {
p.WriteString(a.Name.Value) p.WriteString(a.Name.Value)
if a.Index != nil {
p.WriteByte('[')
p.arithmExpr(a.Index, false, false)
p.WriteByte(']')
}
if a.Append { if a.Append {
p.WriteByte('+') p.WriteByte('+')
} }
@ -912,6 +929,11 @@ func (p *printer) assigns(assigns []*Assign) {
} }
if a.Value != nil { if a.Value != nil {
p.word(a.Value) p.word(a.Value)
} else if a.Array != nil {
p.wantSpace = false
p.WriteByte('(')
p.wordJoin(a.Array.List, false)
p.sepTok(")", a.Array.Rparen)
} }
p.wantSpace = true p.wantSpace = true
} }

View File

@ -46,6 +46,9 @@ func Walk(node Node, f func(Node) bool) {
if x.Value != nil { if x.Value != nil {
Walk(x.Value, f) Walk(x.Value, f)
} }
if x.Array != nil {
Walk(x.Array, f)
}
case *Redirect: case *Redirect:
if x.N != nil { if x.N != nil {
Walk(x.N, f) Walk(x.N, f)
@ -71,9 +74,6 @@ func Walk(node Node, f func(Node) bool) {
case *WhileClause: case *WhileClause:
walkStmts(x.CondStmts, f) walkStmts(x.CondStmts, f)
walkStmts(x.DoStmts, f) walkStmts(x.DoStmts, f)
case *UntilClause:
walkStmts(x.CondStmts, f)
walkStmts(x.DoStmts, f)
case *ForClause: case *ForClause:
Walk(x.Loop, f) Walk(x.Loop, f)
walkStmts(x.DoStmts, f) walkStmts(x.DoStmts, f)
@ -112,8 +112,8 @@ func Walk(node Node, f func(Node) bool) {
if x.Param != nil { if x.Param != nil {
Walk(x.Param, f) Walk(x.Param, f)
} }
if x.Ind != nil { if x.Index != nil {
Walk(x.Ind.Expr, f) Walk(x.Index, f)
} }
if x.Repl != nil { if x.Repl != nil {
Walk(x.Repl.Orig, f) Walk(x.Repl.Orig, f)

12
vendor/vendor.json vendored
View File

@ -51,16 +51,16 @@
"revisionTime": "2017-01-24T11:57:57Z" "revisionTime": "2017-01-24T11:57:57Z"
}, },
{ {
"checksumSHA1": "ZjfvXVu+OyeRBysQ8uowAkPD/6o=", "checksumSHA1": "ohm6oyTSFu/xZk5HtTUG7RIONO4=",
"path": "github.com/mvdan/sh/interp", "path": "github.com/mvdan/sh/interp",
"revision": "faf782d3a498f50cfc6aa9071d04e6f1e82e8035", "revision": "380eaf2df0412887a2240f5b15e76ac810ca2e71",
"revisionTime": "2017-04-30T14:10:52Z" "revisionTime": "2017-05-17T16:44:15Z"
}, },
{ {
"checksumSHA1": "RIR7FOsCR78SmOJOUsclJe9lvxo=", "checksumSHA1": "2OcfNJuStj/eAcEPW5yRdU02DCc=",
"path": "github.com/mvdan/sh/syntax", "path": "github.com/mvdan/sh/syntax",
"revision": "faf782d3a498f50cfc6aa9071d04e6f1e82e8035", "revision": "380eaf2df0412887a2240f5b15e76ac810ca2e71",
"revisionTime": "2017-04-30T14:10:52Z" "revisionTime": "2017-05-17T16:44:15Z"
}, },
{ {
"checksumSHA1": "HUXE+Nrcau8FSaVEvPYHMvDjxOE=", "checksumSHA1": "HUXE+Nrcau8FSaVEvPYHMvDjxOE=",