Skip to content

Commit

Permalink
[usage] implement CancelSubscription
Browse files Browse the repository at this point in the history
  • Loading branch information
svenefftinge authored and roboquat committed Sep 16, 2022
1 parent 1b76bca commit 65812bf
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 35 deletions.
67 changes: 50 additions & 17 deletions components/usage/pkg/apiv1/billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@ import (
"gorm.io/gorm"
)

func NewBillingService(stripeClient *stripe.Client, conn *gorm.DB) *BillingService {
func NewBillingService(stripeClient *stripe.Client, conn *gorm.DB, ccManager *db.CostCenterManager) *BillingService {
return &BillingService{
stripeClient: stripeClient,
conn: conn,
ccManager: ccManager,
}
}

type BillingService struct {
conn *gorm.DB
stripeClient *stripe.Client
ccManager *db.CostCenterManager

v1.UnimplementedBillingServiceServer
}
Expand All @@ -43,8 +45,20 @@ func (s *BillingService) ReconcileInvoices(ctx context.Context, in *v1.Reconcile
return nil, status.Errorf(codes.Internal, "Failed to reconcile invoices.")
}

creditSummaryForTeams := map[db.AttributionID]int64{}
//TODO (se) make it one query
stripeBalances := []db.Balance{}
for _, balance := range balances {
costCenter, err := s.ccManager.GetOrCreateCostCenter(ctx, balance.AttributionID)
if err != nil {
return nil, err
}
if costCenter.BillingStrategy == db.CostCenter_Stripe {
stripeBalances = append(stripeBalances, balance)
}
}

creditSummaryForTeams := map[db.AttributionID]int64{}
for _, balance := range stripeBalances {
creditSummaryForTeams[balance.AttributionID] = int64(math.Ceil(balance.CreditCents.ToCredits()))
}

Expand All @@ -70,22 +84,11 @@ func (s *BillingService) FinalizeInvoice(ctx context.Context, in *v1.FinalizeInv
return nil, status.Errorf(codes.NotFound, "Failed to get invoice with ID %s: %s", in.GetInvoiceId(), err.Error())
}

customer := invoice.Customer
if customer == nil {
logger.Error("No customer information available for invoice.")
return nil, status.Errorf(codes.Internal, "Failed to retrieve customer details from invoice.")
}
logger = logger.WithField("stripe_customer", customer.ID).WithField("stripe_customer_name", customer.Name)

teamID, found := customer.Metadata[stripe.AttributionIDMetadataKey]
if !found {
logger.Error("Failed to find teamID from subscription metadata.")
return nil, status.Errorf(codes.Internal, "Failed to extra teamID from Stripe subscription.")
attributionID, err := stripe.GetAttributionID(ctx, invoice.Customer)
if err != nil {
return nil, err
}
logger = logger.WithField("team_id", teamID)

// To support individual `user`s, we'll need to also extract the `userId` from metadata here and handle separately.
attributionID := db.NewTeamAttributionID(teamID)
logger = logger.WithField("attributionID", attributionID)
finalizedAt := time.Unix(invoice.StatusTransitions.FinalizedAt, 0)

logger = logger.
Expand Down Expand Up @@ -127,6 +130,36 @@ func (s *BillingService) FinalizeInvoice(ctx context.Context, in *v1.FinalizeInv
return &v1.FinalizeInvoiceResponse{}, nil
}

func (s *BillingService) CancelSubscription(ctx context.Context, in *v1.CancelSubscriptionRequest) (*v1.CancelSubscriptionResponse, error) {
logger := log.WithField("subscription_id", in.GetSubscriptionId())
logger.Infof("Subscription ended. Setting cost center back to free.")
if in.GetSubscriptionId() == "" {
return nil, status.Errorf(codes.InvalidArgument, "subscriptionId is required")
}

subscription, err := s.stripeClient.GetSubscriptionWithCustomer(ctx, in.GetSubscriptionId())
if err != nil {
return nil, err
}

attributionID, err := stripe.GetAttributionID(ctx, subscription.Customer)
if err != nil {
return nil, err
}

costCenter, err := s.ccManager.GetOrCreateCostCenter(ctx, attributionID)
if err != nil {
return nil, err
}

costCenter.BillingStrategy = db.CostCenter_Other
_, err = s.ccManager.UpdateCostCenter(ctx, costCenter)
if err != nil {
return nil, err
}
return &v1.CancelSubscriptionResponse{}, nil
}

func (s *BillingService) GetUpcomingInvoice(ctx context.Context, in *v1.GetUpcomingInvoiceRequest) (*v1.GetUpcomingInvoiceResponse, error) {
if in.GetTeamId() == "" && in.GetUserId() == "" {
return nil, status.Errorf(codes.InvalidArgument, "teamId or userId is required")
Expand Down
6 changes: 6 additions & 0 deletions components/usage/pkg/apiv1/billing_noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package apiv1

import (
"context"

"github.com/gitpod-io/gitpod/common-go/log"
v1 "github.com/gitpod-io/gitpod/usage-api/v1"
)
Expand All @@ -19,3 +20,8 @@ func (s *BillingServiceNoop) ReconcileInvoices(_ context.Context, _ *v1.Reconcil
log.Infof("ReconcileInvoices RPC invoked in no-op mode, no invoices will be updated.")
return &v1.ReconcileInvoicesResponse{}, nil
}

func (s *BillingServiceNoop) CancelSubscription(ctx context.Context, in *v1.CancelSubscriptionRequest) (*v1.CancelSubscriptionResponse, error) {
log.Infof("ReconcileInvoices RPC invoked in no-op mode, no invoices will be updated.")
return &v1.CancelSubscriptionResponse{}, nil
}
30 changes: 19 additions & 11 deletions components/usage/pkg/db/cost_center.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type CostCenterManager struct {
// This method creates a codt center and stores it in the DB if there is no preexisting one.
func (c *CostCenterManager) GetOrCreateCostCenter(ctx context.Context, attributionID AttributionID) (CostCenter, error) {
logger := log.WithField("attributionId", attributionID)
logger.Info("Get or create CostCenter")

result, err := getCostCenter(ctx, c.conn, attributionID)
if err != nil {
Expand Down Expand Up @@ -112,13 +113,15 @@ func (c *CostCenterManager) UpdateCostCenter(ctx context.Context, costCenter Cos

now := time.Now()

// we don't allow setting the creationTime or the nextBillingTime from outside
costCenter.CreationTime = existingCostCenter.CreationTime
// we always update the creationTime
costCenter.CreationTime = NewVarcharTime(now)
// we don't allow setting the nextBillingTime from outside
costCenter.NextBillingTime = existingCostCenter.NextBillingTime

// Do we have a billing strategy update?
if costCenter.BillingStrategy != existingCostCenter.BillingStrategy {
if existingCostCenter.BillingStrategy == CostCenter_Other {
switch costCenter.BillingStrategy {
case CostCenter_Stripe:
// moving to stripe -> let's run a finalization
finalizationUsage, err := c.ComputeInvoiceUsageRecord(ctx, costCenter.ID)
if err != nil {
Expand All @@ -130,12 +133,22 @@ func (c *CostCenterManager) UpdateCostCenter(ctx context.Context, costCenter Cos
return CostCenter{}, err
}
}
// we don't manage stripe billing cycle
costCenter.NextBillingTime = VarcharTime{}

case CostCenter_Other:
// cancelled from stripe reset the spending limit
if costCenter.ID.IsEntity(AttributionEntity_Team) {
costCenter.SpendingLimit = c.cfg.ForTeams
} else {
costCenter.SpendingLimit = c.cfg.ForUsers
}
// see you next month
costCenter.NextBillingTime = NewVarcharTime(now.AddDate(0, 1, 0))
}
c.updateNextBillingTime(&costCenter, now)
}

// we update the creationTime
costCenter.CreationTime = NewVarcharTime(now)
log.WithField("cost_center", costCenter).Info("saving cost center.")
db := c.conn.Save(&costCenter)
if db.Error != nil {
return CostCenter{}, fmt.Errorf("failed to save cost center for attributionID %s: %w", costCenter.ID, db.Error)
Expand Down Expand Up @@ -163,8 +176,3 @@ func (c *CostCenterManager) ComputeInvoiceUsageRecord(ctx context.Context, attri
Draft: false,
}, nil
}

func (c *CostCenterManager) updateNextBillingTime(costCenter *CostCenter, now time.Time) {
nextMonth := NewVarcharTime(time.Now().AddDate(0, 1, 0))
costCenter.NextBillingTime = nextMonth
}
18 changes: 13 additions & 5 deletions components/usage/pkg/db/cost_center_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,28 @@ func TestCostCenterManager_UpdateCostCenter(t *testing.T) {
func TestSaveCostCenterMovedToStripe(t *testing.T) {
conn := dbtest.ConnectForTests(t)
mnr := db.NewCostCenterManager(conn, db.DefaultSpendingLimit{
ForTeams: 0,
ForTeams: 20,
ForUsers: 500,
})
team := db.NewTeamAttributionID(uuid.New().String())
cleanUp(t, conn, team)
teamCC, err := mnr.GetOrCreateCostCenter(context.Background(), team)
require.NoError(t, err)
require.Equal(t, int32(0), teamCC.SpendingLimit)
require.Equal(t, int32(20), teamCC.SpendingLimit)

teamCC.BillingStrategy = db.CostCenter_Stripe
newTeamCC, err := mnr.UpdateCostCenter(context.Background(), teamCC)
teamCC.SpendingLimit = 400050
teamCC, err = mnr.UpdateCostCenter(context.Background(), teamCC)
require.NoError(t, err)
require.Equal(t, db.CostCenter_Stripe, teamCC.BillingStrategy)
require.Equal(t, db.VarcharTime{}, teamCC.NextBillingTime)
require.Equal(t, int32(400050), teamCC.SpendingLimit)

teamCC.BillingStrategy = db.CostCenter_Other
teamCC, err = mnr.UpdateCostCenter(context.Background(), teamCC)
require.NoError(t, err)
require.Equal(t, db.CostCenter_Stripe, newTeamCC.BillingStrategy)
require.Equal(t, newTeamCC.CreationTime.Time().AddDate(0, 1, 0).Truncate(time.Second), newTeamCC.NextBillingTime.Time().Truncate(time.Second))
require.Equal(t, teamCC.CreationTime.Time().AddDate(0, 1, 0).Truncate(time.Second), teamCC.NextBillingTime.Time().Truncate(time.Second))
require.Equal(t, int32(20), teamCC.SpendingLimit)
}

func cleanUp(t *testing.T, conn *gorm.DB, attributionIds ...db.AttributionID) {
Expand Down
4 changes: 4 additions & 0 deletions components/usage/pkg/db/workspace_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ func ParseAttributionID(s string) (AttributionID, error) {
if len(tokens) != 2 {
return "", fmt.Errorf("attribution ID (%s) does not have two parts", s)
}
_, err := uuid.Parse(tokens[1])
if err != nil {
return "", fmt.Errorf("The uuid part of attribution ID (%s) is not a valid UUID. %w", tokens[1], err)
}

switch tokens[0] {
case AttributionEntity_Team:
Expand Down
2 changes: 1 addition & 1 deletion components/usage/pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func registerGRPCServices(srv *baseserver.Server, conn *gorm.DB, stripeClient *s
if stripeClient == nil {
v1.RegisterBillingServiceServer(srv.GRPC(), &apiv1.BillingServiceNoop{})
} else {
v1.RegisterBillingServiceServer(srv.GRPC(), apiv1.NewBillingService(stripeClient, conn))
v1.RegisterBillingServiceServer(srv.GRPC(), apiv1.NewBillingService(stripeClient, conn, ccManager))
}
return nil
}
29 changes: 28 additions & 1 deletion components/usage/pkg/stripe/stripe.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/gitpod-io/gitpod/usage/pkg/db"
"os"
"strings"

"github.com/gitpod-io/gitpod/usage/pkg/db"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/gitpod-io/gitpod/common-go/log"
"github.com/stripe/stripe-go/v72"
"github.com/stripe/stripe-go/v72/client"
Expand Down Expand Up @@ -244,6 +247,30 @@ func (c *Client) GetInvoiceWithCustomer(ctx context.Context, invoiceID string) (
return invoice, nil
}

func (c *Client) GetSubscriptionWithCustomer(ctx context.Context, subscriptionID string) (*stripe.Subscription, error) {
if subscriptionID == "" {
return nil, fmt.Errorf("no subscriptionID specified")
}

subscription, err := c.sc.Subscriptions.Get(subscriptionID, &stripe.SubscriptionParams{
Params: stripe.Params{
Expand: []*string{stripe.String("customer")},
},
})
if err != nil {
return nil, fmt.Errorf("failed to get subscription %s: %w", subscriptionID, err)
}
return subscription, nil
}

func GetAttributionID(ctx context.Context, customer *stripe.Customer) (db.AttributionID, error) {
if customer == nil {
log.Error("No customer information available for invoice.")
return "", status.Errorf(codes.Internal, "Failed to retrieve customer details from invoice.")
}
return db.ParseAttributionID(customer.Metadata[AttributionIDMetadataKey])
}

// queriesForCustomersWithAttributionIDs constructs Stripe query strings to find the Stripe Customer for each teamId
// It returns multiple queries, each being a big disjunction of subclauses so that we can process multiple teamIds in one query.
// `clausesPerQuery` is a limit enforced by the Stripe API.
Expand Down

0 comments on commit 65812bf

Please sign in to comment.