Skip to content

Commit

Permalink
PR feedback: check uint bounds
Browse files Browse the repository at this point in the history
cobra does check the boundaries
  • Loading branch information
vroldanbet committed Apr 29, 2024
1 parent c557965 commit 4217e06
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
7 changes: 4 additions & 3 deletions internal/commands/relationship.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func RegisterRelationshipCmd(rootCmd *cobra.Command) *cobra.Command {
relationshipCmd.AddCommand(bulkDeleteCmd)
bulkDeleteCmd.Flags().Bool("force", false, "force deletion of all elements in batches defined by <optional-limit>")
bulkDeleteCmd.Flags().String("subject-filter", "", "optional subject filter")
bulkDeleteCmd.Flags().Uint("optional-limit", 1000, "the max amount of elements to delete. If you want to delete all in batches of size <optional-limit>, set --force to true")
bulkDeleteCmd.Flags().Uint32("optional-limit", 1000, "the max amount of elements to delete. If you want to delete all in batches of size <optional-limit>, set --force to true")
bulkDeleteCmd.Flags().Bool("estimate-count", true, "estimate the count of relationships to be deleted")
_ = bulkDeleteCmd.Flags().MarkDeprecated("estimate-count", "no longer used, make use of --optional-limit instead")
return relationshipCmd
Expand Down Expand Up @@ -133,12 +133,13 @@ func bulkDeleteRelationships(cmd *cobra.Command, args []string) error {
}()

allowPartialDeletions := cobrautil.MustGetBool(cmd, "force")
optionalLimit := cobrautil.MustGetUint(cmd, "optional-limit")
optionalLimit := cobrautil.MustGetUint32(cmd, "optional-limit")

var resp *v1.DeleteRelationshipsResponse
for {
delRequest := &v1.DeleteRelationshipsRequest{
RelationshipFilter: filter,
OptionalLimit: uint32(optionalLimit),
OptionalLimit: optionalLimit,
OptionalAllowPartialDeletions: allowPartialDeletions,
}
log.Trace().Interface("request", delRequest).Msg("deleting relationships")
Expand Down
6 changes: 3 additions & 3 deletions internal/commands/relationship_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ func TestBulkDeleteForcing(t *testing.T) {
client.NewClient = zedtesting.ClientFromConn(conn)
testCmd := zedtesting.CreateTestCobraCommandWithFlagValue(t,
zedtesting.StringFlag{FlagName: "subject-filter"},
zedtesting.UintFlag{FlagName: "optional-limit", FlagValue: 1},
zedtesting.UintFlag32{FlagName: "optional-limit", FlagValue: 1},
zedtesting.BoolFlag{FlagName: "force", FlagValue: true})
c, err := client.NewClient(testCmd)
require.NoError(t, err)
Expand Down Expand Up @@ -629,7 +629,7 @@ func TestBulkDeleteManyForcing(t *testing.T) {
client.NewClient = zedtesting.ClientFromConn(conn)
testCmd := zedtesting.CreateTestCobraCommandWithFlagValue(t,
zedtesting.StringFlag{FlagName: "subject-filter"},
zedtesting.UintFlag{FlagName: "optional-limit", FlagValue: 1},
zedtesting.UintFlag32{FlagName: "optional-limit", FlagValue: 1},
zedtesting.BoolFlag{FlagName: "force", FlagValue: true})
c, err := client.NewClient(testCmd)
require.NoError(t, err)
Expand Down Expand Up @@ -671,7 +671,7 @@ func TestBulkDeleteNotForcing(t *testing.T) {
client.NewClient = zedtesting.ClientFromConn(conn)
testCmd := zedtesting.CreateTestCobraCommandWithFlagValue(t,
zedtesting.StringFlag{FlagName: "subject-filter"},
zedtesting.UintFlag{FlagName: "optional-limit", FlagValue: 1},
zedtesting.UintFlag32{FlagName: "optional-limit", FlagValue: 1},
zedtesting.BoolFlag{FlagName: "force", FlagValue: false})
c, err := client.NewClient(testCmd)
require.NoError(t, err)
Expand Down
7 changes: 7 additions & 0 deletions internal/testing/test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ type UintFlag struct {
FlagValue uint
}

type UintFlag32 struct {
FlagName string
FlagValue uint32
}

type DurationFlag struct {
FlagName string
FlagValue time.Duration
Expand All @@ -100,6 +105,8 @@ func CreateTestCobraCommandWithFlagValue(t *testing.T, flagAndValues ...any) *co
c.Flags().Int(f.FlagName, f.FlagValue, "")
case UintFlag:
c.Flags().Uint(f.FlagName, f.FlagValue, "")
case UintFlag32:
c.Flags().Uint32(f.FlagName, f.FlagValue, "")
case DurationFlag:
c.Flags().Duration(f.FlagName, f.FlagValue, "")
default:
Expand Down

0 comments on commit 4217e06

Please sign in to comment.