Skip to content

[SPARK-52333][SS][PYTHON] Squeeze the protocol of retrieving timers for transformWithState in PySpark #51036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 81 additions & 77 deletions python/pyspark/sql/streaming/proto/StateMessage_pb2.py

Large diffs are not rendered by default.

76 changes: 75 additions & 1 deletion python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,74 @@ class StateResponseWithMapIterator(google.protobuf.message.Message):

global___StateResponseWithMapIterator = StateResponseWithMapIterator

class TimerInfo(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

KEY_FIELD_NUMBER: builtins.int
TIMESTAMPMS_FIELD_NUMBER: builtins.int
key: builtins.bytes
timestampMs: builtins.int
def __init__(
self,
*,
key: builtins.bytes | None = ...,
timestampMs: builtins.int = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["_key", b"_key", "key", b"key"]
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_key", b"_key", "key", b"key", "timestampMs", b"timestampMs"
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_key", b"_key"]
) -> typing_extensions.Literal["key"] | None: ...

global___TimerInfo = TimerInfo

class StateResponseWithTimer(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

STATUSCODE_FIELD_NUMBER: builtins.int
ERRORMESSAGE_FIELD_NUMBER: builtins.int
TIMER_FIELD_NUMBER: builtins.int
REQUIRENEXTFETCH_FIELD_NUMBER: builtins.int
statusCode: builtins.int
errorMessage: builtins.str
@property
def timer(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
global___TimerInfo
]: ...
requireNextFetch: builtins.bool
def __init__(
self,
*,
statusCode: builtins.int = ...,
errorMessage: builtins.str = ...,
timer: collections.abc.Iterable[global___TimerInfo] | None = ...,
requireNextFetch: builtins.bool = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"errorMessage",
b"errorMessage",
"requireNextFetch",
b"requireNextFetch",
"statusCode",
b"statusCode",
"timer",
b"timer",
],
) -> None: ...

global___StateResponseWithTimer = StateResponseWithTimer

class StatefulProcessorCall(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

Expand Down Expand Up @@ -634,15 +702,21 @@ global___TimerValueRequest = TimerValueRequest
class ExpiryTimerRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

ITERATORID_FIELD_NUMBER: builtins.int
EXPIRYTIMESTAMPMS_FIELD_NUMBER: builtins.int
iteratorId: builtins.str
expiryTimestampMs: builtins.int
def __init__(
self,
*,
iteratorId: builtins.str = ...,
expiryTimestampMs: builtins.int = ...,
) -> None: ...
def ClearField(
self, field_name: typing_extensions.Literal["expiryTimestampMs", b"expiryTimestampMs"]
self,
field_name: typing_extensions.Literal[
"expiryTimestampMs", b"expiryTimestampMs", "iteratorId", b"iteratorId"
],
) -> None: ...

global___ExpiryTimerRequest = ExpiryTimerRequest
Expand Down
156 changes: 109 additions & 47 deletions python/pyspark/sql/streaming/stateful_processor_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,13 @@
from pyspark.sql.pandas.serializers import ArrowStreamSerializer
from pyspark.sql.types import (
StructType,
TYPE_CHECKING,
Row,
)
from pyspark.sql.pandas.types import convert_pandas_using_numpy_type
from pyspark.serializers import CPickleSerializer
from pyspark.errors import PySparkRuntimeError
import uuid

if TYPE_CHECKING:
from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike

__all__ = ["StatefulProcessorApiClient", "StatefulProcessorHandleState"]


Expand Down Expand Up @@ -80,9 +76,11 @@ def __init__(
self.utf8_deserializer = UTF8Deserializer()
self.pickleSer = CPickleSerializer()
self.serializer = ArrowStreamSerializer()
# Dictionaries to store the mapping between iterator id and a tuple of pandas DataFrame
# Dictionaries to store the mapping between iterator id and a tuple of data batch
# and the index of the last row that was read.
self.list_timer_iterator_cursors: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
self.list_timer_iterator_cursors: Dict[str, Tuple[Any, int, bool]] = {}
self.expiry_timer_iterator_cursors: Dict[str, Tuple[Any, int, bool]] = {}

# statefulProcessorApiClient is initialized per batch per partition,
# so we will have new timestamps for a new batch
self._batch_timestamp = -1
Expand Down Expand Up @@ -222,76 +220,93 @@ def delete_timer(self, expiry_time_stamp_ms: int) -> None:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error deleting timer: " f"{response_message[1]}")

def get_list_timer_row(self, iterator_id: str) -> int:
def get_list_timer_row(self, iterator_id: str) -> Tuple[int, bool]:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage

if iterator_id in self.list_timer_iterator_cursors:
# if the iterator is already in the dictionary, return the next row
pandas_df, index = self.list_timer_iterator_cursors[iterator_id]
data_batch, index, require_next_fetch = self.list_timer_iterator_cursors[iterator_id]
else:
list_call = stateMessage.ListTimers(iteratorId=iterator_id)
state_call_command = stateMessage.TimerStateCallCommand(list=list_call)
call = stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
message = stateMessage.StateRequest(statefulProcessorCall=call)

self._send_proto_message(message.SerializeToString())
response_message = self._receive_proto_message()
response_message = self._receive_proto_message_with_timers()
status = response_message[0]
if status == 0:
iterator = self._read_arrow_state()
# We need to exhaust the iterator here to make sure all the arrow batches are read,
# even though there is only one batch in the iterator. Otherwise, the stream might
# block further reads since it thinks there might still be some arrow batches left.
# We only need to read the first batch in the iterator because it's guaranteed that
# there would only be one batch sent from the JVM side.
data_batch = None
for batch in iterator:
if data_batch is None:
data_batch = batch
if data_batch is None:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError("Error getting map state entry.")
pandas_df = data_batch.to_pandas()
data_batch = list(map(lambda x: x.timestampMs, response_message[2]))
require_next_fetch = response_message[3]
index = 0
else:
raise StopIteration()

is_last_row = False
new_index = index + 1
if new_index < len(pandas_df):
if new_index < len(data_batch):
# Update the index in the dictionary.
self.list_timer_iterator_cursors[iterator_id] = (pandas_df, new_index)
self.list_timer_iterator_cursors[iterator_id] = (
data_batch,
new_index,
require_next_fetch,
)
else:
# If the index is at the end of the DataFrame, remove the state from the dictionary.
# If the index is at the end of the data batch, remove the state from the dictionary.
self.list_timer_iterator_cursors.pop(iterator_id, None)
return pandas_df.at[index, "timestamp"].item()
is_last_row = True

is_last_row_from_iterator = is_last_row and not require_next_fetch
timestamp = data_batch[index]
return (timestamp, is_last_row_from_iterator)

def get_expiry_timers_iterator(
self, expiry_timestamp: int
) -> Iterator[list[Tuple[Tuple, int]]]:
self, iterator_id: str, expiry_timestamp: int
) -> Tuple[Tuple, int, bool]:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage

while True:
expiry_timer_call = stateMessage.ExpiryTimerRequest(expiryTimestampMs=expiry_timestamp)
if iterator_id in self.expiry_timer_iterator_cursors:
# If the state is already in the dictionary, return the next row.
data_batch, index, require_next_fetch = self.expiry_timer_iterator_cursors[iterator_id]
else:
expiry_timer_call = stateMessage.ExpiryTimerRequest(
expiryTimestampMs=expiry_timestamp, iteratorId=iterator_id
)
timer_request = stateMessage.TimerRequest(expiryTimerRequest=expiry_timer_call)
message = stateMessage.StateRequest(timerRequest=timer_request)

self._send_proto_message(message.SerializeToString())
response_message = self._receive_proto_message()
response_message = self._receive_proto_message_with_timers()
status = response_message[0]
if status == 1:
break
elif status == 0:
result_list = []
iterator = self._read_arrow_state()
for batch in iterator:
batch_df = batch.to_pandas()
for i in range(batch.num_rows):
deserialized_key = self.pickleSer.loads(batch_df.at[i, "key"])
timestamp = batch_df.at[i, "timestamp"].item()
result_list.append((tuple(deserialized_key), timestamp))
yield result_list
if status == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I am wrong: we are now not sending expiry timer in arrow batch, but in list of Rows - Is this because it improves performance by avoiding using arrow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things:

  1. We see the benefit of inlining the data into proto message to save one round-trip.
  2. Arrow is the columnar format, which is known to be efficient when there are multiple data. It's not a good usage (though sometimes needed) to use Arrow RecordBatch with small number of records. It "might be" a bit different when there are enough number of records, especially the fact that pickled Python Row looks to contain the "schema" as "json", which is not needed at all with Arrow RecordBatch. Haven't tested with large number.

data_batch = list(
map(
lambda x: (self._deserialize_from_bytes(x.key), x.timestampMs),
response_message[2],
)
)
require_next_fetch = response_message[3]
index = 0
else:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error getting expiry timers: " f"{response_message[1]}")
raise StopIteration()

is_last_row = False
new_index = index + 1
if new_index < len(data_batch):
# Update the index in the dictionary.
self.expiry_timer_iterator_cursors[iterator_id] = (
data_batch,
new_index,
require_next_fetch,
)
else:
# If the index is at the end of the data batch, remove the state from the dictionary.
self.expiry_timer_iterator_cursors.pop(iterator_id, None)
is_last_row = True

is_last_row_from_iterator = is_last_row and not require_next_fetch
key, timestamp = data_batch[index]
return (key, timestamp, is_last_row_from_iterator)

def get_timestamps(self, time_mode: str) -> Tuple[int, int]:
if time_mode.lower() == "none":
Expand Down Expand Up @@ -461,6 +476,18 @@ def _receive_proto_message_with_map_pairs(self) -> Tuple[int, str, Any, bool]:

return message.statusCode, message.errorMessage, message.kvPair, message.requireNextFetch

# The third return type is RepeatedScalarFieldContainer[TimerInfo], which is protobuf's
# container type. We simplify it to Any here to avoid unnecessary complexity.
def _receive_proto_message_with_timers(self) -> Tuple[int, str, Any, bool]:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage

length = read_int(self.sockfile)
bytes = self.sockfile.read(length)
message = stateMessage.StateResponseWithTimer()
message.ParseFromString(bytes)

return message.statusCode, message.errorMessage, message.timer, message.requireNextFetch

def _receive_str(self) -> str:
return self.utf8_deserializer.loads(self.sockfile)

Expand Down Expand Up @@ -552,9 +579,44 @@ def __init__(self, stateful_processor_api_client: StatefulProcessorApiClient):
# same partition won't interfere with each other
self.iterator_id = str(uuid.uuid4())
self.stateful_processor_api_client = stateful_processor_api_client
self.iterator_fully_consumed = False

def __iter__(self) -> Iterator[int]:
return self

def __next__(self) -> int:
return self.stateful_processor_api_client.get_list_timer_row(self.iterator_id)
if self.iterator_fully_consumed:
raise StopIteration()

ts, is_last_row = self.stateful_processor_api_client.get_list_timer_row(self.iterator_id)
if is_last_row:
self.iterator_fully_consumed = True

return ts


class ExpiredTimerIterator:
def __init__(
self, stateful_processor_api_client: StatefulProcessorApiClient, expiry_timestamp: int
):
# Generate a unique identifier for the iterator to make sure iterators on the
# same partition won't interfere with each other
self.iterator_id = str(uuid.uuid4())
self.stateful_processor_api_client = stateful_processor_api_client
self.expiry_timestamp = expiry_timestamp
self.iterator_fully_consumed = False

def __iter__(self) -> Iterator[Tuple[Tuple, int]]:
return self

def __next__(self) -> Tuple[Tuple, int]:
if self.iterator_fully_consumed:
raise StopIteration()

key, ts, is_last_row = self.stateful_processor_api_client.get_expiry_timers_iterator(
self.iterator_id, self.expiry_timestamp
)
if is_last_row:
self.iterator_fully_consumed = True

return (key, ts)
30 changes: 13 additions & 17 deletions python/pyspark/sql/streaming/stateful_processor_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
StatefulProcessorHandle,
TimerValues,
)
from pyspark.sql.streaming.stateful_processor_api_client import ExpiredTimerIterator
from pyspark.sql.types import Row

if TYPE_CHECKING:
Expand Down Expand Up @@ -218,24 +219,19 @@ def _handle_expired_timers(
)

if self._time_mode.lower() == "processingtime":
expiry_list_iter = stateful_processor_api_client.get_expiry_timers_iterator(
batch_timestamp
)
expiry_iter = ExpiredTimerIterator(stateful_processor_api_client, batch_timestamp)
elif self._time_mode.lower() == "eventtime":
expiry_list_iter = stateful_processor_api_client.get_expiry_timers_iterator(
watermark_timestamp
)
expiry_iter = ExpiredTimerIterator(stateful_processor_api_client, watermark_timestamp)
else:
expiry_list_iter = iter([[]])
expiry_iter = iter([]) # type: ignore[assignment]

# process with expiry timers, only timer related rows will be emitted
for expiry_list in expiry_list_iter:
for key_obj, expiry_timestamp in expiry_list:
stateful_processor_api_client.set_implicit_key(key_obj)
for pd in self._stateful_processor.handleExpiredTimer(
key=key_obj,
timerValues=TimerValues(batch_timestamp, watermark_timestamp),
expiredTimerInfo=ExpiredTimerInfo(expiry_timestamp),
):
yield pd
stateful_processor_api_client.delete_timer(expiry_timestamp)
for key_obj, expiry_timestamp in expiry_iter:
stateful_processor_api_client.set_implicit_key(key_obj)
for pd in self._stateful_processor.handleExpiredTimer(
key=key_obj,
timerValues=TimerValues(batch_timestamp, watermark_timestamp),
expiredTimerInfo=ExpiredTimerInfo(expiry_timestamp),
):
yield pd
stateful_processor_api_client.delete_timer(expiry_timestamp)
Loading