Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(codegen/golang): Allow exporting models to a different package #3874

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions docs/howto/separate-models-file.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Separating models file

By default, sqlc uses a single package to place all the generated code. But you may want to separate
the generated models file into another package for loose coupling purposes in your project.

To do this, you can use the following configuration:

```yaml
version: "2"
sql:
- engine: "postgresql"
queries: "queries.sql"
schema: "schema.sql"
gen:
go:
out: "internal/" # Base directory for the generated files. You can also just use "."
sql_package: "pgx/v5"
package: "sqlcrepo"
output_batch_file_name: "db/sqlcrepo/batch.go"
output_db_file_name: "db/sqlcrepo/db.go"
output_querier_file_name: "db/sqlcrepo/querier.go"
output_copyfrom_file_name: "db/sqlcrepo/copyfrom.go"
output_query_files_directory: "db/sqlcrepo/"
output_models_file_name: "business/entities/models.go"
output_models_package: "entities"
models_package_import_path: "example.com/project/module-path/internal/business/entities"
```

This configuration will generate files in the `internal/db/sqlcrepo` directory with `sqlcrepo`
package name, except for the models file which will be generated in the `internal/business/entities`
directory. The generated models file will use the package name `entities` and it will be imported in
the other generated files using the given
`"example.com/project/module-path/internal/business/entities"` import path when needed.

The generated files will look like this:

```
my-app/
├── internal/
│ ├── db/
│ │ └── sqlcrepo/
│ │ ├── db.go
│ │ └── queries.sql.go
│ └── business/
│ └── entities/
│ └── models.go
├── queries.sql
├── schema.sql
└── sqlc.yaml
```
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ code ever again.
howto/embedding.md
howto/overrides.md
howto/rename.md
howto/separate-models-file.md

.. toctree::
:maxdepth: 3
Expand Down
6 changes: 6 additions & 0 deletions docs/reference/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,16 @@ The `gen` mapping supports the following keys:
- Customize the name of the db file. Defaults to `db.go`.
- `output_models_file_name`:
- Customize the name of the models file. Defaults to `models.go`.
- `output_models_package`:
- Package name of the models file. Used when models file is in a different package. Defaults to value of `package` option.
- `models_package_import_path`:
- Import path of the models package when models file is in a different package. Optional.
- `output_querier_file_name`:
- Customize the name of the querier file. Defaults to `querier.go`.
- `output_copyfrom_file_name`:
- Customize the name of the copyfrom file. Defaults to `copyfrom.go`.
- `output_query_files_directory`:
- Directory where the generated query files will be placed. Defaults to the value of `out` option.
- `output_files_suffix`:
- If specified the suffix will be added to the name of the generated files.
- `query_parameter_limit`:
Expand Down
48 changes: 34 additions & 14 deletions internal/codegen/golang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"go/format"
"path/filepath"
"strings"
"text/template"

Expand Down Expand Up @@ -122,7 +123,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
}

if options.OmitUnusedStructs {
enums, structs = filterUnusedStructs(enums, structs, queries)
enums, structs = filterUnusedStructs(options, enums, structs, queries)
}

if err := validate(options, enums, structs, queries); err != nil {
Expand Down Expand Up @@ -211,6 +212,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
"imports": i.Imports,
"hasImports": i.HasImports,
"hasPrefix": strings.HasPrefix,
"trimPrefix": strings.TrimPrefix,

// These methods are Go specific, they do not belong in the codegen package
// (as that is language independent)
Expand All @@ -232,14 +234,15 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,

output := map[string]string{}

execute := func(name, templateName string) error {
execute := func(name, packageName, templateName string) error {
imports := i.Imports(name)
replacedQueries := replaceConflictedArg(imports, queries)

var b bytes.Buffer
w := bufio.NewWriter(&b)
tctx.SourceName = name
tctx.GoQueries = replacedQueries
tctx.Package = packageName
err := tmpl.ExecuteTemplate(w, templateName, &tctx)
w.Flush()
if err != nil {
Expand All @@ -251,8 +254,13 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
return fmt.Errorf("source error: %w", err)
}

if templateName == "queryFile" && options.OutputFilesSuffix != "" {
name += options.OutputFilesSuffix
if templateName == "queryFile" {
if options.OutputQueryFilesDirectory != "" {
name = filepath.Join(options.OutputQueryFilesDirectory, name)
}
if options.OutputFilesSuffix != "" {
name += options.OutputFilesSuffix
}
}

if !strings.HasSuffix(name, ".go") {
Expand Down Expand Up @@ -284,24 +292,29 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
batchFileName = options.OutputBatchFileName
}

if err := execute(dbFileName, "dbFile"); err != nil {
modelsPackageName := options.Package
if options.OutputModelsPackage != "" {
modelsPackageName = options.OutputModelsPackage
}

if err := execute(dbFileName, options.Package, "dbFile"); err != nil {
return nil, err
}
if err := execute(modelsFileName, "modelsFile"); err != nil {
if err := execute(modelsFileName, modelsPackageName, "modelsFile"); err != nil {
return nil, err
}
if options.EmitInterface {
if err := execute(querierFileName, "interfaceFile"); err != nil {
if err := execute(querierFileName, options.Package, "interfaceFile"); err != nil {
return nil, err
}
}
if tctx.UsesCopyFrom {
if err := execute(copyfromFileName, "copyfromFile"); err != nil {
if err := execute(copyfromFileName, options.Package, "copyfromFile"); err != nil {
return nil, err
}
}
if tctx.UsesBatch {
if err := execute(batchFileName, "batchFile"); err != nil {
if err := execute(batchFileName, options.Package, "batchFile"); err != nil {
return nil, err
}
}
Expand All @@ -312,7 +325,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
}

for source := range files {
if err := execute(source, "queryFile"); err != nil {
if err := execute(source, options.Package, "queryFile"); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -362,7 +375,7 @@ func checkNoTimesForMySQLCopyFrom(queries []Query) error {
return nil
}

func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enum, []Struct) {
func filterUnusedStructs(options *opts.Options, enums []Enum, structs []Struct, queries []Query) ([]Enum, []Struct) {
keepTypes := make(map[string]struct{})

for _, query := range queries {
Expand All @@ -389,16 +402,23 @@ func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enu

keepEnums := make([]Enum, 0, len(enums))
for _, enum := range enums {
_, keep := keepTypes[enum.Name]
_, keepNull := keepTypes["Null"+enum.Name]
var enumType string
if options.ModelsPackageImportPath != "" {
enumType = options.OutputModelsPackage + "." + enum.Name
} else {
enumType = enum.Name
}

_, keep := keepTypes[enumType]
_, keepNull := keepTypes["Null"+enumType]
if keep || keepNull {
keepEnums = append(keepEnums, enum)
}
}

keepStructs := make([]Struct, 0, len(structs))
for _, st := range structs {
if _, ok := keepTypes[st.Name]; ok {
if _, ok := keepTypes[st.Type()]; ok {
keepStructs = append(keepStructs, st)
}
}
Expand Down
53 changes: 47 additions & 6 deletions internal/codegen/golang/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ var pqtypeTypes = map[string]struct{}{
"pqtype.NullRawMessage": {},
}

func buildImports(options *opts.Options, queries []Query, uses func(string) bool) (map[string]struct{}, map[ImportSpec]struct{}) {
func buildImports(options *opts.Options, queries []Query, outputFile OutputFile, uses func(string) bool) (map[string]struct{}, map[ImportSpec]struct{}) {
pkg := make(map[ImportSpec]struct{})
std := make(map[string]struct{})

Expand Down Expand Up @@ -243,11 +243,52 @@ func buildImports(options *opts.Options, queries []Query, uses func(string) bool
}
}

requiresModelsPackageImport := func() bool {
if options.ModelsPackageImportPath == "" {
return false
}

for _, q := range queries {
// Check if the return type is from models package (possibly a model struct or an enum)
if q.hasRetType() && strings.HasPrefix(q.Ret.Type(), options.OutputModelsPackage+".") {
return true
}

// Check if the return type struct contains a type from models package (possibly an enum field or an embedded struct)
if outputFile != OutputFileInterface && q.hasRetType() && q.Ret.IsStruct() {
for _, f := range q.Ret.Struct.Fields {
if strings.HasPrefix(f.Type, options.OutputModelsPackage+".") {
return true
}
}
}

// Check if the argument type is from models package (possibly an enum)
if !q.Arg.isEmpty() && strings.HasPrefix(q.Arg.Type(), options.OutputModelsPackage+".") {
return true
}

// Check if the argument struct contains a type from models package (possibly an enum field)
if outputFile != OutputFileInterface && !q.Arg.isEmpty() && q.Arg.IsStruct() {
for _, f := range q.Arg.Struct.Fields {
if strings.HasPrefix(f.Type, options.OutputModelsPackage+".") {
return true
}
}
}

}
return false
}
if requiresModelsPackageImport() {
pkg[ImportSpec{Path: options.ModelsPackageImportPath}] = struct{}{}
}

return std, pkg
}

func (i *importer) interfaceImports() fileImports {
std, pkg := buildImports(i.Options, i.Queries, func(name string) bool {
std, pkg := buildImports(i.Options, i.Queries, OutputFileInterface, func(name string) bool {
for _, q := range i.Queries {
if q.hasRetType() {
if usesBatch([]Query{q}) {
Expand All @@ -272,7 +313,7 @@ func (i *importer) interfaceImports() fileImports {
}

func (i *importer) modelImports() fileImports {
std, pkg := buildImports(i.Options, nil, i.usesType)
std, pkg := buildImports(i.Options, nil, OutputFileModel, i.usesType)

if len(i.Enums) > 0 {
std["fmt"] = struct{}{}
Expand Down Expand Up @@ -311,7 +352,7 @@ func (i *importer) queryImports(filename string) fileImports {
}
}

std, pkg := buildImports(i.Options, gq, func(name string) bool {
std, pkg := buildImports(i.Options, gq, OutputFileQuery, func(name string) bool {
for _, q := range gq {
if q.hasRetType() {
if q.Ret.EmitStruct() {
Expand Down Expand Up @@ -412,7 +453,7 @@ func (i *importer) copyfromImports() fileImports {
copyFromQueries = append(copyFromQueries, q)
}
}
std, pkg := buildImports(i.Options, copyFromQueries, func(name string) bool {
std, pkg := buildImports(i.Options, copyFromQueries, OutputFileCopyfrom, func(name string) bool {
for _, q := range copyFromQueries {
if q.hasRetType() {
if strings.HasPrefix(q.Ret.Type(), name) {
Expand Down Expand Up @@ -447,7 +488,7 @@ func (i *importer) batchImports() fileImports {
batchQueries = append(batchQueries, q)
}
}
std, pkg := buildImports(i.Options, batchQueries, func(name string) bool {
std, pkg := buildImports(i.Options, batchQueries, OutputFileBatch, func(name string) bool {
for _, q := range batchQueries {
if q.hasRetType() {
if q.Ret.EmitStruct() {
Expand Down
10 changes: 9 additions & 1 deletion internal/codegen/golang/opts/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ type Options struct {
OutputBatchFileName string `json:"output_batch_file_name,omitempty" yaml:"output_batch_file_name"`
OutputDbFileName string `json:"output_db_file_name,omitempty" yaml:"output_db_file_name"`
OutputModelsFileName string `json:"output_models_file_name,omitempty" yaml:"output_models_file_name"`
OutputModelsPackage string `json:"output_models_package,omitempty" yaml:"output_models_package"`
ModelsPackageImportPath string `json:"models_package_import_path,omitempty" yaml:"models_package_import_path"`
OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"`
OutputCopyfromFileName string `json:"output_copyfrom_file_name,omitempty" yaml:"output_copyfrom_file_name"`
OutputQueryFilesDirectory string `json:"output_query_files_directory,omitempty" yaml:"output_query_files_directory"`
OutputFilesSuffix string `json:"output_files_suffix,omitempty" yaml:"output_files_suffix"`
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names,omitempty" yaml:"inflection_exclude_table_names"`
QueryParameterLimit *int32 `json:"query_parameter_limit,omitempty" yaml:"query_parameter_limit"`
Expand Down Expand Up @@ -150,6 +153,11 @@ func ValidateOpts(opts *Options) error {
if *opts.QueryParameterLimit < 0 {
return fmt.Errorf("invalid options: query parameter limit must not be negative")
}

if opts.OutputModelsPackage != "" && opts.ModelsPackageImportPath == "" {
return fmt.Errorf("invalid options: models_package_import_path must be set when output_models_package is used")
}
if opts.ModelsPackageImportPath != "" && opts.OutputModelsPackage == "" {
return fmt.Errorf("invalid options: output_models_package must be set when models_package_import_path is used")
}
return nil
}
12 changes: 12 additions & 0 deletions internal/codegen/golang/output_file.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package golang

type OutputFile string

const (
OutputFileModel OutputFile = "modelFile"
OutputFileQuery OutputFile = "queryFile"
OutputFileDb OutputFile = "dbFile"
OutputFileInterface OutputFile = "interfaceFile"
OutputFileCopyfrom OutputFile = "copyfromFile"
OutputFileBatch OutputFile = "batchFile"
)
15 changes: 11 additions & 4 deletions internal/codegen/golang/postgresql_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,17 +571,24 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi

for _, enum := range schema.Enums {
if rel.Name == enum.Name && rel.Schema == schema.Name {
enumName := ""
if notNull {
if schema.Name == req.Catalog.DefaultSchema {
return StructName(enum.Name, options)
enumName = StructName(enum.Name, options)
} else {
enumName = StructName(schema.Name+"_"+enum.Name, options)
}
return StructName(schema.Name+"_"+enum.Name, options)
} else {
if schema.Name == req.Catalog.DefaultSchema {
return "Null" + StructName(enum.Name, options)
enumName = "Null" + StructName(enum.Name, options)
} else {
enumName = "Null" + StructName(schema.Name+"_"+enum.Name, options)
}
return "Null" + StructName(schema.Name+"_"+enum.Name, options)
}
if options.ModelsPackageImportPath != "" {
return options.OutputModelsPackage + "." + enumName
}
return enumName
}
}

Expand Down
2 changes: 1 addition & 1 deletion internal/codegen/golang/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (v QueryValue) Type() string {
return v.Typ
}
if v.Struct != nil {
return v.Struct.Name
return v.Struct.Type()
}
panic("no type for QueryValue: " + v.Name)
}
Expand Down
Loading
Loading