1
0
mirror of https://github.com/go-kratos/kratos.git synced 2025-01-10 00:29:01 +02:00
kratos/tool/pkg/common.go
jiankuny 0ffb6233bc
add bts-gen & mc-gen (#96)
add genbts & genmc
2019-05-14 13:57:02 +08:00

150 lines
3.4 KiB
Go

package pkg
import (
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"io/ioutil"
"log"
"os"
"regexp"
"strings"
)
// Source source
type Source struct {
Fset *token.FileSet
Src string
F *ast.File
}
// NewSource new source
func NewSource(src string) *Source {
s := &Source{
Fset: token.NewFileSet(),
Src: src,
}
f, err := parser.ParseFile(s.Fset, "", src, 0)
if err != nil {
log.Fatal("无法解析源文件")
}
s.F = f
return s
}
// ExprString expr string
func (s *Source) ExprString(typ ast.Expr) string {
fset := s.Fset
s1 := fset.Position(typ.Pos()).Offset
s2 := fset.Position(typ.End()).Offset
return s.Src[s1:s2]
}
// pkgPath package path
func (s *Source) pkgPath(name string) (res string) {
for _, im := range s.F.Imports {
if im.Name != nil && im.Name.Name == name {
return im.Path.Value
}
}
for _, im := range s.F.Imports {
if strings.HasSuffix(im.Path.Value, name+"\"") {
return im.Path.Value
}
}
return
}
// GetDef get define code
func (s *Source) GetDef(name string) string {
c := s.F.Scope.Lookup(name).Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType)
s1 := s.Fset.Position(c.Pos()).Offset
s2 := s.Fset.Position(c.End()).Offset
line := s.Fset.Position(c.Pos()).Line
lines := []string{strings.Split(s.Src, "\n")[line-1]}
for _, l := range strings.Split(s.Src[s1:s2], "\n")[1:] {
lines = append(lines, "\t"+l)
}
return strings.Join(lines, "\n")
}
// RegexpReplace replace regexp
func RegexpReplace(reg, src, temp string) string {
result := []byte{}
pattern := regexp.MustCompile(reg)
for _, submatches := range pattern.FindAllStringSubmatchIndex(src, -1) {
result = pattern.ExpandString(result, temp, src, submatches)
}
return string(result)
}
// formatPackage format package
func formatPackage(name, path string) (res string) {
if path != "" {
if strings.HasSuffix(path, name+"\"") {
res = path
return
}
res = fmt.Sprintf("%s %s", name, path)
}
return
}
// SourceText get source file text
func SourceText() string {
file := os.Getenv("GOFILE")
data, err := ioutil.ReadFile(file)
if err != nil {
log.Fatal("请使用go generate执行", file)
}
return string(data)
}
// FormatCode format code
func FormatCode(source string) string {
src, err := format.Source([]byte(source))
if err != nil {
// Should never happen, but can arise when developing this code.
// The user can compile the output to see the error.
log.Printf("warning: 输出文件不合法: %s", err)
log.Printf("warning: 详细错误请编译查看")
return source
}
return string(src)
}
// Packages get import packages
func (s *Source) Packages(f *ast.Field) (res []string) {
fs := f.Type.(*ast.FuncType).Params.List
fs = append(fs, f.Type.(*ast.FuncType).Results.List...)
var types []string
resMap := make(map[string]bool)
for _, field := range fs {
if p, ok := field.Type.(*ast.MapType); ok {
types = append(types, s.ExprString(p.Key))
types = append(types, s.ExprString(p.Value))
} else if p, ok := field.Type.(*ast.ArrayType); ok {
types = append(types, s.ExprString(p.Elt))
} else {
types = append(types, s.ExprString(field.Type))
}
}
for _, t := range types {
name := RegexpReplace(`(?P<pkg>\w+)\.\w+`, t, "$pkg")
if name == "" {
continue
}
pkg := formatPackage(name, s.pkgPath(name))
if !resMap[pkg] {
resMap[pkg] = true
}
}
for pkg := range resMap {
res = append(res, pkg)
}
return
}