Skip to content

Commit

Permalink
apacheGH-36975: [C++][FlightRPC] Skip unknown fields, don't crash (ap…
Browse files Browse the repository at this point in the history
…ache#36979)

### Rationale for this change

We should skip unknown fields instead of crashing, for forwards compatibility. 

### What changes are included in this PR?

Skip unknown fields in the FlightData deserializer.

### Are these changes tested?

Yes.

### Are there any user-facing changes?

No.
* Closes: apache#36975

Authored-by: David Li <[email protected]>
Signed-off-by: Sutou Kouhei <[email protected]>
  • Loading branch information
lidavidm authored Aug 1, 2023
1 parent a046b88 commit 0edef8c
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 15 deletions.
35 changes: 35 additions & 0 deletions cpp/src/arrow/flight/flight_internals_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
#include "arrow/testing/gtest_util.h"
#include "arrow/util/string.h"

// Include after Flight headers
#include <grpc/slice.h>

namespace arrow {
namespace flight {

Expand Down Expand Up @@ -651,6 +654,38 @@ TEST_F(TestCookieParsing, CookieCache) {
AddCookieVerifyCache({"id0=0;", "id1=1;", "id2=2"}, "id0=0; id1=1; id2=2");
}

// ----------------------------------------------------------------------
// Protobuf tests

TEST(GrpcTransport, FlightDataDeserialize) {
#ifndef _WIN32
pb::FlightData raw;
// Tack on known and unknown fields by hand here
raw.GetReflection()->MutableUnknownFields(&raw)->AddFixed32(900, 1024);
raw.GetReflection()->MutableUnknownFields(&raw)->AddFixed64(901, 1024);
raw.GetReflection()->MutableUnknownFields(&raw)->AddVarint(902, 1024);
raw.GetReflection()->MutableUnknownFields(&raw)->AddLengthDelimited(903, "foobar");
// Known field comes at end
raw.GetReflection()->MutableUnknownFields(&raw)->AddLengthDelimited(
pb::FlightData::kDataBodyFieldNumber, "data");

auto serialized = raw.SerializeAsString();

grpc_slice slice = grpc_slice_from_copied_buffer(serialized.data(), serialized.size());
// gRPC requires that grpc_slice and grpc::Slice have the same representation
grpc::ByteBuffer buffer(reinterpret_cast<const grpc::Slice*>(&slice), /*nslices=*/1);

flight::internal::FlightData out;
auto status = flight::transport::grpc::FlightDataDeserialize(&buffer, &out);
ASSERT_TRUE(status.ok());
ASSERT_EQ("data", out.body->ToString());

grpc_slice_unref(slice);
#else
GTEST_SKIP() << "Can't use Protobuf symbols on Windows";
#endif
}

// ----------------------------------------------------------------------
// Transport abstraction tests

Expand Down
34 changes: 19 additions & 15 deletions cpp/src/arrow/flight/transport/grpc/serialization_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <cstdint>
#include <limits>
#include <memory>
#include <string>
#include <vector>

Expand Down Expand Up @@ -297,7 +298,7 @@ ::grpc::Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out,
const auto remainder = static_cast<int>(
bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size());
if (remainder) {
slices.push_back(::grpc::Slice(kPaddingBytes, remainder));
slices.emplace_back(kPaddingBytes, remainder);
}
}
}
Expand All @@ -316,7 +317,7 @@ ::grpc::Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out,
::grpc::Status FlightDataDeserialize(ByteBuffer* buffer,
arrow::flight::internal::FlightData* out) {
if (!buffer) {
return ::grpc::Status(::grpc::StatusCode::INTERNAL, "No payload");
return {::grpc::StatusCode::INTERNAL, "No payload"};
}

// Reset fields in case the caller reuses a single allocation
Expand All @@ -342,42 +343,45 @@ ::grpc::Status FlightDataDeserialize(ByteBuffer* buffer,
pb::FlightDescriptor pb_descriptor;
uint32_t length;
if (!pb_stream.ReadVarint32(&length)) {
return ::grpc::Status(::grpc::StatusCode::INTERNAL,
"Unable to parse length of FlightDescriptor");
return {::grpc::StatusCode::INTERNAL,
"Unable to parse length of FlightDescriptor"};
}
// Can't use ParseFromCodedStream as this reads the entire
// rest of the stream into the descriptor command field.
std::string buffer;
pb_stream.ReadString(&buffer, length);
if (!pb_descriptor.ParseFromString(buffer)) {
return ::grpc::Status(::grpc::StatusCode::INTERNAL,
"Unable to parse FlightDescriptor");
return {::grpc::StatusCode::INTERNAL, "Unable to parse FlightDescriptor"};
}
arrow::flight::FlightDescriptor descriptor;
GRPC_RETURN_NOT_OK(
arrow::flight::internal::FromProto(pb_descriptor, &descriptor));
out->descriptor.reset(new arrow::flight::FlightDescriptor(descriptor));
out->descriptor = std::make_unique<arrow::flight::FlightDescriptor>(descriptor);
} break;
case pb::FlightData::kDataHeaderFieldNumber: {
if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) {
return ::grpc::Status(::grpc::StatusCode::INTERNAL,
"Unable to read FlightData metadata");
return {::grpc::StatusCode::INTERNAL, "Unable to read FlightData metadata"};
}
} break;
case pb::FlightData::kAppMetadataFieldNumber: {
if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->app_metadata)) {
return ::grpc::Status(::grpc::StatusCode::INTERNAL,
"Unable to read FlightData application metadata");
return {::grpc::StatusCode::INTERNAL,
"Unable to read FlightData application metadata"};
}
} break;
case pb::FlightData::kDataBodyFieldNumber: {
if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) {
return ::grpc::Status(::grpc::StatusCode::INTERNAL,
"Unable to read FlightData body");
return {::grpc::StatusCode::INTERNAL, "Unable to read FlightData body"};
}
} break;
default:
DCHECK(false) << "cannot happen";
default: {
// Unknown field. We should skip it for compatibility.
if (!WireFormatLite::SkipField(&pb_stream, tag)) {
return {::grpc::StatusCode::INTERNAL,
"Could not skip unknown field tag in FlightData"};
}
break;
}
}
}
buffer->Clear();
Expand Down

0 comments on commit 0edef8c

Please sign in to comment.