Skip to content

Commit b3120c1

Browse files
authored
GH-49548: [C++][FlightRPC] Decouple Flight Serialize/Deserialize from gRPC transport (#49549)
### Rationale for this change Currently the Serialize/Deserialize APIs are gRPC dependent. This means that any code that needs to encode or decode Flight data must depend on gRPC C++ internals. After some discussions around trying to build a PoC using gRPC's generic API with gRPC's BidiReactor we discussed that these primitives should be made gRPC agnostic. ### What changes are included in this PR? - Move the serialization/deserialization logic from `cpp/src/arrow/flight/transport/grpc/serialization_internal.{h,cc}` to `cpp/src/arrow/flight/serialization_internal.cc` - Create new `arrow::Result<arrow::BufferVector> SerializePayloadToBuffers(const FlightPayload& msg)` gRPC agnostic function. - Create new `arrow::Result<arrow::flight::internal::FlightData> DeserializeFlightData(const std::shared_ptr<arrow::Buffer>& buffer)` gRPC agnostic function. - Keep the existing serialize/deserialize functions for gRPC as simple wrappers on top of the new serialization functions to implement the `grpc::ByteBuffer` and `grpc::Slice` details. - Add utility `arrow::Result<BufferVector> SerializeToBuffers() const;` to `FlightPayload` struct. - Add roundtrip tests to serialize/deserialize ### Are these changes tested? Yes, both by existing tests and new tests. ### Are there any user-facing changes? No * GitHub Issue: #49548 Authored-by: Raúl Cumplido <raulcumplido@gmail.com> Signed-off-by: Raúl Cumplido <raulcumplido@gmail.com>
1 parent 9e413fa commit b3120c1

File tree

7 files changed

+340
-227
lines changed

7 files changed

+340
-227
lines changed

cpp/src/arrow/flight/flight_internals_test.cc

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,18 @@
2323
#include <gmock/gmock.h>
2424
#include <gtest/gtest.h>
2525

26+
#include "arrow/buffer.h"
2627
#include "arrow/flight/client_cookie_middleware.h"
2728
#include "arrow/flight/client_middleware.h"
2829
#include "arrow/flight/cookie_internal.h"
2930
#include "arrow/flight/serialization_internal.h"
31+
#include "arrow/flight/server.h"
3032
#include "arrow/flight/test_util.h"
33+
#include "arrow/flight/transport.h"
3134
#include "arrow/flight/transport/grpc/util_internal.h"
3235
#include "arrow/flight/types.h"
36+
#include "arrow/ipc/reader.h"
37+
#include "arrow/record_batch.h"
3338
#include "arrow/status.h"
3439
#include "arrow/testing/gtest_util.h"
3540
#include "arrow/util/string.h"
@@ -730,6 +735,74 @@ TEST(GrpcTransport, FlightDataDeserialize) {
730735
#endif
731736
}
732737

738+
// ----------------------------------------------------------------------
739+
// Transport-agnostic serialization roundtrip tests
740+
741+
TEST(FlightSerialization, RoundtripPayloadWithBody) {
742+
// Use RecordBatchStream to generate FlightPayloads
743+
auto schema = arrow::schema({arrow::field("a", arrow::int32())});
744+
auto arr = ArrayFromJSON(arrow::int32(), "[1, 2, 3]");
745+
auto batch = RecordBatch::Make(schema, 3, {arr});
746+
auto reader = RecordBatchReader::Make({batch}).ValueOrDie();
747+
RecordBatchStream stream(std::move(reader));
748+
749+
// Get a FlightPayload from the stream
750+
ASSERT_OK_AND_ASSIGN(auto schema_payload, stream.GetSchemaPayload());
751+
ASSERT_OK_AND_ASSIGN(auto flight_payload, stream.Next());
752+
753+
// Add app_metadata to the flight payload
754+
flight_payload.app_metadata = Buffer::FromString("test-metadata");
755+
756+
// Serialize FlightPayload to BufferVector
757+
ASSERT_OK_AND_ASSIGN(auto buffers, internal::SerializePayloadToBuffers(flight_payload));
758+
ASSERT_GT(buffers.size(), 0);
759+
760+
// Concatenate to a single buffer for deserialization and deserialize.
761+
ASSERT_OK_AND_ASSIGN(auto concat, ConcatenateBuffers(buffers));
762+
ASSERT_OK_AND_ASSIGN(auto data, internal::DeserializeFlightData(concat));
763+
764+
// Verify IPC metadata (data_header) is present
765+
ASSERT_NE(data.metadata, nullptr);
766+
ASSERT_GT(data.metadata->size(), 0);
767+
768+
// Verify app_metadata
769+
ASSERT_NE(data.app_metadata, nullptr);
770+
ASSERT_EQ(data.app_metadata->ToString(), "test-metadata");
771+
772+
// Verify body and message are present
773+
ASSERT_NE(data.body, nullptr);
774+
ASSERT_GT(data.body->size(), 0);
775+
ASSERT_OK_AND_ASSIGN(auto message, data.OpenMessage());
776+
ASSERT_NE(message, nullptr);
777+
// Also verify the RecordBatch roundtrips correctly
778+
ipc::DictionaryMemo dict_memo;
779+
ASSERT_OK_AND_ASSIGN(auto result_batch,
780+
ipc::ReadRecordBatch(*message, schema, &dict_memo,
781+
ipc::IpcReadOptions::Defaults()));
782+
ASSERT_TRUE(result_batch->Equals(*batch));
783+
}
784+
785+
TEST(FlightSerialization, RoundtripMetadataOnly) {
786+
// A metadata-only payload (no IPC body, no descriptor)
787+
auto app_meta = Buffer::FromString("metadata-only-message");
788+
789+
FlightPayload payload;
790+
payload.app_metadata = std::move(app_meta);
791+
792+
// Serialize
793+
ASSERT_OK_AND_ASSIGN(auto buffers, internal::SerializePayloadToBuffers(payload));
794+
ASSERT_OK_AND_ASSIGN(auto concat, ConcatenateBuffers(buffers));
795+
796+
// Deserialize
797+
ASSERT_OK_AND_ASSIGN(auto data, internal::DeserializeFlightData(concat));
798+
799+
// Verify: no descriptor, no IPC metadata, just app_metadata
800+
ASSERT_EQ(data.descriptor, nullptr);
801+
ASSERT_EQ(data.metadata, nullptr);
802+
ASSERT_NE(data.app_metadata, nullptr);
803+
ASSERT_EQ(data.app_metadata->ToString(), "metadata-only-message");
804+
}
805+
733806
// ----------------------------------------------------------------------
734807
// Transport abstraction tests
735808

cpp/src/arrow/flight/serialization_internal.cc

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,24 @@
1717

1818
#include "arrow/flight/serialization_internal.h"
1919

20+
#include <limits>
2021
#include <memory>
2122
#include <string>
2223

2324
#include <google/protobuf/any.pb.h>
25+
#include <google/protobuf/io/coded_stream.h>
26+
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
27+
#include <google/protobuf/wire_format_lite.h>
2428

2529
#include "arrow/buffer.h"
2630
#include "arrow/flight/protocol_internal.h"
2731
#include "arrow/io/memory.h"
32+
#include "arrow/ipc/message.h"
2833
#include "arrow/ipc/reader.h"
2934
#include "arrow/ipc/writer.h"
3035
#include "arrow/result.h"
3136
#include "arrow/status.h"
37+
#include "arrow/util/logging_internal.h"
3238

3339
// Lambda helper & CTAD
3440
template <class... Ts>
@@ -612,6 +618,232 @@ Status ToProto(const CloseSessionResult& result, pb::CloseSessionResult* pb_resu
612618
return Status::OK();
613619
}
614620

621+
namespace {
622+
using google::protobuf::internal::WireFormatLite;
623+
using google::protobuf::io::ArrayOutputStream;
624+
using google::protobuf::io::CodedInputStream;
625+
using google::protobuf::io::CodedOutputStream;
626+
static constexpr int64_t kInt32Max = std::numeric_limits<int32_t>::max();
627+
const uint8_t kPaddingBytes[8] = {0, 0, 0, 0, 0, 0, 0, 0};
628+
629+
// Update the sizes of our Protobuf fields based on the given IPC payload.
630+
arrow::Status IpcMessageHeaderSize(const arrow::ipc::IpcPayload& ipc_msg, bool has_body,
631+
size_t* header_size, int32_t* metadata_size) {
632+
DCHECK_LE(ipc_msg.metadata->size(), kInt32Max);
633+
*metadata_size = static_cast<int32_t>(ipc_msg.metadata->size());
634+
635+
// 1 byte for metadata tag
636+
*header_size += 1 + WireFormatLite::LengthDelimitedSize(*metadata_size);
637+
638+
// 2 bytes for body tag
639+
if (has_body) {
640+
// We write the body tag in the header but not the actual body data
641+
*header_size += 2 + WireFormatLite::LengthDelimitedSize(ipc_msg.body_length) -
642+
ipc_msg.body_length;
643+
}
644+
645+
return arrow::Status::OK();
646+
}
647+
648+
bool ReadBytesZeroCopy(const std::shared_ptr<Buffer>& source_data,
649+
CodedInputStream* input, std::shared_ptr<Buffer>* out) {
650+
uint32_t length;
651+
if (!input->ReadVarint32(&length)) {
652+
return false;
653+
}
654+
auto buf =
655+
SliceBuffer(source_data, input->CurrentPosition(), static_cast<int64_t>(length));
656+
*out = buf;
657+
return input->Skip(static_cast<int>(length));
658+
}
659+
660+
} // namespace
661+
662+
arrow::Result<arrow::BufferVector> SerializePayloadToBuffers(const FlightPayload& msg) {
663+
// Size of the IPC body (protobuf: data_body)
664+
size_t body_size = 0;
665+
// Size of the Protobuf "header" (everything except for the body)
666+
size_t header_size = 0;
667+
// Size of IPC header metadata (protobuf: data_header)
668+
int32_t metadata_size = 0;
669+
670+
// Write the descriptor if present
671+
int32_t descriptor_size = 0;
672+
if (msg.descriptor != nullptr) {
673+
DCHECK_LE(msg.descriptor->size(), kInt32Max);
674+
descriptor_size = static_cast<int32_t>(msg.descriptor->size());
675+
header_size += 1 + WireFormatLite::LengthDelimitedSize(descriptor_size);
676+
}
677+
678+
// App metadata tag if appropriate
679+
int32_t app_metadata_size = 0;
680+
if (msg.app_metadata && msg.app_metadata->size() > 0) {
681+
DCHECK_LE(msg.app_metadata->size(), kInt32Max);
682+
app_metadata_size = static_cast<int32_t>(msg.app_metadata->size());
683+
header_size += 1 + WireFormatLite::LengthDelimitedSize(app_metadata_size);
684+
}
685+
686+
const arrow::ipc::IpcPayload& ipc_msg = msg.ipc_message;
687+
// No data in this payload (metadata-only).
688+
bool has_ipc = ipc_msg.type != ipc::MessageType::NONE;
689+
bool has_body = has_ipc ? ipc::Message::HasBody(ipc_msg.type) : false;
690+
691+
if (has_ipc) {
692+
DCHECK(has_body || ipc_msg.body_length == 0);
693+
ARROW_RETURN_NOT_OK(
694+
IpcMessageHeaderSize(ipc_msg, has_body, &header_size, &metadata_size));
695+
body_size = static_cast<size_t>(ipc_msg.body_length);
696+
}
697+
698+
// TODO(wesm): messages over 2GB unlikely to be yet supported
699+
// Validated in WritePayload since returning error here causes gRPC to fail an assertion
700+
DCHECK_LE(body_size, kInt32Max);
701+
702+
// Allocate and initialize buffers
703+
arrow::BufferVector buffers;
704+
ARROW_ASSIGN_OR_RAISE(auto header_buf, arrow::AllocateBuffer(header_size));
705+
706+
// Force the header_stream to be destructed, which actually flushes
707+
// the data into the slice.
708+
{
709+
ArrayOutputStream header_writer(const_cast<uint8_t*>(header_buf->mutable_data()),
710+
static_cast<int>(header_size));
711+
CodedOutputStream header_stream(&header_writer);
712+
713+
// Write descriptor
714+
if (msg.descriptor != nullptr) {
715+
WireFormatLite::WriteTag(pb::FlightData::kFlightDescriptorFieldNumber,
716+
WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream);
717+
header_stream.WriteVarint32(descriptor_size);
718+
header_stream.WriteRawMaybeAliased(msg.descriptor->data(),
719+
static_cast<int>(msg.descriptor->size()));
720+
}
721+
722+
// Write header
723+
if (has_ipc) {
724+
WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber,
725+
WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream);
726+
header_stream.WriteVarint32(metadata_size);
727+
header_stream.WriteRawMaybeAliased(ipc_msg.metadata->data(),
728+
static_cast<int>(ipc_msg.metadata->size()));
729+
}
730+
731+
// Write app metadata
732+
if (app_metadata_size > 0) {
733+
WireFormatLite::WriteTag(pb::FlightData::kAppMetadataFieldNumber,
734+
WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream);
735+
header_stream.WriteVarint32(app_metadata_size);
736+
header_stream.WriteRawMaybeAliased(msg.app_metadata->data(),
737+
static_cast<int>(msg.app_metadata->size()));
738+
}
739+
740+
if (has_body) {
741+
// Write body tag
742+
WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber,
743+
WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream);
744+
header_stream.WriteVarint32(static_cast<uint32_t>(body_size));
745+
746+
// Enqueue body buffers for writing, without copying
747+
for (const auto& buffer : ipc_msg.body_buffers) {
748+
// Buffer may be null when the row length is zero, or when all
749+
// entries are invalid.
750+
if (!buffer || buffer->size() == 0) continue;
751+
buffers.push_back(buffer);
752+
753+
// Write padding if not multiple of 8
754+
const auto remainder = static_cast<int>(
755+
bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size());
756+
if (remainder) {
757+
buffers.push_back(std::make_shared<arrow::Buffer>(kPaddingBytes, remainder));
758+
}
759+
}
760+
}
761+
762+
DCHECK_EQ(static_cast<int>(header_size), header_stream.ByteCount());
763+
}
764+
// Once header is written we add it as the first buffer in the output vector.
765+
buffers.insert(buffers.begin(), std::move(header_buf));
766+
767+
return buffers;
768+
}
769+
770+
// Read internal::FlightData from arrow::Buffer containing FlightData
771+
// protobuf without copying
772+
arrow::Result<arrow::flight::internal::FlightData> DeserializeFlightData(
773+
const std::shared_ptr<arrow::Buffer>& buffer) {
774+
if (!buffer) {
775+
return Status::Invalid("No payload");
776+
}
777+
778+
arrow::flight::internal::FlightData out;
779+
780+
auto buffer_length = static_cast<int>(buffer->size());
781+
CodedInputStream pb_stream(buffer->data(), buffer_length);
782+
783+
pb_stream.SetTotalBytesLimit(buffer_length);
784+
785+
// This is the bytes remaining when using CodedInputStream like this
786+
while (pb_stream.BytesUntilTotalBytesLimit()) {
787+
const uint32_t tag = pb_stream.ReadTag();
788+
const int field_number = WireFormatLite::GetTagFieldNumber(tag);
789+
switch (field_number) {
790+
case pb::FlightData::kFlightDescriptorFieldNumber: {
791+
pb::FlightDescriptor pb_descriptor;
792+
uint32_t length;
793+
if (!pb_stream.ReadVarint32(&length)) {
794+
return Status::Invalid("Unable to parse length of FlightDescriptor");
795+
}
796+
// Can't use ParseFromCodedStream as this reads the entire
797+
// rest of the stream into the descriptor command field.
798+
std::string buffer;
799+
if (!pb_stream.ReadString(&buffer, length)) {
800+
return Status::Invalid("Unable to read FlightDescriptor from protobuf");
801+
}
802+
if (!pb_descriptor.ParseFromString(buffer)) {
803+
return Status::Invalid("Unable to parse FlightDescriptor");
804+
}
805+
arrow::flight::FlightDescriptor descriptor;
806+
ARROW_RETURN_NOT_OK(
807+
arrow::flight::internal::FromProto(pb_descriptor, &descriptor));
808+
out.descriptor = std::make_unique<arrow::flight::FlightDescriptor>(descriptor);
809+
} break;
810+
case pb::FlightData::kDataHeaderFieldNumber: {
811+
if (!ReadBytesZeroCopy(buffer, &pb_stream, &out.metadata)) {
812+
return Status::Invalid("Unable to read FlightData metadata");
813+
}
814+
} break;
815+
case pb::FlightData::kAppMetadataFieldNumber: {
816+
if (!ReadBytesZeroCopy(buffer, &pb_stream, &out.app_metadata)) {
817+
return Status::Invalid("Unable to read FlightData application metadata");
818+
}
819+
} break;
820+
case pb::FlightData::kDataBodyFieldNumber: {
821+
if (!ReadBytesZeroCopy(buffer, &pb_stream, &out.body)) {
822+
return Status::Invalid("Unable to read FlightData body");
823+
}
824+
} break;
825+
default: {
826+
// Unknown field. We should skip it for compatibility.
827+
if (!WireFormatLite::SkipField(&pb_stream, tag)) {
828+
return Status::Invalid("Could not skip unknown field tag in FlightData");
829+
}
830+
break;
831+
}
832+
}
833+
}
834+
835+
// TODO(wesm): Where and when should we verify that the FlightData is not
836+
// malformed?
837+
838+
// Set the default value for an unspecified FlightData body. The other
839+
// fields can be null if they're unspecified.
840+
if (out.body == nullptr) {
841+
out.body = std::make_shared<Buffer>(nullptr, 0);
842+
}
843+
844+
return out;
845+
}
846+
615847
} // namespace internal
616848
} // namespace flight
617849
} // namespace arrow

cpp/src/arrow/flight/serialization_internal.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,15 @@ ARROW_FLIGHT_EXPORT Status ToProto(const CloseSessionResult& result,
182182

183183
Status ToPayload(const FlightDescriptor& descr, std::shared_ptr<Buffer>* out);
184184

185+
/// \brief Serialize a FlightPayload to a vector of buffers.
186+
ARROW_FLIGHT_EXPORT
187+
arrow::Result<arrow::BufferVector> SerializePayloadToBuffers(const FlightPayload& msg);
188+
189+
/// \brief Deserialize FlightData from a contiguous buffer.
190+
ARROW_FLIGHT_EXPORT
191+
arrow::Result<internal::FlightData> DeserializeFlightData(
192+
const std::shared_ptr<arrow::Buffer>& buffer);
193+
185194
// We want to reuse RecordBatchStreamReader's implementation while
186195
// (1) Adapting it to the Flight message format
187196
// (2) Allowing pure-metadata messages before data is sent

cpp/src/arrow/flight/transport.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class FlightStatusDetail;
7676
namespace internal {
7777

7878
/// Internal, not user-visible type used for memory-efficient reads
79-
struct FlightData {
79+
struct ARROW_FLIGHT_EXPORT FlightData {
8080
/// Used only for puts, may be null
8181
std::unique_ptr<FlightDescriptor> descriptor;
8282

0 commit comments

Comments
 (0)