|
17 | 17 |
|
18 | 18 | #include "arrow/flight/serialization_internal.h" |
19 | 19 |
|
| 20 | +#include <limits> |
20 | 21 | #include <memory> |
21 | 22 | #include <string> |
22 | 23 |
|
23 | 24 | #include <google/protobuf/any.pb.h> |
| 25 | +#include <google/protobuf/io/coded_stream.h> |
| 26 | +#include <google/protobuf/io/zero_copy_stream_impl_lite.h> |
| 27 | +#include <google/protobuf/wire_format_lite.h> |
24 | 28 |
|
25 | 29 | #include "arrow/buffer.h" |
26 | 30 | #include "arrow/flight/protocol_internal.h" |
27 | 31 | #include "arrow/io/memory.h" |
| 32 | +#include "arrow/ipc/message.h" |
28 | 33 | #include "arrow/ipc/reader.h" |
29 | 34 | #include "arrow/ipc/writer.h" |
30 | 35 | #include "arrow/result.h" |
31 | 36 | #include "arrow/status.h" |
| 37 | +#include "arrow/util/logging_internal.h" |
32 | 38 |
|
33 | 39 | // Lambda helper & CTAD |
34 | 40 | template <class... Ts> |
@@ -612,6 +618,232 @@ Status ToProto(const CloseSessionResult& result, pb::CloseSessionResult* pb_resu |
612 | 618 | return Status::OK(); |
613 | 619 | } |
614 | 620 |
|
| 621 | +namespace { |
| 622 | +using google::protobuf::internal::WireFormatLite; |
| 623 | +using google::protobuf::io::ArrayOutputStream; |
| 624 | +using google::protobuf::io::CodedInputStream; |
| 625 | +using google::protobuf::io::CodedOutputStream; |
| 626 | +static constexpr int64_t kInt32Max = std::numeric_limits<int32_t>::max(); |
| 627 | +const uint8_t kPaddingBytes[8] = {0, 0, 0, 0, 0, 0, 0, 0}; |
| 628 | + |
| 629 | +// Update the sizes of our Protobuf fields based on the given IPC payload. |
| 630 | +arrow::Status IpcMessageHeaderSize(const arrow::ipc::IpcPayload& ipc_msg, bool has_body, |
| 631 | + size_t* header_size, int32_t* metadata_size) { |
| 632 | + DCHECK_LE(ipc_msg.metadata->size(), kInt32Max); |
| 633 | + *metadata_size = static_cast<int32_t>(ipc_msg.metadata->size()); |
| 634 | + |
| 635 | + // 1 byte for metadata tag |
| 636 | + *header_size += 1 + WireFormatLite::LengthDelimitedSize(*metadata_size); |
| 637 | + |
| 638 | + // 2 bytes for body tag |
| 639 | + if (has_body) { |
| 640 | + // We write the body tag in the header but not the actual body data |
| 641 | + *header_size += 2 + WireFormatLite::LengthDelimitedSize(ipc_msg.body_length) - |
| 642 | + ipc_msg.body_length; |
| 643 | + } |
| 644 | + |
| 645 | + return arrow::Status::OK(); |
| 646 | +} |
| 647 | + |
| 648 | +bool ReadBytesZeroCopy(const std::shared_ptr<Buffer>& source_data, |
| 649 | + CodedInputStream* input, std::shared_ptr<Buffer>* out) { |
| 650 | + uint32_t length; |
| 651 | + if (!input->ReadVarint32(&length)) { |
| 652 | + return false; |
| 653 | + } |
| 654 | + auto buf = |
| 655 | + SliceBuffer(source_data, input->CurrentPosition(), static_cast<int64_t>(length)); |
| 656 | + *out = buf; |
| 657 | + return input->Skip(static_cast<int>(length)); |
| 658 | +} |
| 659 | + |
| 660 | +} // namespace |
| 661 | + |
| 662 | +arrow::Result<arrow::BufferVector> SerializePayloadToBuffers(const FlightPayload& msg) { |
| 663 | + // Size of the IPC body (protobuf: data_body) |
| 664 | + size_t body_size = 0; |
| 665 | + // Size of the Protobuf "header" (everything except for the body) |
| 666 | + size_t header_size = 0; |
| 667 | + // Size of IPC header metadata (protobuf: data_header) |
| 668 | + int32_t metadata_size = 0; |
| 669 | + |
| 670 | + // Write the descriptor if present |
| 671 | + int32_t descriptor_size = 0; |
| 672 | + if (msg.descriptor != nullptr) { |
| 673 | + DCHECK_LE(msg.descriptor->size(), kInt32Max); |
| 674 | + descriptor_size = static_cast<int32_t>(msg.descriptor->size()); |
| 675 | + header_size += 1 + WireFormatLite::LengthDelimitedSize(descriptor_size); |
| 676 | + } |
| 677 | + |
| 678 | + // App metadata tag if appropriate |
| 679 | + int32_t app_metadata_size = 0; |
| 680 | + if (msg.app_metadata && msg.app_metadata->size() > 0) { |
| 681 | + DCHECK_LE(msg.app_metadata->size(), kInt32Max); |
| 682 | + app_metadata_size = static_cast<int32_t>(msg.app_metadata->size()); |
| 683 | + header_size += 1 + WireFormatLite::LengthDelimitedSize(app_metadata_size); |
| 684 | + } |
| 685 | + |
| 686 | + const arrow::ipc::IpcPayload& ipc_msg = msg.ipc_message; |
| 687 | + // No data in this payload (metadata-only). |
| 688 | + bool has_ipc = ipc_msg.type != ipc::MessageType::NONE; |
| 689 | + bool has_body = has_ipc ? ipc::Message::HasBody(ipc_msg.type) : false; |
| 690 | + |
| 691 | + if (has_ipc) { |
| 692 | + DCHECK(has_body || ipc_msg.body_length == 0); |
| 693 | + ARROW_RETURN_NOT_OK( |
| 694 | + IpcMessageHeaderSize(ipc_msg, has_body, &header_size, &metadata_size)); |
| 695 | + body_size = static_cast<size_t>(ipc_msg.body_length); |
| 696 | + } |
| 697 | + |
| 698 | + // TODO(wesm): messages over 2GB unlikely to be yet supported |
| 699 | + // Validated in WritePayload since returning error here causes gRPC to fail an assertion |
| 700 | + DCHECK_LE(body_size, kInt32Max); |
| 701 | + |
| 702 | + // Allocate and initialize buffers |
| 703 | + arrow::BufferVector buffers; |
| 704 | + ARROW_ASSIGN_OR_RAISE(auto header_buf, arrow::AllocateBuffer(header_size)); |
| 705 | + |
| 706 | + // Force the header_stream to be destructed, which actually flushes |
| 707 | + // the data into the slice. |
| 708 | + { |
| 709 | + ArrayOutputStream header_writer(const_cast<uint8_t*>(header_buf->mutable_data()), |
| 710 | + static_cast<int>(header_size)); |
| 711 | + CodedOutputStream header_stream(&header_writer); |
| 712 | + |
| 713 | + // Write descriptor |
| 714 | + if (msg.descriptor != nullptr) { |
| 715 | + WireFormatLite::WriteTag(pb::FlightData::kFlightDescriptorFieldNumber, |
| 716 | + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); |
| 717 | + header_stream.WriteVarint32(descriptor_size); |
| 718 | + header_stream.WriteRawMaybeAliased(msg.descriptor->data(), |
| 719 | + static_cast<int>(msg.descriptor->size())); |
| 720 | + } |
| 721 | + |
| 722 | + // Write header |
| 723 | + if (has_ipc) { |
| 724 | + WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber, |
| 725 | + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); |
| 726 | + header_stream.WriteVarint32(metadata_size); |
| 727 | + header_stream.WriteRawMaybeAliased(ipc_msg.metadata->data(), |
| 728 | + static_cast<int>(ipc_msg.metadata->size())); |
| 729 | + } |
| 730 | + |
| 731 | + // Write app metadata |
| 732 | + if (app_metadata_size > 0) { |
| 733 | + WireFormatLite::WriteTag(pb::FlightData::kAppMetadataFieldNumber, |
| 734 | + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); |
| 735 | + header_stream.WriteVarint32(app_metadata_size); |
| 736 | + header_stream.WriteRawMaybeAliased(msg.app_metadata->data(), |
| 737 | + static_cast<int>(msg.app_metadata->size())); |
| 738 | + } |
| 739 | + |
| 740 | + if (has_body) { |
| 741 | + // Write body tag |
| 742 | + WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, |
| 743 | + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); |
| 744 | + header_stream.WriteVarint32(static_cast<uint32_t>(body_size)); |
| 745 | + |
| 746 | + // Enqueue body buffers for writing, without copying |
| 747 | + for (const auto& buffer : ipc_msg.body_buffers) { |
| 748 | + // Buffer may be null when the row length is zero, or when all |
| 749 | + // entries are invalid. |
| 750 | + if (!buffer || buffer->size() == 0) continue; |
| 751 | + buffers.push_back(buffer); |
| 752 | + |
| 753 | + // Write padding if not multiple of 8 |
| 754 | + const auto remainder = static_cast<int>( |
| 755 | + bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size()); |
| 756 | + if (remainder) { |
| 757 | + buffers.push_back(std::make_shared<arrow::Buffer>(kPaddingBytes, remainder)); |
| 758 | + } |
| 759 | + } |
| 760 | + } |
| 761 | + |
| 762 | + DCHECK_EQ(static_cast<int>(header_size), header_stream.ByteCount()); |
| 763 | + } |
| 764 | + // Once header is written we add it as the first buffer in the output vector. |
| 765 | + buffers.insert(buffers.begin(), std::move(header_buf)); |
| 766 | + |
| 767 | + return buffers; |
| 768 | +} |
| 769 | + |
| 770 | +// Read internal::FlightData from arrow::Buffer containing FlightData |
| 771 | +// protobuf without copying |
| 772 | +arrow::Result<arrow::flight::internal::FlightData> DeserializeFlightData( |
| 773 | + const std::shared_ptr<arrow::Buffer>& buffer) { |
| 774 | + if (!buffer) { |
| 775 | + return Status::Invalid("No payload"); |
| 776 | + } |
| 777 | + |
| 778 | + arrow::flight::internal::FlightData out; |
| 779 | + |
| 780 | + auto buffer_length = static_cast<int>(buffer->size()); |
| 781 | + CodedInputStream pb_stream(buffer->data(), buffer_length); |
| 782 | + |
| 783 | + pb_stream.SetTotalBytesLimit(buffer_length); |
| 784 | + |
| 785 | + // This is the bytes remaining when using CodedInputStream like this |
| 786 | + while (pb_stream.BytesUntilTotalBytesLimit()) { |
| 787 | + const uint32_t tag = pb_stream.ReadTag(); |
| 788 | + const int field_number = WireFormatLite::GetTagFieldNumber(tag); |
| 789 | + switch (field_number) { |
| 790 | + case pb::FlightData::kFlightDescriptorFieldNumber: { |
| 791 | + pb::FlightDescriptor pb_descriptor; |
| 792 | + uint32_t length; |
| 793 | + if (!pb_stream.ReadVarint32(&length)) { |
| 794 | + return Status::Invalid("Unable to parse length of FlightDescriptor"); |
| 795 | + } |
| 796 | + // Can't use ParseFromCodedStream as this reads the entire |
| 797 | + // rest of the stream into the descriptor command field. |
| 798 | + std::string buffer; |
| 799 | + if (!pb_stream.ReadString(&buffer, length)) { |
| 800 | + return Status::Invalid("Unable to read FlightDescriptor from protobuf"); |
| 801 | + } |
| 802 | + if (!pb_descriptor.ParseFromString(buffer)) { |
| 803 | + return Status::Invalid("Unable to parse FlightDescriptor"); |
| 804 | + } |
| 805 | + arrow::flight::FlightDescriptor descriptor; |
| 806 | + ARROW_RETURN_NOT_OK( |
| 807 | + arrow::flight::internal::FromProto(pb_descriptor, &descriptor)); |
| 808 | + out.descriptor = std::make_unique<arrow::flight::FlightDescriptor>(descriptor); |
| 809 | + } break; |
| 810 | + case pb::FlightData::kDataHeaderFieldNumber: { |
| 811 | + if (!ReadBytesZeroCopy(buffer, &pb_stream, &out.metadata)) { |
| 812 | + return Status::Invalid("Unable to read FlightData metadata"); |
| 813 | + } |
| 814 | + } break; |
| 815 | + case pb::FlightData::kAppMetadataFieldNumber: { |
| 816 | + if (!ReadBytesZeroCopy(buffer, &pb_stream, &out.app_metadata)) { |
| 817 | + return Status::Invalid("Unable to read FlightData application metadata"); |
| 818 | + } |
| 819 | + } break; |
| 820 | + case pb::FlightData::kDataBodyFieldNumber: { |
| 821 | + if (!ReadBytesZeroCopy(buffer, &pb_stream, &out.body)) { |
| 822 | + return Status::Invalid("Unable to read FlightData body"); |
| 823 | + } |
| 824 | + } break; |
| 825 | + default: { |
| 826 | + // Unknown field. We should skip it for compatibility. |
| 827 | + if (!WireFormatLite::SkipField(&pb_stream, tag)) { |
| 828 | + return Status::Invalid("Could not skip unknown field tag in FlightData"); |
| 829 | + } |
| 830 | + break; |
| 831 | + } |
| 832 | + } |
| 833 | + } |
| 834 | + |
| 835 | + // TODO(wesm): Where and when should we verify that the FlightData is not |
| 836 | + // malformed? |
| 837 | + |
| 838 | + // Set the default value for an unspecified FlightData body. The other |
| 839 | + // fields can be null if they're unspecified. |
| 840 | + if (out.body == nullptr) { |
| 841 | + out.body = std::make_shared<Buffer>(nullptr, 0); |
| 842 | + } |
| 843 | + |
| 844 | + return out; |
| 845 | +} |
| 846 | + |
615 | 847 | } // namespace internal |
616 | 848 | } // namespace flight |
617 | 849 | } // namespace arrow |
0 commit comments