How To Obtain Kafka Consumer Lags in Pyspark Structured Streaming (Part 1)

Antonio Si
9 min readApr 2, 2022
Photo by Alexander Popov on Unsplash

Pyspark is a common BigData computational engine used by data scientists. Intuit Stream Processing Platform (SPP) provides a common ecosystem to run and monitor pyspark streaming pipelines. One of the metrics that SPP monitors is the consumer lag which is the number of messages that have not been processed for each Kafka partition. A sample of consumer lag over time is illustrated below:

Consumer lag is an important metric as it indicates how far behind a pipeline is in processing incoming messages. For example, based on the above diagram, the number of messages in the backlog ranges from 50 to 450. In this 2-part series, I will describe two alternatives that a pyspark pipeline can use to compute consumer lags together with its strengths and limitations.

In part 1 of this series, I will first describe how we can use Spark StreamingQueryListener to compute consumer lag. Limitations with using StreamingQueryListener will also be described. In part 2, I will move on to describe how we use a Spark checkpointing file to compute consumer lag. I will also annotate my descriptions with code snippets for clarity.

Context

I would like to begin with the context of the definition of a pyspark pipeline. We assume each pyspark pipeline uses structured streaming. It consumes one or more Kafka topics, performs some transformation, and produces messages to a sink topic. There are two challenges to obtain consumer lag in this environment:

  1. A streaming pipeline does not commit the kafka offset. It is relying on checkpoint files to resume consumption from the kafka topic.
  2. It is not possible to get a hold of the kafka consumer object used internally by spark; therefore, it is not possible to obtain the latest kafka offset currently consumed by the kafka consumer.

We will need a different way to compute the consumer lag.

StreamingQueryListener

Spark streaming api provides a listener interface, StreamingQueryListener. This listener provides a callback to application code during the lifecycle of processing a microbatch. With the help of this interface, we will be able to obtain the last processed offset for each partition of each topic after a microbatch is processed.

This interface has 3 apis:

  • onQueryProgress: This api is called when a microbatch of a streaming query is processed.
  • onQueryStarted: This api is called when a streaming query is started.
  • onQueryTerminated: This api is called when a streaming query is terminated.

For our purpose, we only need to focus on the onQueryProgress api. This api is called whenever a microbatch of a streaming query is processed. When onQueryProgress is called, an object of QueryProgressEvent is supplied. I will go through the api of QueryProgressEvent in a bit more detail later. In a nutshell, QueryProgressEvent includes information about the last committed offset for each Kafka topic partition that the referenced microbatch has processed.

The apis, onQueryStarted and onQueryTerminated will be called whenever a streaming query is started and stopped respectively. They are not needed in this discussion, but I include them for the sake of completeness.

However, this interface is only available in JVM spark (scala/java). It is not available in pyspark. In order to take advantage of this interface in pyspark, we will need a way to expose the interface to pyspark. This can be achieved by adopting an observer pattern using py4j gateway.

Step 1: Implement an observer interface in Java

We define a simple interface which mirrors the interface of StreamingQueryListener. Concrete implementation of this interface will be provided by the pyspark pipeline and act as a proxy between pyspark and JVM spark.

public interface PythonObserver {
void onQueryProgress(Object event);

void onQueryStarted(Object event);

void onQueryTerminated(Object event);
}

Step 2: Implement a StreamingQueryListener concrete class in Java

We will need a concrete implementation of the StreamingQueryListener. This StreamingQueryListener class will take a PythonObserver object and delegate all events to the PythonObserver object.

public class PythonStreamingQueryListener extends StreamingQueryListener {
private final PythonObserver observer;

public PythonStreamingQueryListener(PythonObserver observer) {
this.observer = observer;
}

@Override
public void onQueryProgress(QueryProgressEvent event) {
observer.onQueryProgress(event);
}

@Override
public void onQueryStarted(QueryStartedEvent event) {
observer.onQueryStarted(event);
}

@Override
public void onQueryTerminated(QueryTerminatedEvent event) {
observer.onQueryTerminated(event);
}
}

Once these two classes are made available to pyspark, all we need is to provide a concrete implementation of PythonObserver and supplied as an argument to PythonStreamingListener in pyspark. Once these two pieces are hooked up, for every completion of a microbatch, a QueryProgressEvent will be forwarded to the PythonObserver provided by pyspark and PythonObserver can use the supplied information in QueryProgressEvent to compute the consumer lags.

Step 3: Provide an implementation of PythonObserver in pyspark

We will need a concrete implementation of the PythonObserver in pyspark. Later, we will have to register this as an observer to PythonStreamingQueryListener.

class StreamingObserver(Subject):
class Java:
implements = ["com.intuit.data.strmprocess.spark.observer.PythonObserver"]

def __init__(self, config: ConsumerLagConfig, processedOffsetCallback = None):
self.config = config
self.processedOffsetCallback = processedOffsetCallback

def onQueryProgress(self, queryProgressEvent):
try:
processedOffset = self._getSourceLastProcessedOffset(queryProgressEvent)

if self.processedOffsetCallback:
self.processedOffsetCallback(processedOffset)
except Exception as e:
error(StreamingObserver._LOGGER, str(e))
raise e

def _getSourceLastProcessedOffset(self, queryProgressEvent) -> Dict[str, Dict[str, str]]:
sources = queryProgressEvent.progress().sources()
sourceOffset = {}
for source in sources:
endOffset = json.loads(source.endOffset())
sourceOffset.update(endOffset)
return sourceOffset

def onQueryStarted(self, queryStartedEvent):
...

def onQueryTerminated(self, queryTerminatedEvent):
...

Let’s walk through this code snippet. Since PythonObserver is a java class, we will need to make use of the spark java gateway. Java gateway is a python module that launches a gateway process to establish a communication channel between python and java code. One can refer to the documentation for additional details about java gateway.

It first defines this class as an implementation of the Java interface PythonObserver that we defined earlier:

class Java:
implements = ["com.intuit.data.strmprocess.spark.observer.PythonObserver"]

The constructor of this class, __init__, takes in 2 arguments, an object of ConsumerLagConfig and a callback function. The callback function is used later to push the committed offsets out for further processing to compute the consumer lag.

The rest of the class definition defines the implementation of the 3 apis defined by the PythonObserver interface. As indicated, onQueryStarted and onQueryTerminated apis are virtually noOp in this context. They are included here for the sake of completeness.

The bulk of implementation for this class is in the api onQueryProcessed. The queryProgressEvent object contains the latest processed offset of the last processed microbatch. It contains an endOffset api which returns a json string for each source topic. Each json string contains the last processed offset for each partition. In our implementation, we parse the json string and create a Dictionary of last processed offset keyed by the topic and partition. For more details of the QueryProgressEvent object, please refer to spark javadoc.

Step 4: Register StreamingObserver with PythonStreamingQueryListener in pyspark

Now that the PythonObserver is defined, we will need to supply it to the PythonStreamingQueryListener. This is indicated in the following code snippet.

def addListener(self, listener):
jvm = SparkContext._active_spark_context._jvm
jlistener = jvm.com.intuit.data.strmprocess.spark.observer.PythonStreamingQueryListener(
listener
)
self._jsqm.addListener(jlistener)

return jlistener

In this function definition, the supplied listener is an instance of the StreamingObserver class that we defined in step 3 above. An instance of PythonStreamingQueryListener is instantiated by supplying an instance of StreamingObserver as its argument. The remaining code simply registers the PythonStreamingQueryListener object as a listener to SparkContext.

Step 5: Obtaining Latest Offset

Once the last processed offset of a microbatch is obtained, all we need is the current latest offset of the source topic. The difference between the two will be the consumer lag.

To obtain the current latest offset of the source topic, we will need the help of a KafkaConsumer.

class KafkaOffsetManager:
def __init__(self, consumerLagConfig: ConsumerLagConfig):
self.topics = consumerLagConfig.getTopics()
self.consumerLagConfig = consumerLagConfig
self.consumer = self.getKafkaConsumer()
self.topicPartitions = {}
for topic in self.topics:
partitions = [TopicPartition(topic, p) for p in self.consumer.partitions_for_topic(topic)]
self.topicPartitions[topic] = partitions

def getLatestPartitionOffset(self) -> Dict[str, Dict[str, str]]:
lastOffsets = {}
for topic in self.topics:
partitions = self.topicPartitions[topic]
lastOffsetPerPartition = self.consumer.end_offsets(partitions)
lastOffsets[topic] = lastOffsetPerPartition

return lastOffsets

def getKafkaConsumer(self):
return KafkaConsumer(*self.topics, bootstrap_servers=self.consumerLagConfig.getBootstrapServer(), security_protocol=self.consumerLagConfig.getSecurityProtocol())

In the constructor, __init__, we create a KafkaConsumer object in the constructor and obtain all topic partitions with the partitions_for_topic api provided by KafkaConsumer.

The getLatestPartitionOffset api obtains the latest offset for each partition via the end_offsets() api. Again, we return a Dictionary of offsets keyed by the topic and partition.

It is important to note that creating an object of KafkaConsumer is an expensive operation and thus, we create an object of KafkaConsumer in the constructor and reuse it every time getLatestPartitionOffset is called. If an object of KafkaConsumer is created every time getLatestPartitionOffset is called, it will unnecessarily increase the overhead of getLatestPartitionOffset which in turn, will obtain a higher offset value. This will create a false impression of high consumer lag.

Step 6: Tying All Together With a Driver Class

Now that we have the last processed offset and the latest offset for each topic partition, all we need is to compute the difference to obtain the consumer lag. We achieve this via a ConsumerLagManager class.

class ConsumerLagManager:
def __init__(self, consumerLagConfig: ConsumerLagConfig, consumerLagCallback: None):
if consumerLagCallback:
self.consumerLagCallback = consumerLagCallback
self.consumerLagConfig = consumerLagConfig
self.kafkaOffsetManager = KafkaOffsetManager(consumerLagConfig)
if (consumerLagConfig.getConsumerLagStrategy() == ConsumerLagStrategy.PASSIVE):
PassiveConsumerLagManager(consumerLagConfig, self.computeConsumerLag)
else:
ProactiveConsumerLagManager(consumerLagConfig, self.computeConsumerLag)

def computeConsumerLag(self, latestProcessedOffset: Dict[str, Dict[str, str]]):
latestOffset = self.kafkaOffsetManager.getLatestPartitionOffset()

offsetLag = self.__computeOffsetLag(latestOffset, latestProcessedOffset)

if self.consumerLagCallback:
self.consumerLagCallback(offsetLag)

def __computeOffsetLag(self, latestOffset: Dict[str, Dict[str, str]], latestProcessedOffset: Dict[str, Dict[str, str]]):
offsetLagsForAllTopics = {}
for topic in self.consumerLagConfig.getTopics():
offsetLagForTopic = {}
processedPartitionsOffset = latestProcessedOffset[topic]
latestPartitionsOffset = latestOffset[topic]
for partition in processedPartitionsOffset:
commitOffsetForPartition = processedPartitionsOffset[partition]
topicP = TopicPartition(topic, int(partition))
latestPartitionOffset = latestPartitionsOffset[topicP]
offsetLagForTopic[partition] = latestPartitionOffset - commitOffsetForPartition
offsetLagsForAllTopics[topic] = offsetLagForTopic
return offsetLagsForAllTopics

The constructor, __init__, takes in a ConsumerLagConfig object and a callback function, consumerLagCallback.

The ConsumerLagConfig is a simple wrapper for different configuration parameters:

class ConsumerLagConfig:
def __init__(self, consumerLagStrategy: ConsumerLagStrategy, checkpointLocation: str, topics: str, bootstrapServer: str, securityProtocol: str):
self.__topics = topics
self.__checkpointLocation = checkpointLocation
self.__bootstrapServer = bootstrapServer
self.__consumerLagStrategy = consumerLagStrategy
self.__securityProtocol = securityProtocol

def getTopics(self) -> Tuple[str]:
topics = []
if self.__topics is not None:
topics.extend(self.__topics.split(","))
return topics

def getCheckpointLocation(self) -> str:
return self.__checkpointLocation

def getBootstrapServer(self) -> str:
return self.__bootstrapServer

def getConsumerLagStrategy(self) -> str:
return self.__consumerLagStrategy

def getSecurityProtocol(self) -> str:
return self.__securityProtocol

def getCheckpointReader(self) -> CheckpointReader:
if self.__checkpointLocation.startswith("s3"):
return S3CheckpointReader()
else:
return LocalCheckpointReader()

The callback function is used to process the computed consumer lag. A caller can pass in a function to output the consumer lag to the console or stream the consumer lag to any metric monitoring platform such as wavefront.

The code

if (consumerLagConfig.getConsumerLagStrategy() == ConsumerLagStrategy.PASSIVE):
PassiveConsumerLagManager(consumerLagConfig, self.computeConsumerLag)
else:
ProactiveConsumerLagManager(consumerLagConfig, self.computeConsumerLag)

decides which consumer lag computation mechanism to use. To use the mechanism described in this article, we pass in a PASSIVE ConsumerLagStrategy, reflecting it being notified by the Spark engine passively. We will describe PROACTIVE ConsumerLagStrategy in part 2 of this series. At a very high level, the PROACTIVE strategy proactively looks at the latest checkpoint file to obtain the last processed offset.

The function computeConsumerLag is used as the callback function that got passed to StreamingObserver described in Step 3. When this function is called, the last processed offset of the recently completed microbatch is passed back. The KafkaOffsetManager is called to obtain the latest offset of the topic partitions. It then calls an internal helper function, __computeOffsetLag, to compute the difference between the latest offset and the last processed offset. The differences are returned by calling the consumerLagCallback callback function.

Step 7: Register StreamingQueryListener implementation with Spark Runtime

Finally, we will need to register the StreamingQueryListener implementation at Step 2 with Spark runtime. This could simply be achieved by adding the jar file to spark.driver.extraClassPath configuration parameter.

Limitations

This mechanism takes advantage of the existing Spark StreamingQueryListener for notification to compute the consumer lag. However, since StreamingQueryListener is not available in pySpark, effort is needed to wire the notification from JVM to pySpark.

StreamingQueryListener is a listener for microbatch lifecycle. In other words, this mechanism is only applicable for microbatch streaming pipelines. Pyspark also provides a continuous processing model which is not based on microbatch and thus will require a different mechanism for consumer lag computation. In part 2 of this series, I will describe another mechanism to compute consumer lag which is based on checkpointing information of a spark pipeline.

𝚂𝚙𝚎𝚌𝚒𝚊𝚕 𝚝𝚑𝚊𝚗𝚔𝚜 𝚝𝚘 𝚖𝚢 𝚌𝚘𝚕𝚕𝚎𝚊𝚐𝚞𝚎 JD Rosensweig at 𝙸𝚗𝚝𝚞𝚒𝚝 𝚏𝚘𝚛 𝚜𝚞𝚐𝚐𝚎𝚜𝚝𝚒𝚘𝚗𝚜 𝚊𝚗𝚍 𝚖𝚞𝚕𝚝𝚒𝚙𝚕𝚎 𝚛𝚘𝚞𝚗𝚍𝚜 𝚘𝚏 𝚎𝚍𝚒𝚝𝚒𝚗𝚐 𝚘𝚗 𝚎𝚊𝚛𝚕𝚢 𝚍𝚛𝚊𝚏𝚝 𝚘𝚏 𝚝𝚑𝚒𝚜 𝚊𝚛𝚝𝚒𝚌𝚕𝚎.

--

--