Skip to content

Commit

Permalink
chore: reuse fetcher client everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
sweatybridge committed May 2, 2024
1 parent 50e9f07 commit 11940ab
Show file tree
Hide file tree
Showing 15 changed files with 192 additions and 363 deletions.
47 changes: 38 additions & 9 deletions internal/bootstrap/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/supabase/cli/internal/utils/flags"
"github.com/supabase/cli/internal/utils/tenant"
"github.com/supabase/cli/pkg/api"
"github.com/supabase/cli/pkg/fetcher"
"golang.org/x/oauth2"
"golang.org/x/term"
)
Expand Down Expand Up @@ -88,9 +89,6 @@ func Run(ctx context.Context, starter StarterTemplate, fsys afero.Fs, options ..
if err := backoff.RetryNotify(func() error {
fmt.Fprintln(os.Stderr, "Linking project...")
keys, err = apiKeys.RunGetApiKeys(ctx, flags.ProjectRef)
if err == nil {
tenant.SetApiKeys(tenant.NewApiKey(keys))
}
return err
}, policy, newErrorCallback()); err != nil {
return err
Expand All @@ -99,7 +97,7 @@ func Run(ctx context.Context, starter StarterTemplate, fsys afero.Fs, options ..
if err := utils.LoadConfigFS(fsys); err != nil {
return err
}
link.LinkServices(ctx, flags.ProjectRef, fsys)
link.LinkServices(ctx, flags.ProjectRef, tenant.NewApiKey(keys).Anon, fsys)
if err := utils.WriteFile(utils.ProjectRefPath, []byte(flags.ProjectRef), fsys); err != nil {
return err
}
Expand Down Expand Up @@ -356,7 +354,7 @@ func downloadSample(ctx context.Context, client *github.Client, templateUrl stri
opts := github.RepositoryContentGetOptions{Ref: ref}
queue := make([]string, 0)
queue = append(queue, root)
jq := utils.NewJobQueue(5)
download := NewDownloader(5, fsys)
for len(queue) > 0 {
contentPath := queue[0]
queue = queue[1:]
Expand All @@ -369,9 +367,7 @@ func downloadSample(ctx context.Context, client *github.Client, templateUrl stri
case "file":
path := strings.TrimPrefix(file.GetPath(), root)
hostPath := filepath.Join(".", filepath.FromSlash(path))
if err := jq.Put(func() error {
return utils.DownloadFile(ctx, hostPath, file.GetDownloadURL(), fsys)
}); err != nil {
if err := download.Start(ctx, hostPath, file.GetDownloadURL()); err != nil {
return err
}
case "dir":
Expand All @@ -381,5 +377,38 @@ func downloadSample(ctx context.Context, client *github.Client, templateUrl stri
}
}
}
return jq.Collect()
return download.Wait()
}

type Downloader struct {
api *fetcher.Fetcher
queue *utils.JobQueue
fsys afero.Fs
}

func NewDownloader(concurrency uint, fsys afero.Fs) *Downloader {
return &Downloader{
api: fetcher.NewFetcher(""),
queue: utils.NewJobQueue(concurrency),
fsys: fsys,
}
}

func (d *Downloader) Start(ctx context.Context, localPath, remotePath string) error {
job := func() error {
resp, err := d.api.Send(ctx, http.MethodGet, remotePath, nil)
if err != nil {
return err
}
defer resp.Body.Close()
if err := afero.WriteReader(d.fsys, localPath, resp.Body); err != nil {
return errors.Errorf("failed to write file: %w", err)
}
return nil
}
return d.queue.Put(job)
}

func (d *Downloader) Wait() error {
return d.queue.Collect()
}
29 changes: 16 additions & 13 deletions internal/link/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/supabase/cli/internal/utils/flags"
"github.com/supabase/cli/internal/utils/tenant"
"github.com/supabase/cli/pkg/api"
"github.com/supabase/cli/pkg/fetcher"
)

var updatedConfig ConfigCopy
Expand All @@ -37,10 +38,11 @@ func (c ConfigCopy) IsEmpty() bool {

func Run(ctx context.Context, projectRef string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
// 1. Check service config
if _, err := tenant.GetApiKeys(ctx, projectRef); err != nil {
keys, err := tenant.GetApiKeys(ctx, projectRef)
if err != nil {
return err
}
LinkServices(ctx, projectRef, fsys)
LinkServices(ctx, projectRef, keys.Anon, fsys)

// 2. Check database connection
config := flags.GetDbConfigOptionalPassword(projectRef)
Expand Down Expand Up @@ -72,7 +74,7 @@ func PostRun(projectRef string, stdout io.Writer, fsys afero.Fs) error {
return nil
}

func LinkServices(ctx context.Context, projectRef string, fsys afero.Fs) {
func LinkServices(ctx context.Context, projectRef, anonKey string, fsys afero.Fs) {
// Ignore non-fatal errors linking services
var wg sync.WaitGroup
wg.Add(6)
Expand All @@ -90,25 +92,26 @@ func LinkServices(ctx context.Context, projectRef string, fsys afero.Fs) {
}()
go func() {
defer wg.Done()
if err := linkPostgrestVersion(ctx, projectRef, fsys); err != nil && viper.GetBool("DEBUG") {
if err := linkPooler(ctx, projectRef, fsys); err != nil && viper.GetBool("DEBUG") {
fmt.Fprintln(os.Stderr, err)
}
}()
api := tenant.NewTenantAPI(ctx, projectRef, anonKey)
go func() {
defer wg.Done()
if err := linkGotrueVersion(ctx, projectRef, fsys); err != nil && viper.GetBool("DEBUG") {
if err := linkPostgrestVersion(ctx, api, fsys); err != nil && viper.GetBool("DEBUG") {
fmt.Fprintln(os.Stderr, err)
}
}()
go func() {
defer wg.Done()
if err := linkStorageVersion(ctx, projectRef, fsys); err != nil && viper.GetBool("DEBUG") {
if err := linkGotrueVersion(ctx, api, fsys); err != nil && viper.GetBool("DEBUG") {
fmt.Fprintln(os.Stderr, err)
}
}()
go func() {
defer wg.Done()
if err := linkPooler(ctx, projectRef, fsys); err != nil && viper.GetBool("DEBUG") {
if err := linkStorageVersion(ctx, api, fsys); err != nil && viper.GetBool("DEBUG") {
fmt.Fprintln(os.Stderr, err)
}
}()
Expand All @@ -127,8 +130,8 @@ func linkPostgrest(ctx context.Context, projectRef string) error {
return nil
}

func linkPostgrestVersion(ctx context.Context, projectRef string, fsys afero.Fs) error {
version, err := tenant.GetPostgrestVersion(ctx, projectRef)
func linkPostgrestVersion(ctx context.Context, api *fetcher.Fetcher, fsys afero.Fs) error {
version, err := tenant.GetPostgrestVersion(ctx, api)
if err != nil {
return err
}
Expand Down Expand Up @@ -160,16 +163,16 @@ func readCsv(line string) []string {
return result
}

func linkGotrueVersion(ctx context.Context, projectRef string, fsys afero.Fs) error {
version, err := tenant.GetGotrueVersion(ctx, projectRef)
func linkGotrueVersion(ctx context.Context, api *fetcher.Fetcher, fsys afero.Fs) error {
version, err := tenant.GetGotrueVersion(ctx, api)
if err != nil {
return err
}
return utils.WriteFile(utils.GotrueVersionPath, []byte(version), fsys)
}

func linkStorageVersion(ctx context.Context, projectRef string, fsys afero.Fs) error {
version, err := tenant.GetStorageVersion(ctx, projectRef)
func linkStorageVersion(ctx context.Context, api *fetcher.Fetcher, fsys afero.Fs) error {
version, err := tenant.GetStorageVersion(ctx, api)
if err != nil {
return err
}
Expand Down
18 changes: 12 additions & 6 deletions internal/services/services.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,31 +58,37 @@ func GetServiceImages() []string {
}

func GetRemoteImages(ctx context.Context, projectRef string) map[string]string {
const cap = 4
linked := make(map[string]string, cap)
linked := make(map[string]string, 4)
var wg sync.WaitGroup
wg.Add(cap)
wg.Add(1)
go func() {
defer wg.Done()
if version, err := tenant.GetDatabaseVersion(ctx, projectRef); err == nil {
linked[utils.Config.Db.Image] = version
}
}()
keys, err := tenant.GetApiKeys(ctx, projectRef)
if err != nil {
wg.Wait()
return linked
}
api := tenant.NewTenantAPI(ctx, projectRef, keys.Anon)
wg.Add(3)
go func() {
defer wg.Done()
if version, err := tenant.GetGotrueVersion(ctx, projectRef); err == nil {
if version, err := tenant.GetGotrueVersion(ctx, api); err == nil {
linked[utils.Config.Auth.Image] = version
}
}()
go func() {
defer wg.Done()
if version, err := tenant.GetPostgrestVersion(ctx, projectRef); err == nil {
if version, err := tenant.GetPostgrestVersion(ctx, api); err == nil {
linked[utils.Config.Api.Image] = version
}
}()
go func() {
defer wg.Done()
if version, err := tenant.GetStorageVersion(ctx, projectRef); err == nil {
if version, err := tenant.GetStorageVersion(ctx, api); err == nil {
linked[utils.Config.Storage.Image] = version
}
}()
Expand Down
9 changes: 9 additions & 0 deletions internal/storage/cp/cp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ func TestStorageCP(t *testing.T) {
t.Run("throws error on unsupported operation", func(t *testing.T) {
// Setup in-memory fs
fsys := afero.NewMemMapFs()
// Setup mock api
defer gock.OffAll()
gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + flags.ProjectRef + "/api-keys").
Reply(http.StatusOK).
JSON([]api.ApiKeyResponse{{
Name: "service_role",
ApiKey: "service-key",
}})
// Run test
err := Run(context.Background(), ".", ".", false, 1, fsys)
// Check error
Expand Down
42 changes: 32 additions & 10 deletions internal/utils/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ import (
"net/http/httptrace"
"net/textproto"
"sync"
"time"

"github.com/go-errors/errors"
"github.com/spf13/viper"
supabase "github.com/supabase/cli/pkg/api"
"github.com/supabase/cli/pkg/fetcher"
)

const (
Expand Down Expand Up @@ -54,11 +56,13 @@ func FallbackLookupIP(ctx context.Context, host string) ([]string, error) {
return []string{host}, nil
}
// Ref: https://developers.cloudflare.com/1.1.1.1/encryption/dns-over-https/make-api-requests/dns-json
url := "https://1.1.1.1/dns-query?name=" + host
data, err := JsonResponse[dnsResponse](ctx, http.MethodGet, url, nil, func(ctx context.Context, req *http.Request) error {
req.Header.Add("accept", "application/dns-json")
return nil
})
api := NewCloudflareAPI()
resp, err := api.Send(ctx, http.MethodGet, "/dns-query?name="+host, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
data, err := fetcher.ParseJSON[dnsResponse](resp.Body)
if err != nil {
return nil, err
}
Expand All @@ -77,11 +81,13 @@ func FallbackLookupIP(ctx context.Context, host string) ([]string, error) {

func ResolveCNAME(ctx context.Context, host string) (string, error) {
// Ref: https://developers.cloudflare.com/1.1.1.1/encryption/dns-over-https/make-api-requests/dns-json
url := fmt.Sprintf("https://1.1.1.1/dns-query?name=%s&type=CNAME", host)
data, err := JsonResponse[dnsResponse](ctx, http.MethodGet, url, nil, func(ctx context.Context, req *http.Request) error {
req.Header.Add("accept", "application/dns-json")
return nil
})
api := NewCloudflareAPI()
resp, err := api.Send(ctx, http.MethodGet, "/dns-query?type=CNAME&name="+host, nil)
if err != nil {
return "", err
}
defer resp.Body.Close()
data, err := fetcher.ParseJSON[dnsResponse](resp.Body)
if err != nil {
return "", err
}
Expand All @@ -99,6 +105,22 @@ func ResolveCNAME(ctx context.Context, host string) (string, error) {
return "", errors.Errorf("failed to locate appropriate CNAME record for %s; resolves to %+v", host, serialized)
}

func NewCloudflareAPI() *fetcher.Fetcher {
server := "https://1.1.1.1"
client := &http.Client{
Timeout: 10 * time.Second,
}
header := func(req *http.Request) {
req.Header.Add("accept", "application/dns-json")
}
api := fetcher.NewFetcher(
server,
fetcher.WithHTTPClient(client),
fetcher.WithRequestEditor(header),
)
return api
}

func WithTraceContext(ctx context.Context) context.Context {
trace := &httptrace.ClientTrace{
DNSStart: func(info httptrace.DNSStartInfo) {
Expand Down
Loading

0 comments on commit 11940ab

Please sign in to comment.