Skip to content

Commit 3246dcb

Browse files
authored
Custom proto field types (goadesign#3065)
* Use int32 and int64 for proto integer types The encoding differs than what is used for sint32 and sint64 and it matters for writing servers that are accessed by non Goa clients. * Add "struct:fueld:proto" and "protoc:include" meta * `struct:field:proto` makes it possible to override the proto type generated for a given field. Note that the type must be compatible with the service level type generated by Goa. * `protoc:include` makes it possible to specify include path to be used when invoking `protoc`. This is especially useful when overriding field type with `struct:field:proto` with types that exist in different proto files. * Revert integer proto mapping change Since the type can now be overridden with `struct:field:proto`
1 parent dd802db commit 3246dcb

File tree

5 files changed

+109
-25
lines changed

5 files changed

+109
-25
lines changed

dsl/meta.go

+31
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,25 @@ import (
7676
// })
7777
// })
7878
//
79+
// - "struct:field:proto" overrides the generated protobuf field type. The second
80+
// argument is optional and if present indicates an import path for the proto file
81+
// defining the type.
82+
//
83+
// var Timestamp = Type("Timestamp", func() {
84+
// Description("Google timestamp compatible design")
85+
// Field(1, "seconds", Int64, "Unix timestamp in seconds", func() {
86+
// Meta("struct:field:proto", "int64") // Goa generates sint64 by default
87+
// })
88+
// Field(2, "nanos", Int32, "Unix timestamp in nanoseconds", func() {
89+
// Meta("struct:field:proto", "int32") // Goa generates sint32 by default
90+
// })
91+
// })
92+
//
93+
// var MyType = Type("MyType", func() {
94+
// Field(1, "created_at", Timestamp, func() {
95+
// Meta("struct:field:proto", "google.protobuf.Timestamp", "google/protobuf/timestamp.proto")
96+
// })
97+
// })
7998
//
8099
// - "struct:tag:xxx" sets a generated Go struct field tag and overrides tags
81100
// that Goa would otherwise set. If the metadata value is a slice then the
@@ -89,6 +108,18 @@ import (
89108
// })
90109
// })
91110
//
111+
// - "protoc:include" provides the list of import paths used to invoke protoc.
112+
// Applicable to API and service definitions only. If used on an API definition
113+
// the include paths are used for all services.
114+
//
115+
// var _ = API("myapi", func() {
116+
// Meta("protoc:include", "/usr/include", "/usr/local/include")
117+
// })
118+
//
119+
// var _ = Service("service1", func() {
120+
// Meta("protoc:include", "/usr/local/include/google/protobuf")
121+
// })
122+
//
92123
// - "swagger:generate" DEPRECATED, use "openapi:generate" instead.
93124
//
94125
// - "openapi:generate" specifies whether OpenAPI specification should be

grpc/codegen/proto.go

+23-3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ func protoFile(genpkg string, svc *expr.GRPCServiceExpr) *codegen.File {
4747
Data: map[string]interface{}{
4848
"ProtoVersion": ProtoVersion,
4949
"Pkg": pkgName(svc, svcName),
50+
"Imports": data.Imports,
5051
},
5152
},
5253
// service definition
@@ -62,10 +63,16 @@ func protoFile(genpkg string, svc *expr.GRPCServiceExpr) *codegen.File {
6263
sections = append(sections, &codegen.SectionTemplate{Name: "grpc-message", Source: messageT, Data: m})
6364
}
6465

66+
runProtoc := func(path string) error {
67+
includes := svc.ServiceExpr.Meta["protoc:include"]
68+
includes = append(includes, expr.Root.API.Meta["protoc:include"]...)
69+
return protoc(path, includes)
70+
}
71+
6572
return &codegen.File{
6673
Path: path,
6774
SectionTemplates: sections,
68-
FinalizeFunc: protoc,
75+
FinalizeFunc: runProtoc,
6976
}
7077
}
7178

@@ -76,11 +83,21 @@ func pkgName(svc *expr.GRPCServiceExpr, svcName string) string {
7683
return codegen.SnakeCase(svcName)
7784
}
7885

79-
func protoc(path string) error {
86+
func protoc(path string, includes []string) error {
8087
dir := filepath.Dir(path)
8188
os.MkdirAll(dir, 0777)
8289

83-
args := []string{"--proto_path", dir, "--go_out", dir, "--go-grpc_out", dir, "--go_opt=paths=source_relative", "--go-grpc_opt=paths=source_relative", path}
90+
args := []string{
91+
path,
92+
"--proto_path", dir,
93+
"--go_out", dir,
94+
"--go-grpc_out", dir,
95+
"--go_opt=paths=source_relative",
96+
"--go-grpc_opt=paths=source_relative",
97+
}
98+
for _, include := range includes {
99+
args = append(args, "-I", include)
100+
}
84101
cmd := exec.Command("protoc", args...)
85102
cmd.Dir = filepath.Dir(path)
86103

@@ -108,6 +125,9 @@ syntax = {{ printf "%q" .ProtoVersion }};
108125
package {{ .Pkg }};
109126
110127
option go_package = "/{{ .Pkg }}pb";
128+
{{- range .Imports }}
129+
import "{{ . }}";
130+
{{- end }}
111131
`
112132

113133
// input: ServiceData

grpc/codegen/proto_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func TestProtoFiles(t *testing.T) {
4242
t.Errorf("%s: got\n%s\ngot vs. expected:\n%s", c.Name, code, codegen.Diff(t, code, c.Code))
4343
}
4444
fpath := codegen.CreateTempFile(t, code)
45-
if err := protoc(fpath); err != nil {
45+
if err := protoc(fpath, nil); err != nil {
4646
t.Fatalf("error occurred when compiling proto file %q: %s", fpath, err)
4747
}
4848
})
@@ -83,7 +83,7 @@ func TestMessageDefSection(t *testing.T) {
8383
t.Errorf("%s: got\n%s\ngot vs. expected:\n%s", c.Name, msgCode, codegen.Diff(t, msgCode, c.Code))
8484
}
8585
fpath := codegen.CreateTempFile(t, code+msgCode)
86-
if err := protoc(fpath); err != nil {
86+
if err := protoc(fpath, nil); err != nil {
8787
t.Fatalf("error occurred when compiling proto file %q: %s", fpath, err)
8888
}
8989
})

grpc/codegen/protobuf.go

+17-9
Original file line numberDiff line numberDiff line change
@@ -231,27 +231,35 @@ func protoBufGoFullTypeName(att *expr.AttributeExpr, pkg string, s *codegen.Name
231231
}
232232
}
233233

234+
// protoType returns the protocol buffer type name for the given attribute.
235+
func protoType(att *expr.AttributeExpr, sd *ServiceData) string {
236+
if protos := att.Meta["struct:field:proto"]; len(protos) > 0 {
237+
return protos[0]
238+
}
239+
return protoBufMessageDef(att, sd)
240+
}
241+
234242
// protoBufMessageDef returns the protocol buffer code that defines a message
235243
// which matches the data structure definition (the part that comes after
236244
// `message foo`). The message is defined using the proto3 syntax.
237245
func protoBufMessageDef(att *expr.AttributeExpr, sd *ServiceData) string {
238246
switch actual := att.Type.(type) {
239247
case expr.Primitive:
240-
return protoBufNativeMessageTypeName(att.Type)
248+
return protoNativeType(att.Type)
241249
case *expr.Array:
242-
return "repeated " + protoBufMessageDef(actual.ElemType, sd)
250+
return "repeated " + protoType(actual.ElemType, sd)
243251
case *expr.Map:
244-
return fmt.Sprintf("map<%s, %s>", protoBufMessageDef(actual.KeyType, sd), protoBufMessageDef(actual.ElemType, sd))
252+
return fmt.Sprintf("map<%s, %s>", protoType(actual.KeyType, sd), protoType(actual.ElemType, sd))
245253
case *expr.Union:
246254
def := "\toneof " + codegen.SnakeCase(protoBufify(actual.Name(), false, false)) + " {"
247255
for _, nat := range actual.Values {
248256
fn := codegen.SnakeCase(protoBufify(nat.Name, false, false))
249257
fnum := rpcTag(nat.Attribute)
250258
var typ string
251259
if prim := getPrimitive(nat.Attribute); prim != nil {
252-
typ = protoBufMessageDef(prim, sd)
260+
typ = protoType(prim, sd)
253261
} else {
254-
typ = protoBufMessageDef(nat.Attribute, sd)
262+
typ = protoType(nat.Attribute, sd)
255263
}
256264
var desc string
257265
if d := nat.Attribute.Description; d != "" {
@@ -287,9 +295,9 @@ func protoBufMessageDef(att *expr.AttributeExpr, sd *ServiceData) string {
287295
fn = codegen.SnakeCase(protoBufify(nat.Name, false, false))
288296
fnum = rpcTag(nat.Attribute)
289297
if prim := getPrimitive(nat.Attribute); prim != nil {
290-
typ = protoBufMessageDef(prim, sd)
298+
typ = protoType(prim, sd)
291299
} else {
292-
typ = protoBufMessageDef(nat.Attribute, sd)
300+
typ = protoType(nat.Attribute, sd)
293301
}
294302
if nat.Attribute.Description != "" {
295303
desc = codegen.Comment(nat.Attribute.Description) + "\n\t"
@@ -372,10 +380,10 @@ func protoBufifyAtt(att *expr.AttributeExpr, name string, upper bool) string {
372380
return protoBufify(name, upper, false)
373381
}
374382

375-
// protoBufNativeMessageTypeName returns the protocol buffer built-in type
383+
// protoNativeType returns the protocol buffer built-in type
376384
// corresponding to the given primitive type. It panics if t is not a
377385
// primitive type.
378-
func protoBufNativeMessageTypeName(t expr.DataType) string {
386+
func protoNativeType(t expr.DataType) string {
379387
switch t.Kind() {
380388
case expr.BooleanKind:
381389
return "bool"

grpc/codegen/service_data.go

+36-11
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ type (
2424
Service *service.Data
2525
// PkgName is the name of the generated package in *.pb.go.
2626
PkgName string
27+
// Imports is the list of proto package imports.
28+
Imports []string
2729
// Name is the service name.
2830
Name string
2931
// Description is the service description.
@@ -466,8 +468,19 @@ func (d ServicesData) analyze(gs *expr.GRPCServiceExpr) *ServiceData {
466468
}
467469

468470
// collect all the nested messages and return the top-level message
471+
// Also collect all proto imports specified via Meta.
469472
collect := func(att *expr.AttributeExpr) *service.UserTypeData {
470-
msgs := collectMessages(att, sd, seen)
473+
msgs, imports := collectMessages(att, sd, seen)
474+
if len(imports) > 0 {
475+
imported := make(map[string]struct{})
476+
for _, imp := range imports {
477+
if _, ok := imported[imp]; ok {
478+
continue
479+
}
480+
imported[imp] = struct{}{}
481+
sd.Imports = append(sd.Imports, imp)
482+
}
483+
}
471484
if len(msgs) > 0 {
472485
sd.Messages = append(sd.Messages, msgs...)
473486
return msgs[0]
@@ -634,17 +647,23 @@ func (d ServicesData) analyze(gs *expr.GRPCServiceExpr) *ServiceData {
634647
}
635648

636649
// collectMessages recurses through the attribute to gather all the messages.
637-
func collectMessages(at *expr.AttributeExpr, sd *ServiceData, seen map[string]struct{}) (data []*service.UserTypeData) {
638-
if at == nil || expr.IsPrimitive(at.Type) {
650+
func collectMessages(at *expr.AttributeExpr, sd *ServiceData, seen map[string]struct{}) (data []*service.UserTypeData, imports []string) {
651+
if at == nil {
639652
return
640653
}
641-
collect := func(at *expr.AttributeExpr) []*service.UserTypeData {
654+
if proto := at.Meta["struct:field:proto"]; len(proto) > 1 {
655+
imports = append(imports, proto[1])
656+
}
657+
if expr.IsPrimitive(at.Type) {
658+
return
659+
}
660+
collect := func(at *expr.AttributeExpr) ([]*service.UserTypeData, []string) {
642661
return collectMessages(at, sd, seen)
643662
}
644663
switch dt := at.Type.(type) {
645664
case expr.UserType:
646665
if _, ok := seen[dt.Name()]; ok {
647-
return nil
666+
return
648667
}
649668
att := dt.Attribute()
650669
if rt, ok := dt.(*expr.ResultTypeExpr); ok {
@@ -662,19 +681,25 @@ func collectMessages(at *expr.AttributeExpr, sd *ServiceData, seen map[string]st
662681
Type: dt,
663682
})
664683
seen[dt.Name()] = struct{}{}
665-
data = append(data, collect(att)...)
684+
d, i := collect(att)
685+
data, imports = append(data, d...), append(imports, i...)
666686
case *expr.Object:
667687
for _, nat := range *dt {
668-
data = append(data, collect(nat.Attribute)...)
688+
d, i := collect(nat.Attribute)
689+
data, imports = append(data, d...), append(imports, i...)
669690
}
670691
case *expr.Array:
671-
data = append(data, collect(dt.ElemType)...)
692+
d, i := collect(dt.ElemType)
693+
data, imports = append(data, d...), append(imports, i...)
672694
case *expr.Map:
673-
data = append(data, collect(dt.KeyType)...)
674-
data = append(data, collect(dt.ElemType)...)
695+
dk, ik := collect(dt.KeyType)
696+
data, imports = append(data, dk...), append(imports, ik...)
697+
de, ie := collect(dt.ElemType)
698+
data, imports = append(data, de...), append(imports, ie...)
675699
case *expr.Union:
676700
for _, nat := range dt.Values {
677-
data = append(data, collect(nat.Attribute)...)
701+
d, i := collect(nat.Attribute)
702+
data, imports = append(data, d...), append(imports, i...)
678703
}
679704
}
680705
return

0 commit comments

Comments
 (0)