From 65812bf67f5e367319b842d74c20ab251637483e Mon Sep 17 00:00:00 2001 From: Sven Efftinge Date: Fri, 16 Sep 2022 14:10:49 +0000 Subject: [PATCH] [usage] implement CancelSubscription --- components/usage/pkg/apiv1/billing.go | 67 ++++++++++++++----- components/usage/pkg/apiv1/billing_noop.go | 6 ++ components/usage/pkg/db/cost_center.go | 30 ++++++--- components/usage/pkg/db/cost_center_test.go | 18 +++-- components/usage/pkg/db/workspace_instance.go | 4 ++ components/usage/pkg/server/server.go | 2 +- components/usage/pkg/stripe/stripe.go | 29 +++++++- 7 files changed, 121 insertions(+), 35 deletions(-) diff --git a/components/usage/pkg/apiv1/billing.go b/components/usage/pkg/apiv1/billing.go index fe7ce866df5249..d1b12cbf907b21 100644 --- a/components/usage/pkg/apiv1/billing.go +++ b/components/usage/pkg/apiv1/billing.go @@ -22,16 +22,18 @@ import ( "gorm.io/gorm" ) -func NewBillingService(stripeClient *stripe.Client, conn *gorm.DB) *BillingService { +func NewBillingService(stripeClient *stripe.Client, conn *gorm.DB, ccManager *db.CostCenterManager) *BillingService { return &BillingService{ stripeClient: stripeClient, conn: conn, + ccManager: ccManager, } } type BillingService struct { conn *gorm.DB stripeClient *stripe.Client + ccManager *db.CostCenterManager v1.UnimplementedBillingServiceServer } @@ -43,8 +45,20 @@ func (s *BillingService) ReconcileInvoices(ctx context.Context, in *v1.Reconcile return nil, status.Errorf(codes.Internal, "Failed to reconcile invoices.") } - creditSummaryForTeams := map[db.AttributionID]int64{} + //TODO (se) make it one query + stripeBalances := []db.Balance{} for _, balance := range balances { + costCenter, err := s.ccManager.GetOrCreateCostCenter(ctx, balance.AttributionID) + if err != nil { + return nil, err + } + if costCenter.BillingStrategy == db.CostCenter_Stripe { + stripeBalances = append(stripeBalances, balance) + } + } + + creditSummaryForTeams := map[db.AttributionID]int64{} + for _, balance := range stripeBalances { creditSummaryForTeams[balance.AttributionID] = int64(math.Ceil(balance.CreditCents.ToCredits())) } @@ -70,22 +84,11 @@ func (s *BillingService) FinalizeInvoice(ctx context.Context, in *v1.FinalizeInv return nil, status.Errorf(codes.NotFound, "Failed to get invoice with ID %s: %s", in.GetInvoiceId(), err.Error()) } - customer := invoice.Customer - if customer == nil { - logger.Error("No customer information available for invoice.") - return nil, status.Errorf(codes.Internal, "Failed to retrieve customer details from invoice.") - } - logger = logger.WithField("stripe_customer", customer.ID).WithField("stripe_customer_name", customer.Name) - - teamID, found := customer.Metadata[stripe.AttributionIDMetadataKey] - if !found { - logger.Error("Failed to find teamID from subscription metadata.") - return nil, status.Errorf(codes.Internal, "Failed to extra teamID from Stripe subscription.") + attributionID, err := stripe.GetAttributionID(ctx, invoice.Customer) + if err != nil { + return nil, err } - logger = logger.WithField("team_id", teamID) - - // To support individual `user`s, we'll need to also extract the `userId` from metadata here and handle separately. - attributionID := db.NewTeamAttributionID(teamID) + logger = logger.WithField("attributionID", attributionID) finalizedAt := time.Unix(invoice.StatusTransitions.FinalizedAt, 0) logger = logger. @@ -127,6 +130,36 @@ func (s *BillingService) FinalizeInvoice(ctx context.Context, in *v1.FinalizeInv return &v1.FinalizeInvoiceResponse{}, nil } +func (s *BillingService) CancelSubscription(ctx context.Context, in *v1.CancelSubscriptionRequest) (*v1.CancelSubscriptionResponse, error) { + logger := log.WithField("subscription_id", in.GetSubscriptionId()) + logger.Infof("Subscription ended. Setting cost center back to free.") + if in.GetSubscriptionId() == "" { + return nil, status.Errorf(codes.InvalidArgument, "subscriptionId is required") + } + + subscription, err := s.stripeClient.GetSubscriptionWithCustomer(ctx, in.GetSubscriptionId()) + if err != nil { + return nil, err + } + + attributionID, err := stripe.GetAttributionID(ctx, subscription.Customer) + if err != nil { + return nil, err + } + + costCenter, err := s.ccManager.GetOrCreateCostCenter(ctx, attributionID) + if err != nil { + return nil, err + } + + costCenter.BillingStrategy = db.CostCenter_Other + _, err = s.ccManager.UpdateCostCenter(ctx, costCenter) + if err != nil { + return nil, err + } + return &v1.CancelSubscriptionResponse{}, nil +} + func (s *BillingService) GetUpcomingInvoice(ctx context.Context, in *v1.GetUpcomingInvoiceRequest) (*v1.GetUpcomingInvoiceResponse, error) { if in.GetTeamId() == "" && in.GetUserId() == "" { return nil, status.Errorf(codes.InvalidArgument, "teamId or userId is required") diff --git a/components/usage/pkg/apiv1/billing_noop.go b/components/usage/pkg/apiv1/billing_noop.go index 4d4ef59b2cce08..33874d2a668a3e 100644 --- a/components/usage/pkg/apiv1/billing_noop.go +++ b/components/usage/pkg/apiv1/billing_noop.go @@ -6,6 +6,7 @@ package apiv1 import ( "context" + "github.com/gitpod-io/gitpod/common-go/log" v1 "github.com/gitpod-io/gitpod/usage-api/v1" ) @@ -19,3 +20,8 @@ func (s *BillingServiceNoop) ReconcileInvoices(_ context.Context, _ *v1.Reconcil log.Infof("ReconcileInvoices RPC invoked in no-op mode, no invoices will be updated.") return &v1.ReconcileInvoicesResponse{}, nil } + +func (s *BillingServiceNoop) CancelSubscription(ctx context.Context, in *v1.CancelSubscriptionRequest) (*v1.CancelSubscriptionResponse, error) { + log.Infof("ReconcileInvoices RPC invoked in no-op mode, no invoices will be updated.") + return &v1.CancelSubscriptionResponse{}, nil +} diff --git a/components/usage/pkg/db/cost_center.go b/components/usage/pkg/db/cost_center.go index 7af5182e9405c6..c542ca1f8fa1a4 100644 --- a/components/usage/pkg/db/cost_center.go +++ b/components/usage/pkg/db/cost_center.go @@ -59,6 +59,7 @@ type CostCenterManager struct { // This method creates a codt center and stores it in the DB if there is no preexisting one. func (c *CostCenterManager) GetOrCreateCostCenter(ctx context.Context, attributionID AttributionID) (CostCenter, error) { logger := log.WithField("attributionId", attributionID) + logger.Info("Get or create CostCenter") result, err := getCostCenter(ctx, c.conn, attributionID) if err != nil { @@ -112,13 +113,15 @@ func (c *CostCenterManager) UpdateCostCenter(ctx context.Context, costCenter Cos now := time.Now() - // we don't allow setting the creationTime or the nextBillingTime from outside - costCenter.CreationTime = existingCostCenter.CreationTime + // we always update the creationTime + costCenter.CreationTime = NewVarcharTime(now) + // we don't allow setting the nextBillingTime from outside costCenter.NextBillingTime = existingCostCenter.NextBillingTime // Do we have a billing strategy update? if costCenter.BillingStrategy != existingCostCenter.BillingStrategy { - if existingCostCenter.BillingStrategy == CostCenter_Other { + switch costCenter.BillingStrategy { + case CostCenter_Stripe: // moving to stripe -> let's run a finalization finalizationUsage, err := c.ComputeInvoiceUsageRecord(ctx, costCenter.ID) if err != nil { @@ -130,12 +133,22 @@ func (c *CostCenterManager) UpdateCostCenter(ctx context.Context, costCenter Cos return CostCenter{}, err } } + // we don't manage stripe billing cycle + costCenter.NextBillingTime = VarcharTime{} + + case CostCenter_Other: + // cancelled from stripe reset the spending limit + if costCenter.ID.IsEntity(AttributionEntity_Team) { + costCenter.SpendingLimit = c.cfg.ForTeams + } else { + costCenter.SpendingLimit = c.cfg.ForUsers + } + // see you next month + costCenter.NextBillingTime = NewVarcharTime(now.AddDate(0, 1, 0)) } - c.updateNextBillingTime(&costCenter, now) } - // we update the creationTime - costCenter.CreationTime = NewVarcharTime(now) + log.WithField("cost_center", costCenter).Info("saving cost center.") db := c.conn.Save(&costCenter) if db.Error != nil { return CostCenter{}, fmt.Errorf("failed to save cost center for attributionID %s: %w", costCenter.ID, db.Error) @@ -163,8 +176,3 @@ func (c *CostCenterManager) ComputeInvoiceUsageRecord(ctx context.Context, attri Draft: false, }, nil } - -func (c *CostCenterManager) updateNextBillingTime(costCenter *CostCenter, now time.Time) { - nextMonth := NewVarcharTime(time.Now().AddDate(0, 1, 0)) - costCenter.NextBillingTime = nextMonth -} diff --git a/components/usage/pkg/db/cost_center_test.go b/components/usage/pkg/db/cost_center_test.go index b467e053310565..12bc6a891cf863 100644 --- a/components/usage/pkg/db/cost_center_test.go +++ b/components/usage/pkg/db/cost_center_test.go @@ -83,20 +83,28 @@ func TestCostCenterManager_UpdateCostCenter(t *testing.T) { func TestSaveCostCenterMovedToStripe(t *testing.T) { conn := dbtest.ConnectForTests(t) mnr := db.NewCostCenterManager(conn, db.DefaultSpendingLimit{ - ForTeams: 0, + ForTeams: 20, ForUsers: 500, }) team := db.NewTeamAttributionID(uuid.New().String()) cleanUp(t, conn, team) teamCC, err := mnr.GetOrCreateCostCenter(context.Background(), team) require.NoError(t, err) - require.Equal(t, int32(0), teamCC.SpendingLimit) + require.Equal(t, int32(20), teamCC.SpendingLimit) teamCC.BillingStrategy = db.CostCenter_Stripe - newTeamCC, err := mnr.UpdateCostCenter(context.Background(), teamCC) + teamCC.SpendingLimit = 400050 + teamCC, err = mnr.UpdateCostCenter(context.Background(), teamCC) + require.NoError(t, err) + require.Equal(t, db.CostCenter_Stripe, teamCC.BillingStrategy) + require.Equal(t, db.VarcharTime{}, teamCC.NextBillingTime) + require.Equal(t, int32(400050), teamCC.SpendingLimit) + + teamCC.BillingStrategy = db.CostCenter_Other + teamCC, err = mnr.UpdateCostCenter(context.Background(), teamCC) require.NoError(t, err) - require.Equal(t, db.CostCenter_Stripe, newTeamCC.BillingStrategy) - require.Equal(t, newTeamCC.CreationTime.Time().AddDate(0, 1, 0).Truncate(time.Second), newTeamCC.NextBillingTime.Time().Truncate(time.Second)) + require.Equal(t, teamCC.CreationTime.Time().AddDate(0, 1, 0).Truncate(time.Second), teamCC.NextBillingTime.Time().Truncate(time.Second)) + require.Equal(t, int32(20), teamCC.SpendingLimit) } func cleanUp(t *testing.T, conn *gorm.DB, attributionIds ...db.AttributionID) { diff --git a/components/usage/pkg/db/workspace_instance.go b/components/usage/pkg/db/workspace_instance.go index dbee2c163294f0..f622057581aa35 100644 --- a/components/usage/pkg/db/workspace_instance.go +++ b/components/usage/pkg/db/workspace_instance.go @@ -187,6 +187,10 @@ func ParseAttributionID(s string) (AttributionID, error) { if len(tokens) != 2 { return "", fmt.Errorf("attribution ID (%s) does not have two parts", s) } + _, err := uuid.Parse(tokens[1]) + if err != nil { + return "", fmt.Errorf("The uuid part of attribution ID (%s) is not a valid UUID. %w", tokens[1], err) + } switch tokens[0] { case AttributionEntity_Team: diff --git a/components/usage/pkg/server/server.go b/components/usage/pkg/server/server.go index 91c50ea8681e5b..f295b9c0c547ed 100644 --- a/components/usage/pkg/server/server.go +++ b/components/usage/pkg/server/server.go @@ -158,7 +158,7 @@ func registerGRPCServices(srv *baseserver.Server, conn *gorm.DB, stripeClient *s if stripeClient == nil { v1.RegisterBillingServiceServer(srv.GRPC(), &apiv1.BillingServiceNoop{}) } else { - v1.RegisterBillingServiceServer(srv.GRPC(), apiv1.NewBillingService(stripeClient, conn)) + v1.RegisterBillingServiceServer(srv.GRPC(), apiv1.NewBillingService(stripeClient, conn, ccManager)) } return nil } diff --git a/components/usage/pkg/stripe/stripe.go b/components/usage/pkg/stripe/stripe.go index 12afd66362f747..655a0d7478f63c 100644 --- a/components/usage/pkg/stripe/stripe.go +++ b/components/usage/pkg/stripe/stripe.go @@ -8,10 +8,13 @@ import ( "context" "encoding/json" "fmt" - "github.com/gitpod-io/gitpod/usage/pkg/db" "os" "strings" + "github.com/gitpod-io/gitpod/usage/pkg/db" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/gitpod-io/gitpod/common-go/log" "github.com/stripe/stripe-go/v72" "github.com/stripe/stripe-go/v72/client" @@ -244,6 +247,30 @@ func (c *Client) GetInvoiceWithCustomer(ctx context.Context, invoiceID string) ( return invoice, nil } +func (c *Client) GetSubscriptionWithCustomer(ctx context.Context, subscriptionID string) (*stripe.Subscription, error) { + if subscriptionID == "" { + return nil, fmt.Errorf("no subscriptionID specified") + } + + subscription, err := c.sc.Subscriptions.Get(subscriptionID, &stripe.SubscriptionParams{ + Params: stripe.Params{ + Expand: []*string{stripe.String("customer")}, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to get subscription %s: %w", subscriptionID, err) + } + return subscription, nil +} + +func GetAttributionID(ctx context.Context, customer *stripe.Customer) (db.AttributionID, error) { + if customer == nil { + log.Error("No customer information available for invoice.") + return "", status.Errorf(codes.Internal, "Failed to retrieve customer details from invoice.") + } + return db.ParseAttributionID(customer.Metadata[AttributionIDMetadataKey]) +} + // queriesForCustomersWithAttributionIDs constructs Stripe query strings to find the Stripe Customer for each teamId // It returns multiple queries, each being a big disjunction of subclauses so that we can process multiple teamIds in one query. // `clausesPerQuery` is a limit enforced by the Stripe API.