Created
July 14, 2017 12:50
-
-
Save tvoinarovskyi/05a5d083a0f96cae3e9b4c2af580be74 to your computer and use it in GitHub Desktop.
Kafka enhanced consumer using Thread Workers and consumer.pause().
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from kafka import ( | |
KafkaConsumer, TopicPartition, OffsetAndMetadata, ConsumerRebalanceListener | |
) | |
import queue | |
import threading | |
import time | |
import logging | |
log = logging.getLogger(__name__) | |
NUM_WORKERS = 5 | |
class Shutdown(Exception): | |
pass | |
class WorkerQueue(object): | |
def __init__(self): | |
self._queue = queue.Queue() | |
self._processing_tps = set([]) | |
self._finished = {} | |
self._lock = threading.Lock() | |
self._start_shutdown = False | |
def put(self, tp, messages): | |
assert isinstance(tp, TopicPartition) | |
self.check_shutdown() | |
with self._lock: | |
assert tp not in self._finished | |
self._queue.put_nowait((tp, messages)) | |
def get(self): | |
while True: | |
self.check_shutdown() | |
try: | |
return self._queue.get(timeout=1) | |
except queue.Empty: | |
continue | |
def join(self): | |
return self._queue.join() | |
def finished_processing(self, tp, last_message): | |
with self._lock: | |
self._queue.task_done() | |
self._finished[tp] = last_message | |
def get_finished(self): | |
with self._lock: | |
finished = self._finished | |
self._finished = {} | |
return finished | |
def drop_pending(self): | |
""" Remove any records that were not started processing by workers | |
""" | |
while True: | |
try: | |
self._queue.get_nowait() | |
except queue.Empty: | |
break | |
self._queue.task_done() | |
def start_shutdown(self): | |
self._start_shutdown = True | |
# Clear all items in queue, that has not started processing | |
self.drop_pending() | |
def check_shutdown(self): | |
if self._start_shutdown: | |
raise Shutdown() | |
class RebalanceListener(ConsumerRebalanceListener): | |
def __init__(self, worker_queue, consumer): | |
self._worker_queue = worker_queue | |
self._consumer = consumer | |
def _commit_finished(self): | |
# Unpause any finished ones and commit offsets | |
finished = self._worker_queue.get_finished() | |
paused = self._consumer.paused() | |
commit_offsets = {} | |
if finished: | |
log.warn("Committing %d partitions on revoke", len(finished)) | |
for tp, last_message in finished.items(): | |
assert tp in paused | |
self._consumer.resume(tp) | |
commit_offsets[tp] = OffsetAndMetadata( | |
last_message.offset + 1, "") | |
self._consumer.commit(commit_offsets) | |
def on_partitions_revoked(self, revoked): | |
""" Commit all processed items before rebalancing partitions """ | |
log.info("Revoking %d partitions", len(revoked)) | |
self._worker_queue.drop_pending() | |
# We commit before and after, as we may not succed later. `join()` can | |
# take a while. | |
self._commit_finished() | |
self._worker_queue.join() | |
self._commit_finished() | |
def on_partitions_assigned(self, assigned): | |
log.info("Assigned %d partitions", len(assigned)) | |
def worker_thread(worker_queue): | |
try: | |
log.info("Starting worker thread %s", threading.get_ident()) | |
while True: | |
tp, messages = worker_queue.get() | |
# Process messages | |
log.info("Processing %d messages from tp %s on tid=%s", | |
len(messages), tp, threading.get_ident()) | |
time.sleep(5) | |
worker_queue.finished_processing(tp, messages[-1]) | |
except Shutdown: | |
print("Worker thread {} shutdown".format(threading.get_ident())) | |
except Exception: | |
log.exception("Unexpected error in worker", exc_info=True) | |
def consumer_thread(worker_queue): | |
try: | |
log.info("Starting consumer thread") | |
# To consume latest messages and auto-commit offsets | |
consumer = KafkaConsumer( | |
group_id='my-group', | |
bootstrap_servers=['localhost:9092'], | |
enable_auto_commit=False, | |
auto_offset_reset="earliest") | |
rebalance_listener = RebalanceListener(worker_queue, consumer) | |
consumer.subscribe('my-new-topic', listener=rebalance_listener) | |
while True: | |
worker_queue.check_shutdown() | |
# You can use `max_records` to limit the number of results | |
msg_pack = consumer.poll(timeout_ms=1000) | |
log.info("poll() returned %s partitions, %s paused", | |
len(msg_pack), len(consumer.paused())) | |
for tp, messages in msg_pack.items(): | |
if messages: | |
worker_queue.put(tp, messages) | |
consumer.pause(tp) | |
# Unpause any finished ones and commit offsets | |
finished = worker_queue.get_finished() | |
paused = consumer.paused() | |
commit_offsets = {} | |
for tp, last_message in finished.items(): | |
assert tp in paused | |
consumer.resume(tp) | |
commit_offsets[tp] = OffsetAndMetadata( | |
last_message.offset + 1, "") | |
consumer.commit(commit_offsets) | |
except Shutdown: | |
log.info("Starting consumer thread shutdown") | |
# We have to commit all processed data on normal shutdowns | |
worker_queue.join() # Wait for all `get` data to finish processing | |
last_finished = worker_queue.get_finished() | |
if last_finished: | |
log.warn("Committing for last %s partitions", len(last_finished)) | |
commit_offsets = {} | |
for tp, last_message in last_finished.items(): | |
commit_offsets[tp] = OffsetAndMetadata( | |
last_message.offset + 1, "") | |
consumer.commit(commit_offsets) | |
print("Consumer thread shutdown") | |
except Exception: | |
log.exception("Unexpected error in consumer thread", exc_info=True) | |
worker_queue.start_shutdown() | |
consumer.close() | |
def main(): | |
logging.basicConfig(level=logging.INFO) | |
worker_queue = WorkerQueue() | |
c_thread = threading.Thread(target=consumer_thread, args=(worker_queue, )) | |
c_thread.start() | |
w_threads = [] | |
for i in range(NUM_WORKERS): | |
t = threading.Thread(target=worker_thread, args=(worker_queue, )) | |
t.start() | |
w_threads.append(t) | |
try: | |
while True: | |
time.sleep(1) | |
assert c_thread.is_alive() | |
except KeyboardInterrupt: | |
print("Received ctrl+C, shutting down...") | |
# All shutdown will be done through WorkerQueue instance, as the | |
# central bus. | |
worker_queue.start_shutdown() | |
worker_queue.join() | |
c_thread.join() | |
for w in w_threads: | |
w.join() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thank you