forked from sourcegraph/zoekt
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
grpc: port messagesize interceptors and raise default client message …
…size to 90mb (sourcegraph#640)
- Loading branch information
Showing
4 changed files
with
298 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
package messagesize | ||
|
||
import ( | ||
"fmt" | ||
"math" | ||
"os" | ||
|
||
"google.golang.org/grpc" | ||
|
||
"github.com/dustin/go-humanize" | ||
) | ||
|
||
var ( | ||
smallestAllowedMaxMessageSize = uint64(4 * 1024 * 1024) // 4 MB: There isn't a scenario where we'd want to dip below the default of 4MB. | ||
largestAllowedMaxMessageSize = uint64(math.MaxInt) // This is the largest allowed value for the type accepted by the grpc.MaxSize[...] options. | ||
|
||
envClientMessageSize = getEnv("GRPC_CLIENT_MAX_MESSAGE_SIZE", messageSizeDisabled) // set the maximum message size for gRPC clients (ex: "40MB") | ||
envServerMessageSize = getEnv("GRPC_SERVER_MAX_MESSAGE_SIZE", messageSizeDisabled) // set the maximum message size for gRPC servers (ex: "40MB") | ||
|
||
messageSizeDisabled = "message_size_disabled" // sentinel value for when the message size env var isn't set | ||
) | ||
|
||
// MustGetClientMessageSizeFromEnv returns a slice of grpc.DialOptions that set the maximum message size for gRPC clients if | ||
// the "SRC_GRPC_CLIENT_MAX_MESSAGE_SIZE" environment variable is set to a valid size value (ex: "40 MB"). | ||
// | ||
// If the environment variable isn't set, it returns nil. | ||
// If the size value in the environment variable is invalid (too small, not parsable, etc.), it panics. | ||
func MustGetClientMessageSizeFromEnv() []grpc.DialOption { | ||
if envClientMessageSize == messageSizeDisabled { | ||
return nil | ||
} | ||
|
||
messageSize, err := getMessageSizeBytesFromString(envClientMessageSize, smallestAllowedMaxMessageSize, largestAllowedMaxMessageSize) | ||
if err != nil { | ||
panic(fmt.Sprintf("failed to get gRPC client message size: %s", err)) | ||
} | ||
|
||
return []grpc.DialOption{ | ||
grpc.WithDefaultCallOptions( | ||
grpc.MaxCallRecvMsgSize(messageSize), | ||
grpc.MaxCallSendMsgSize(messageSize), | ||
), | ||
} | ||
} | ||
|
||
// MustGetServerMessageSizeFromEnv returns a slice of grpc.ServerOption that set the maximum message size for gRPC servers if | ||
// the "SRC_GRPC_SERVER_MAX_MESSAGE_SIZE" environment variable is set to a valid size value (ex: "40 MB"). | ||
// | ||
// If the environment variable isn't set, it returns nil. | ||
// If the size value in the environment variable is invalid (too small, not parsable, etc.), it panics. | ||
func MustGetServerMessageSizeFromEnv() []grpc.ServerOption { | ||
if envServerMessageSize == messageSizeDisabled { | ||
return nil | ||
} | ||
|
||
messageSize, err := getMessageSizeBytesFromString(envServerMessageSize, smallestAllowedMaxMessageSize, largestAllowedMaxMessageSize) | ||
if err != nil { | ||
panic(fmt.Sprintf("failed to get gRPC server message size: %s", err)) | ||
} | ||
|
||
return []grpc.ServerOption{ | ||
grpc.MaxRecvMsgSize(messageSize), | ||
grpc.MaxSendMsgSize(messageSize), | ||
} | ||
} | ||
|
||
// getMessageSizeBytesFromEnv parses rawSize returns the message size in bytes within the range [minSize, maxSize]. | ||
// | ||
// If rawSize isn't a valid size is not set or the value is outside the allowed range, it returns an error. | ||
func getMessageSizeBytesFromString(rawSize string, minSize, maxSize uint64) (size int, err error) { | ||
sizeBytes, err := humanize.ParseBytes(rawSize) | ||
if err != nil { | ||
return 0, &parseError{ | ||
rawSize: rawSize, | ||
err: err, | ||
} | ||
} | ||
|
||
if sizeBytes < minSize || sizeBytes > maxSize { | ||
return 0, &sizeOutOfRangeError{ | ||
size: humanize.IBytes(sizeBytes), | ||
min: humanize.IBytes(minSize), | ||
max: humanize.IBytes(maxSize), | ||
} | ||
} | ||
|
||
return int(sizeBytes), nil | ||
} | ||
|
||
// parseError occurs when the environment variable's value cannot be parsed as a byte size. | ||
type parseError struct { | ||
// rawSize is the raw size string that was attempted to be parsed | ||
rawSize string | ||
// err is the error that occurred while parsing rawSize | ||
err error | ||
} | ||
|
||
func (e *parseError) Error() string { | ||
return fmt.Sprintf("failed to parse %q as bytes: %s", e.rawSize, e.err) | ||
} | ||
|
||
func (e *parseError) Unwrap() error { | ||
return e.err | ||
} | ||
|
||
// sizeOutOfRangeError occurs when the environment variable's value is outside of the allowed range. | ||
type sizeOutOfRangeError struct { | ||
// size is the size that was out of range | ||
size string | ||
// min is the minimum allowed size | ||
min string | ||
// max is the maximum allowed size | ||
max string | ||
} | ||
|
||
func (e *sizeOutOfRangeError) Error() string { | ||
return fmt.Sprintf("size %s is outside of allowed range [%s, %s]", e.size, e.min, e.max) | ||
} | ||
|
||
func getEnv(key string, defaultValue string) string { | ||
value, ok := os.LookupEnv(key) | ||
if !ok { | ||
return defaultValue | ||
} | ||
|
||
return value | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
package messagesize | ||
|
||
import ( | ||
"errors" | ||
"math" | ||
"testing" | ||
|
||
"github.com/google/go-cmp/cmp" | ||
) | ||
|
||
func TestGetMessageSizeBytesFromString(t *testing.T) { | ||
|
||
t.Run("8 MB", func(t *testing.T) { | ||
sizeString := "8MB" | ||
|
||
size, err := getMessageSizeBytesFromString(sizeString, 0, math.MaxInt) | ||
|
||
if err != nil { | ||
t.Fatalf("unexpected error: %s", err) | ||
} | ||
|
||
expectedSize := 8 * 1000 * 1000 | ||
if diff := cmp.Diff(expectedSize, size); diff != "" { | ||
t.Fatalf("unexpected size (-want +got):\n%s", diff) | ||
} | ||
}) | ||
|
||
t.Run("just small enough", func(t *testing.T) { | ||
sizeString := "4MB" // inside large-end of range | ||
|
||
fourMegaBytes := 4 * 1000 * 1000 | ||
size, err := getMessageSizeBytesFromString(sizeString, 0, uint64(fourMegaBytes)) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %s", err) | ||
} | ||
|
||
if diff := cmp.Diff(fourMegaBytes, size); diff != "" { | ||
t.Fatalf("unexpected size (-want +got):\n%s", diff) | ||
} | ||
}) | ||
|
||
t.Run("just large enough", func(t *testing.T) { | ||
sizeString := "4MB" // inside low-end of range | ||
|
||
fourMegaBytes := 4 * 1000 * 1000 | ||
size, err := getMessageSizeBytesFromString(sizeString, uint64(fourMegaBytes), math.MaxInt) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %s", err) | ||
} | ||
|
||
if diff := cmp.Diff(fourMegaBytes, size); diff != "" { | ||
t.Fatalf("unexpected size (-want +got):\n%s", diff) | ||
} | ||
}) | ||
|
||
t.Run("invalid size", func(t *testing.T) { | ||
sizeString := "this-is-not-a-size" | ||
|
||
_, err := getMessageSizeBytesFromString(sizeString, 0, math.MaxInt) | ||
var expectedErr *parseError | ||
if !errors.As(err, &expectedErr) { | ||
t.Fatalf("expected parseError, got error %q", err) | ||
} | ||
}) | ||
|
||
t.Run("empty", func(t *testing.T) { | ||
sizeString := "" | ||
|
||
_, err := getMessageSizeBytesFromString(sizeString, 0, math.MaxInt) | ||
var expectedErr *parseError | ||
if !errors.As(err, &expectedErr) { | ||
t.Fatalf("expected parseError, got error %q", err) | ||
} | ||
}) | ||
|
||
t.Run("too large", func(t *testing.T) { | ||
sizeString := "4MB" // above range | ||
|
||
twoMegaBytes := 2 * 1024 * 1024 | ||
_, err := getMessageSizeBytesFromString(sizeString, 0, uint64(twoMegaBytes)) | ||
var expectedErr *sizeOutOfRangeError | ||
if !errors.As(err, &expectedErr) { | ||
t.Fatalf("expected sizeOutOfRangeError, got error %q", err) | ||
} | ||
}) | ||
|
||
t.Run("too small", func(t *testing.T) { | ||
sizeString := "1MB" // below range | ||
|
||
twoMegaBytes := 2 * 1024 * 1024 | ||
_, err := getMessageSizeBytesFromString(sizeString, uint64(twoMegaBytes), math.MaxInt) | ||
var expectedErr *sizeOutOfRangeError | ||
if !errors.As(err, &expectedErr) { | ||
t.Fatalf("expected sizeOutOfRangeError, got error %q", err) | ||
} | ||
}) | ||
} |