How To Obtain Kafka Consumer Lags in Pyspark Structured Streaming (Part 2)

Antonio Si
7 min readJul 8, 2022
Photo by Derek Lee on Unsplash

In this article, I will describe another alternative to obtaining consumer lag based on spark checkpoint file. As I did in my previous article, I will annotate my descriptions with code snippets for clarity. Finally, I will conclude with the strengths and limitations of this approach.

Context

I would like to begin by revisiting the context of the definition of a pyspark pipeline. We assume each pyspark pipeline uses structured streaming. It consumes one or more Kafka topics, runs some transformations, and produces messages to a sink topic. The pyspark pipeline is utilizing the continuous processing model.

Bird’s Eye View

Ideally, we would hope that the continuous processing model will provide some kind of callback or listener similar to microbatch processing and provide the offset of each partition for each topic that the pipeline has last processed. That would allow us to obtain the lag in a way similar to microbatch processing. Unfortunately, I could not find any such design in spark continuous processing.

To obtain the lag, I have to employ a somewhat hacky, though still quite generic approach.

A streaming pipeline usually makes use of checkpointing for resilient operation. In spark, its checkpointing file is a text file containing the state of the application as shown below:

v1
{“batchWatermarkMs”:0,”batchTimestampMs”:1631478980101,”conf”:{“spark.sql.streaming.stateStore.providerClass”:”org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider”,”spark.sql.streaming.join.stateFormatVersion”:”2",”spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion”:”2",”spark.sql.streaming.multipleWatermarkPolicy”:”min”,”spark.sql.streaming.aggregation.stateFormatVersion”:”2",”spark.sql.shuffle.partitions”:”200"}}
{“test-idp-session-input-take1”:{“2”:83485,”5":83414,”4":84307,”1":1293554,”3":83053,”0":81912}}

In our context, we are only interested in the last line of the checkpoint file. The last line contains the offset of each partition for each topic that the pyspark pipeline has last processed. In other words, similar to QueryProgressEvent provided by StreamingQueryListener of microbatch processing, we obtain the last processed offset from a checkpoint file when a pyspark pipeline is running under the continuous model. Since spark will not push the last processed offset to the pipeline, our approach will proactively pull the latest checkpoint file periodically and retrieve the last processed offset.

The latest offset of a topic can still be obtained from a KafkaConsumer as we had explained in part 1 and the consumer lag can be obtained by computing the differences.

Step 1: Obtaining the latest checkpoint file

When a pipeline needs to get the current consumer lag, it first needs to identify the latest checkpoint file. That will provide the last offset that has been processed.

class LocalCheckpointReader(CheckpointReader):
_LOGGER = logging.getLogger("LocalCheckpointReader")

def getLastOffsetFile(self, checkpointLocation: str) -> str:
info(LocalCheckpointReader._LOGGER, "getLastOffsetFile, checkpointLocation={}".format(checkpointLocation))

list_of_files = glob.glob(checkpointLocation+'/*/offsets/*')
latest_file = max(list_of_files, key=os.path.getmtime)

info(LocalCheckpointReader._LOGGER, "latest offset path: {}".format(latest_file))

return latest_file

def _getLastLineOfFile(self, fileName: str) -> str:
with open(fileName, 'r') as f:
lastLine = f.readlines()[-1]
info(LocalCheckpointReader._LOGGER, "lastline: {}".format(lastLine))

return lastLine

In the above code snippet, the function getLastOffsetFile shows a simple example of obtaining the latest checkpoint file given a checkpoint path. This simple code assumes the checkpoint location is on a local file system. We also provide an example of how to obtain the latest checkpoint file from S3 in the appendix.

Step 2: Obtaining the processed offset

The above code snippet also shows a function, _getLastLineOfFile, to return the last line of the checkpoint file. As mentioned earlier, this line contains information regarding the offset for each partition of each topic consumed by the pipeline. To obtain the committed offsets for processing, we need to parse this line into a python dictionary. This is defined in the readLatestCommitOffsetFromFile function provided by the superclass, CheckpointReader.

class CheckpointReader(ABC):
@abstractmethod
def getLastOffsetFile(self, checkpointLocation: str) -> str:
pass

@abstractmethod
def _getLastLineOfFile(self, fileName: str) -> str:
pass

def readLatestCommitOffsetFromFile(self, offsetFileName: str) -> Dict:
try:
data = json.loads(self._getLastLineOfFile(offsetFileName))

return data
except:
return []

Step 3: Obtaining Latest Offset

Once the last processed offset is obtained, all we need is the current latest offset of the source topic. The difference between the two will be the consumer lag.

To obtain the current latest offset of the source topic, we will reuse the KafkaConsumer described in my previous article, which is illustrated and explained below again.

class KafkaOffsetManager:
def __init__(self, consumerLagConfig: ConsumerLagConfig):
self.topics = consumerLagConfig.getTopics()
self.consumerLagConfig = consumerLagConfig
self.consumer = self.getKafkaConsumer()
self.topicPartitions = {}
for topic in self.topics:
partitions = [TopicPartition(topic, p) for p in self.consumer.partitions_for_topic(topic)]
self.topicPartitions[topic] = partitions

def getLatestPartitionOffset(self) -> Dict[str, Dict[str, str]]:
lastOffsets = {}
for topic in self.topics:
partitions = self.topicPartitions[topic]
lastOffsetPerPartition = self.consumer.end_offsets(partitions)
lastOffsets[topic] = lastOffsetPerPartition

return lastOffsets

def getKafkaConsumer(self):
return KafkaConsumer(*self.topics, bootstrap_servers=self.consumerLagConfig.getBootstrapServer(), security_protocol=self.consumerLagConfig.getSecurityProtocol())

In the constructor, __init__, we create a KafkaConsumer object in the constructor and obtain all topic partitions with the partitions_for_topic() api provided by KafkaConsumer.

The getLatestPartitionOffset() api obtains the latest offset for each partition via the end_offsets() api. Again, we return a Dictionary of offsets keyed by the topic and partition.

As a reminder, it is important to note that creating an object of KafkaConsumer is an expensive operation and thus, we create an object of KafkaConsumer in the constructor and reuse it every time getLatestPartitionOffset() is called. If an object of KafkaConsumer is created every time getLatestPartitionOffset() is called, it will unnecessarily increase the overhead of getLatestPartitionOffset() which in turn, will obtain a higher offset value. This will create a false impression of high consumer lag.

Step 4: Tying All Together With a Driver Class

Now that we have the last processed offset and the latest offset for each topic partition, all we need is to compute the difference to obtain the consumer lag. We achieve this using the same ConsumerLagManager class as described in part 1.

class ConsumerLagManager:
def __init__(self, consumerLagConfig: ConsumerLagConfig, consumerLagCallback: None):
if consumerLagCallback:
self.consumerLagCallback = consumerLagCallback
self.consumerLagConfig = consumerLagConfig
self.kafkaOffsetManager = KafkaOffsetManager(consumerLagConfig)
if (consumerLagConfig.getConsumerLagStrategy() == ConsumerLagStrategy.PASSIVE):
PassiveConsumerLagManager(consumerLagConfig, self.computeConsumerLag)
else:
ProactiveConsumerLagManager(consumerLagConfig, self.computeConsumerLag)
def computeConsumerLag(self, latestProcessedOffset: Dict[str, Dict[str, str]]):
latestOffset = self.kafkaOffsetManager.getLatestPartitionOffset()

offsetLag = self.__computeOffsetLag(latestOffset, latestProcessedOffset)

if self.consumerLagCallback:
self.consumerLagCallback(offsetLag)

def __computeOffsetLag(self, latestOffset: Dict[str, Dict[str, str]], latestProcessedOffset: Dict[str, Dict[str, str]]):
offsetLagsForAllTopics = {}
for topic in self.consumerLagConfig.getTopics():
offsetLagForTopic = {}
processedPartitionsOffset = latestProcessedOffset[topic]
latestPartitionsOffset = latestOffset[topic]
for partition in processedPartitionsOffset:
commitOffsetForPartition = processedPartitionsOffset[partition]
topicP = TopicPartition(topic, int(partition))
latestPartitionOffset = latestPartitionsOffset[topicP]
offsetLagForTopic[partition] = latestPartitionOffset - commitOffsetForPartition
offsetLagsForAllTopics[topic] = offsetLagForTopic
return offsetLagsForAllTopics

The ConsumerLagConfig is a simple helper class that contains the configurations of a pipeline including the source Kafka topic, Kafka bootstrap server, the ConsumerLagStrategy, and the checkpoint location. If the ConsumerLagStrategy is PASSIVE, the checkpoint location is not relevant.

class ConsumerLagConfig:
def __init__(self, consumerLagStrategy: ConsumerLagStrategy, checkpointLocation: str, topics: str, bootstrapServer: str, securityProtocol: str):
self.__topics = topics
self.__checkpointLocation = checkpointLocation
self.__bootstrapServer = bootstrapServer
self.__consumerLagStrategy = consumerLagStrategy
self.__securityProtocol = securityProtocol

def getTopics(self) -> Tuple[str]:
topics = []
if self.__topics is not None:
topics.extend(self.__topics.split(","))
return topics

def getCheckpointLocation(self) -> str:
return self.__checkpointLocation

def getBootstrapServer(self) -> str:
return self.__bootstrapServer

def getConsumerLagStrategy(self) -> str:
return self.__consumerLagStrategy

def getSecurityProtocol(self) -> str:
return self.__securityProtocol

def getCheckpointReader(self) -> CheckpointReader:
if self.__checkpointLocation.startswith("s3"):
return S3CheckpointReader()
else:
return LocalCheckpointReader()

ConsumerLagStrategy is simply an enum including either PASSIVE or PROACTIVE.

class ConsumerLagStrategy(Enum):
PASSIVE = 1
PROACTIVE = 2

I name the strategy as PASSIVE and PROACTIVE, simply to indicate that the consumer lag is computed either via a call through the listener from the spark engine (PASSIVE) or by periodically probing the checkpoints to obtain the last processed offset and compute the lag (PROACTIVE).

If the consumer lag strategy is not PASSIVE, it will instantiate a ProactiveConsumerLagManager.

class ProactiveConsumerLagManager:
def __init__(self, config: ConsumerLagConfig, computeConsumerLag):
self.lagTrigger = SparkLagTrigger(config, computeConsumerLag)

This class makes use of a SparkLagTrigger object. This object spawns a thread that periodically gets the latest checkpoint file to obtain the latest processed offset.

class SparkLagTrigger:

def __init__(self, config: ConsumerLagConfig, processedOffsetCallback=None):
self.config = config
self.checkpointReader = config.getCheckpointReader()
self.processedOffsetCallback = processedOffsetCallback

self.scheduler = sched.scheduler(time.time, time.sleep)
self.scheduler.enter(30, 1, self.__scheduleConsumerOffsetReport, (self.scheduler,))

cease_continuous_run = threading.Event()

class ScheduleThread(threading.Thread):
@classmethod
def run(cls):
while not cease_continuous_run.is_set():
self.scheduler.run()
time.sleep(1)

continuous_thread = ScheduleThread()
continuous_thread.start()


def __scheduleConsumerOffsetReport(self, sc):
try:
offsetLag = self.__getLatestCommitOffset()
if self.processedOffsetCallback:
self.processedOffsetCallback(offsetLag)
self.scheduler.enter(1, 1, self.__scheduleConsumerOffsetReport, (sc,))
except Exception as e:
self.scheduler.enter(1, 1, self.__scheduleConsumerOffsetReport, (sc,))

def __getLatestCommitOffset(self) -> Dict[str, Dict[str, int]]:
latestCommitOffset = self.__readLatestCommitOffsetFromFile()

return latestCommitOffset

def __readLatestCommitOffsetFromFile(self):
try:
latestOffsetFile = self.checkpointReader.getLastOffsetFile(self.config.getCheckpointLocation())
data = self.checkpointReader.readLatestCommitOffsetFromFile(latestOffsetFile)

return data
except:
return []

In the constructor, __init__, it spawns a thread that runs a scheduler periodically. The scheduler simply calls the function __scheduleConsumerOffsetReport, which simply obtains the latest processed offset.

The __getLatestCommitOffset simply uses the checkpoint reader that was described earlier to get the latest processed offset from the latest checkpoint file.

Limitations and Extensions

This mechanism provides a way to compute consumer lag for the new continuous model. The major limitation of this mechanism is being able to obtain the latest checkpoint file of the pyspark pipeline. This article presents a way to obtain the latest checkpoint file when the checkpoint files are saved in the local file system.

Since the spark engine already handles checkpointing. It will be cleaner if the spark engine can be extended to push the last processed offset via a checkpoint listener. The pipeline can thus, simply register a listener to obtain the last processed offset.

Another option is to allow the pipeline/application to register a KafkaConsumer with the spark context for the spark processing engine to use. The pipeline can thus use the same KafkaConsumer to find out the last processed offset.

Appendix: Obtaining Checkpoint File from S3

The appendix illustrates a way to obtain the checkpoint file if the checkpoint files are saved in Amazon S3.

class S3CheckpointReader(CheckpointReader):

def __init__(self):
self.assumedRole = os.getenv(ASSUMED_ROLE_ENV_NAME)

def _getS3Client(self):
stsclient = boto3.client('sts')

assumedRoleObject = stsclient.assume_role(
RoleArn=self.assumedRole,
RoleSessionName="AssumeRoleSession")

creds = assumedRoleObject['Credentials']

s3Client = boto3.client('s3', aws_access_key_id=creds['AccessKeyId'], aws_secret_access_key=creds['SecretAccessKey'], aws_session_token=creds['SessionToken'],
region_name='us-west-2')
return s3Client

def _getS3Resource(self):
stsClient = boto3.client('sts')

assumedRoleObject = stsClient.assume_role(
RoleArn=self.assumedRole,
RoleSessionName="AssumeRoleSession")

creds = assumedRoleObject['Credentials']

s3Resource = boto3.resource('s3', aws_access_key_id=creds['AccessKeyId'], aws_secret_access_key=creds['SecretAccessKey'], aws_session_token=creds['SessionToken'],
region_name='us-west-2')
return s3Resource

def _getLastOffsetFileFromResponse(self, response):
responseContents = response['Contents']

offsetContents = filter(lambda fileObject: 'offsets' in fileObject['Key'] and 'offsets/.' not in fileObject['Key'], responseContents)

latestObject = max(offsetContents, key=lambda obj: obj['LastModified'])

return latestObject

def getLastOffsetFile(self, checkpointLocation: str) -> str:
try:
s3Tokens = self._getBucketAndPrefix(checkpointLocation)
s3Resource = self._getS3Client()

prefix = "{}/".format(s3Tokens[1]) if not s3Tokens[1].endswith('/') else s3Tokens[1]
kwargs = {'Bucket': s3Tokens[0], 'Prefix': prefix}
response = s3Resource.list_objects_v2(**kwargs)


lastestOffsetFile = self._getLastOffsetFileFromResponse(response)

# the response contains first 1000 keys. If there are more than 1000 directories,
# response['IsTruncated'] will be set to True and response['NextContinuationToken'] is provided
# to pass in subsequent list_objects_v2 call
while response['IsTruncated']:
continuation_token = response['NextContinuationToken']
kwargs = {'Bucket': s3Tokens[0], 'Prefix': prefix, 'ContinuationToken': continuation_token}

response = s3Resource.list_objects_v2(**kwargs)
latestOffsetFileOfThisBatch = self._getLastOffsetFileFromResponse(response)
if lastestOffsetFile['LastModified'] < latestOffsetFileOfThisBatch['LastModified']:
lastestOffsetFile = latestOffsetFileOfThisBatch

return "s3a://{}/{}".format(s3Tokens[0], lastestOffsetFile['Key'])
except Exception as e:
raise e

def _getBucketAndPrefix(self, checkpointLocation: str) -> str:
o = urlparse(checkpointLocation, allow_fragments=False)
# remove the first "/" for prefix
return o.netloc, o.path[1:] if o.path.startswith('/') else o.path

def _getObjectContent(self, fileName: str) -> str:
s3Tokens = self._getBucketAndPrefix(fileName)

tries = 1
while tries <= 3:
try:
s3Resource = self._getS3Resource()
response = s3Resource.Object(s3Tokens[0], s3Tokens[1]).get()

content = response['Body'].read()
return content
except Exception as e:
tries += 1
raise e

def _getLastLineOfFile(self, fileName: str) -> str:
content = self._getObjectContent(fileName).decode("utf-8")
return content.split('\n')[-1]

𝚂𝚙𝚎𝚌𝚒𝚊𝚕 𝚝𝚑𝚊𝚗𝚔𝚜 𝚝𝚘 𝚖𝚢 𝚌𝚘𝚕𝚕𝚎𝚊𝚐𝚞𝚎 JD Rosensweig at Intuit for his continuous support and editing of my articles.

--

--