Skip to content

Commit

Permalink
fix: fix duplicate method name
Browse files Browse the repository at this point in the history
  • Loading branch information
idersec authored and tr1v3r committed Sep 2, 2021
1 parent 9ac4321 commit dd15929
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 93 deletions.
4 changes: 2 additions & 2 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ type genInfo struct {
func (i *genInfo) AppendMethods(methods []*check.InterfaceMethod) error {
for _, newMethod := range methods {
for _, infoMethod := range i.Interfaces {
if infoMethod.MethodName == newMethod.MethodName && infoMethod.InterfaceName != newMethod.InterfaceName {
if infoMethod.IsRepeatInterfaceMethod(newMethod) {
return fmt.Errorf("can not generate method with the same name from different interface:%s.%s and %s.%s", infoMethod.InterfaceName, infoMethod.MethodName, newMethod.InterfaceName, newMethod.MethodName)
}
}
Expand Down Expand Up @@ -206,7 +206,7 @@ func (g *Generator) apply(fc interface{}, structs []*check.BaseStruct) {
panic("panic with check interface error")
}

err = g.readInterfaceSet.ParseFile(interfacePaths)
err = g.readInterfaceSet.ParseFile(interfacePaths, check.GetNames(structs))
if err != nil {
g.db.Logger.Error(context.Background(), "can not parser interface file: %s", err)
panic("panic with parser interface file error")
Expand Down
126 changes: 64 additions & 62 deletions internal/check/checkinterface.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,33 @@ type InterfaceMethod struct {
}

// HasSqlData has variable or not
func (f *InterfaceMethod) HasSqlData() bool {
return len(f.SqlData) > 0
func (m *InterfaceMethod) HasSqlData() bool {
return len(m.SqlData) > 0
}

// HasGotPoint parameter has pointer or not
func (f *InterfaceMethod) HasGotPoint() bool {
return !f.HasNeedNewResult()
func (m *InterfaceMethod) HasGotPoint() bool {
return !m.HasNeedNewResult()
}

// HasNeedNewResult need pointer or not
func (f *InterfaceMethod) HasNeedNewResult() bool {
return !f.ResultData.IsArray && ((f.ResultData.IsNull() && f.ResultData.IsTime()) || f.ResultData.IsMap())
func (m *InterfaceMethod) HasNeedNewResult() bool {
return !m.ResultData.IsArray && ((m.ResultData.IsNull() && m.ResultData.IsTime()) || m.ResultData.IsMap())
}

// IsRepeatInterfaceMethod check different interface has same mame method
func (m *InterfaceMethod) IsRepeatInterfaceMethod(newMethod *InterfaceMethod) bool {
return m.MethodName == newMethod.MethodName && m.InterfaceName != newMethod.InterfaceName && m.MethodStruct == newMethod.MethodStruct
}

//GetParamInTmpl return param list
func (f *InterfaceMethod) GetParamInTmpl() string {
return paramToString(f.Params)
func (m *InterfaceMethod) GetParamInTmpl() string {
return paramToString(m.Params)
}

// GetResultParamInTmpl return result list
func (f *InterfaceMethod) GetResultParamInTmpl() string {
return paramToString(f.Result)
func (m *InterfaceMethod) GetResultParamInTmpl() string {
return paramToString(m.Result)
}

// paramToString param list to string used in tmpl
Expand All @@ -80,98 +85,95 @@ func paramToString(params []parser.Param) string {
}

// checkParams check all parameters
func (f *InterfaceMethod) checkParams(params []parser.Param) (err error) {
func (m *InterfaceMethod) checkParams(params []parser.Param) (err error) {
paramList := make([]parser.Param, len(params))
for i, r := range params {
if r.Package == "UNDEFINED" {
r.Package = f.OriginStruct.Package
for i, param := range params {
if param.Package == "UNDEFINED" {
param.Package = m.OriginStruct.Package
}
if r.IsMap() || r.IsGenM() || r.IsError() || r.IsNull() {
return fmt.Errorf("type error on interface [%s] param: [%s]", f.InterfaceName, r.Name)
if param.IsMap() || param.IsGenM() || param.IsError() || param.IsNull() {
return fmt.Errorf("type error on interface [%s] param: [%s]", m.InterfaceName, param.Name)
}
paramList[i] = r
paramList[i] = param
}
f.Params = paramList
m.Params = paramList
return
}

// checkResult check all parameters and replace gen.T by target structure. Parameters must be one of int/string/struct/map
func (f *InterfaceMethod) checkResult(result []parser.Param) (err error) {
func (m *InterfaceMethod) checkResult(result []parser.Param) (err error) {
resList := make([]parser.Param, len(result))
var hasError bool
for i, param := range result {
if param.Package == "UNDEFINED" {
param.Package = f.Package
param.Package = m.Package
}
if param.IsGenM() {
param.Type = "map[string]interface{}"
param.Package = ""
}
if param.InMainPkg() {
return fmt.Errorf("query method cannot return struct of main package in [%s.%s]", f.InterfaceName, f.MethodName)
}
switch {
case param.InMainPkg():
return fmt.Errorf("query method cannot return struct of main package in [%s.%s]", m.InterfaceName, m.MethodName)
case param.IsError():
if hasError {
return fmt.Errorf("query method cannot return more than 1 error value in [%s.%s]", f.InterfaceName, f.MethodName)
return fmt.Errorf("query method cannot return more than 1 error value in [%s.%s]", m.InterfaceName, m.MethodName)
}
param.SetName("err")
f.ExecuteResult = "err"
m.ExecuteResult = "err"
hasError = true
case param.Eq(f.OriginStruct) || param.IsGenT():
if !f.ResultData.IsNull() {
return fmt.Errorf("query method cannot return more than 1 data value in [%s.%s]", f.InterfaceName, f.MethodName)
case param.Eq(m.OriginStruct) || param.IsGenT():
if !m.ResultData.IsNull() {
return fmt.Errorf("query method cannot return more than 1 data value in [%s.%s]", m.InterfaceName, m.MethodName)
}
param.SetName("result")
param.Type = f.OriginStruct.Type
param.Package = f.OriginStruct.Package
param.Type = m.OriginStruct.Type
param.Package = m.OriginStruct.Package
param.IsPointer = true
f.ResultData = param
m.ResultData = param
case param.IsInterface():
return fmt.Errorf("query method can not return interface in [%s.%s]", f.InterfaceName, f.MethodName)
return fmt.Errorf("query method can not return interface in [%s.%s]", m.InterfaceName, m.MethodName)
default:
if !f.ResultData.IsNull() {
return fmt.Errorf("query method cannot return more than 1 data value in [%s.%s]", f.InterfaceName, f.MethodName)
if !m.ResultData.IsNull() {
return fmt.Errorf("query method cannot return more than 1 data value in [%s.%s]", m.InterfaceName, m.MethodName)
}
if param.Package == "" && !(param.AllowType() || param.IsMap() || param.IsTime()) {
param.Package = f.Package
param.Package = m.Package
}

param.SetName("result")
f.ResultData = param
//return fmt.Errorf("illegal parameter:%s.%s on struct %s.%s generated method %s", param.Package, param.Type, f.OriginStruct.Package, f.OriginStruct.Type, f.MethodName)
m.ResultData = param
}
resList[i] = param
}
f.Result = resList
m.Result = resList
return
}

// checkSQL get sql from comment and check it
func (f *InterfaceMethod) checkSQL() (err error) {
f.SqlString = f.parseDocString()
if err = f.sqlStateCheck(); err != nil {
err = fmt.Errorf("interface %s member method %s check sql err:%w", f.InterfaceName, f.MethodName, err)
func (m *InterfaceMethod) checkSQL() (err error) {
m.SqlString = m.parseDocString()
if err = m.sqlStateCheck(); err != nil {
err = fmt.Errorf("interface %s member method %s check sql err:%w", m.InterfaceName, m.MethodName, err)
}
return
}

func (f *InterfaceMethod) parseDocString() string {
docString := strings.TrimSpace(f.Doc)
func (m *InterfaceMethod) parseDocString() string {
docString := strings.TrimSpace(m.Doc)
switch {
case strings.HasPrefix(strings.ToLower(docString), "sql("):
docString = docString[4 : len(docString)-1]
f.GormOption = "Raw"
if f.ResultData.IsNull() {
f.GormOption = "Exec"
m.GormOption = "Raw"
if m.ResultData.IsNull() {
m.GormOption = "Exec"
}
case strings.HasPrefix(strings.ToLower(docString), "where("):
docString = docString[6 : len(docString)-1]
f.GormOption = "Where"
m.GormOption = "Where"
default:
f.GormOption = "Raw"
if f.ResultData.IsNull() {
f.GormOption = "Exec"
m.GormOption = "Raw"
if m.ResultData.IsNull() {
m.GormOption = "Exec"
}
}

Expand All @@ -183,8 +185,8 @@ func (f *InterfaceMethod) parseDocString() string {
}

// sqlStateCheck check sql with an adeterministic finite automaton
func (f *InterfaceMethod) sqlStateCheck() error {
sqlString := f.SqlString
func (m *InterfaceMethod) sqlStateCheck() error {
sqlString := m.SqlString
result := NewSlices()
var buf sql
for i := 0; !strOutrange(i, sqlString); i++ {
Expand Down Expand Up @@ -238,7 +240,7 @@ func (f *InterfaceMethod) sqlStateCheck() error {
i++

sqlClause := buf.Dump()
part, err := checkTemplate(sqlClause, f.Params)
part, err := checkTemplate(sqlClause, m.Params)
if err != nil {
return fmt.Errorf("sql [%s] dynamic template %s err:%w", sqlString, sqlClause, err)
}
Expand All @@ -258,7 +260,7 @@ func (f *InterfaceMethod) sqlStateCheck() error {
for ; ; i++ {
if strOutrange(i, sqlString) || isEnd(sqlString[i]) {
varString := buf.Dump()
params, err := f.methodParams(varString, status)
params, err := m.methodParams(varString, status)
if err != nil {
return fmt.Errorf("sql [%s] varable %s err:%s", sqlString, varString, err)
}
Expand All @@ -284,24 +286,24 @@ func (f *InterfaceMethod) sqlStateCheck() error {
if err != nil {
return fmt.Errorf("sql [%s] parser err:%w", sqlString, err)
}
f.SqlTmplList = result.tmpl
m.SqlTmplList = result.tmpl
return nil
}

// methodParams return extrenal parameters, table name
func (f *InterfaceMethod) methodParams(param string, s Status) (result slice, err error) {
for _, p := range f.Params {
func (m *InterfaceMethod) methodParams(param string, s Status) (result slice, err error) {
for _, p := range m.Params {
if p.Name == param {
var str string
switch s {
case DATA:
str = fmt.Sprintf("\"@%s\"", param)
f.SqlData = append(f.SqlData, param)
m.SqlData = append(m.SqlData, param)
case VARIABLE:
if p.Type != "string" {
err = fmt.Errorf("variable name must be string :%s type is %s", param, p.Type)
}
str = fmt.Sprintf("%s.Quote(%s)", f.S, param)
str = fmt.Sprintf("%s.Quote(%s)", m.S, param)
}
result = slice{
Type: s,
Expand All @@ -313,7 +315,7 @@ func (f *InterfaceMethod) methodParams(param string, s Status) (result slice, er
if param == "table" {
result = slice{
Type: SQL,
Value: strconv.Quote(f.Table),
Value: strconv.Quote(m.Table),
}
return
}
Expand Down
7 changes: 7 additions & 0 deletions internal/check/checkstruct.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ func (b *BaseStruct) fixMember() {
}
}

func GetNames(bases []*BaseStruct) (res []string) {
for _, base := range bases {
res = append(res, base.StructName)
}
return res
}

func isStructType(data reflect.Value) bool {
return data.Kind() == reflect.Struct ||
(data.Kind() == reflect.Ptr && data.Elem().Kind() == reflect.Struct)
Expand Down
47 changes: 25 additions & 22 deletions internal/check/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,29 +53,32 @@ func CheckStructs(db *gorm.DB, structs ...interface{}) (bases []*BaseStruct, err
// CheckInterface check the legitimacy of interfaces
func CheckInterface(f *parser.InterfaceSet, s *BaseStruct) (checkResults []*InterfaceMethod, err error) {
for _, interfaceInfo := range f.Interfaces {
for _, method := range interfaceInfo.Methods {
t := &InterfaceMethod{
S: s.S,
MethodStruct: s.NewStructName,
OriginStruct: s.StructInfo,
MethodName: method.MethodName,
Params: method.Params,
Doc: method.Doc,
ExecuteResult: "_",
Table: s.TableName,
InterfaceName: interfaceInfo.Name,
Package: getPackageName(interfaceInfo.Package),
}
if err = t.checkParams(method.Params); err != nil {
return
}
if err = t.checkResult(method.Result); err != nil {
return
}
if err = t.checkSQL(); err != nil {
return
if interfaceInfo.IsMatchStruct(s.StructName) {
for _, method := range interfaceInfo.Methods {

t := &InterfaceMethod{
S: s.S,
MethodStruct: s.NewStructName,
OriginStruct: s.StructInfo,
MethodName: method.MethodName,
Params: method.Params,
Doc: method.Doc,
ExecuteResult: "_",
Table: s.TableName,
InterfaceName: interfaceInfo.Name,
Package: getPackageName(interfaceInfo.Package),
}
if err = t.checkParams(method.Params); err != nil {
return
}
if err = t.checkResult(method.Result); err != nil {
return
}
if err = t.checkSQL(); err != nil {
return
}
checkResults = append(checkResults, t)
}
checkResults = append(checkResults, t)
}
}
return
Expand Down
25 changes: 18 additions & 7 deletions internal/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,20 @@ type InterfaceSet struct {

// InterfaceInfo ...
type InterfaceInfo struct {
Name string
Doc string
Methods []*Method
Package string
Name string
Doc string
Methods []*Method
Package string
ApplyStruct []string
}

func (i *InterfaceInfo) IsMatchStruct(name string) bool {
for _, s := range i.ApplyStruct {
if s == name {
return true
}
}
return false
}

// Method interface's method
Expand All @@ -32,15 +42,15 @@ type Method struct {
}

// ParseFile get interface's info from source file
func (i *InterfaceSet) ParseFile(paths []*InterfacePath) error {
func (i *InterfaceSet) ParseFile(paths []*InterfacePath, structNames []string) error {
for _, path := range paths {
for _, file := range path.Files {
absFilePath, err := filepath.Abs(file)
if err != nil {
return fmt.Errorf("file not found:%s", file)
}

err = i.getInterfaceFromFile(absFilePath, path.Name, path.FullName)
err = i.getInterfaceFromFile(absFilePath, path.Name, path.FullName, structNames)
if err != nil {
return fmt.Errorf("can't get interface from %s:%s", path.FullName, err)
}
Expand Down Expand Up @@ -80,7 +90,7 @@ func (i *InterfaceSet) Visit(n ast.Node) (w ast.Visitor) {

// getInterfaceFromFile get interfaces
// get all interfaces from file and compare with specified name
func (i *InterfaceSet) getInterfaceFromFile(filename string, name, Package string) error {
func (i *InterfaceSet) getInterfaceFromFile(filename string, name, Package string, structNames []string) error {
fileset := token.NewFileSet()
f, err := parser.ParseFile(fileset, filename, nil, parser.ParseComments)
if err != nil {
Expand All @@ -93,6 +103,7 @@ func (i *InterfaceSet) getInterfaceFromFile(filename string, name, Package strin
for _, info := range astResult.Interfaces {
if name == info.Name {
info.Package = Package
info.ApplyStruct = structNames
i.Interfaces = append(i.Interfaces, info)
}
}
Expand Down

0 comments on commit dd15929

Please sign in to comment.