From 0343759478fb9444369666ff5381ab2a9081f217 Mon Sep 17 00:00:00 2001 From: mohammedtarek Date: Thu, 19 Mar 2026 03:21:23 +0200 Subject: [PATCH 1/3] feat: add queue_wait_seconds, task_errors_by_type, and pre_send timestamp to PrometheusMiddleware --- taskiq/abc/broker.py | 8 ++- taskiq/cli/worker/process_manager.py | 2 +- taskiq/middlewares/prometheus_middleware.py | 73 ++++++++++++++++++++- 3 files changed, 78 insertions(+), 5 deletions(-) diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index c0902371..47ca2dd9 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -32,10 +32,13 @@ from taskiq.utils import maybe_awaitable from taskiq.warnings import TaskiqDeprecationWarning -if sys.version_info >= (3, 11): +if sys.version_info >= ( + 3, + 11, +): # Check which python version are we running to import correctly from typing import Self else: - from typing_extensions import Self + from typing_extensions import Self # pragma: no cover if TYPE_CHECKING: # pragma: no cover @@ -46,6 +49,7 @@ _FuncParams = ParamSpec("_FuncParams") _ReturnType = TypeVar("_ReturnType") +# an event handler can be either a sync or an async function that has one parameter of type TaskiqState EventHandler: TypeAlias = Callable[[TaskiqState], Awaitable[None] | None] logger = getLogger("taskiq") diff --git a/taskiq/cli/worker/process_manager.py b/taskiq/cli/worker/process_manager.py index 22257e86..394b05d8 100644 --- a/taskiq/cli/worker/process_manager.py +++ b/taskiq/cli/worker/process_manager.py @@ -169,7 +169,7 @@ def __init__( for path_to_watch in watch_paths: logger.debug(f"Watching directory: {path_to_watch}") observer.schedule( - FileWatcher( + FileWatcher( # type: ignore callback=schedule_workers_reload, path=Path(path_to_watch), use_gitignore=not args.no_gitignore, diff --git a/taskiq/middlewares/prometheus_middleware.py b/taskiq/middlewares/prometheus_middleware.py index 56837cf3..c9df6a9f 100644 --- a/taskiq/middlewares/prometheus_middleware.py +++ b/taskiq/middlewares/prometheus_middleware.py @@ -1,9 +1,9 @@ +import datetime import os from logging import getLogger from pathlib import Path from tempfile import gettempdir from typing import Any - from taskiq.abc.middleware import TaskiqMiddleware from taskiq.message import TaskiqMessage from taskiq.result import TaskiqResult @@ -43,7 +43,7 @@ def __init__( logger.debug("Initializing metrics") try: - from prometheus_client import Counter, Histogram # noqa: PLC0415 + from prometheus_client import Counter, Histogram, Gauge # noqa: PLC0415 except ImportError as exc: raise ImportError( "Cannot initialize metrics. Please install 'taskiq[metrics]'.", @@ -74,6 +74,24 @@ def __init__( "Time of function execution", ["task_name"], ) + + self.in_flight_tasks = Gauge( + "in_flight_tasks", + "Number of tasks in flight", + ["task_name"], + multiprocess_mode="livesum", + ) + self.queue_wait_seconds = Histogram( + "queue_wait_seconds", + "time task spent in message queue", + ["task_name"], + ) + self.task_errors_by_type = Counter( + "task_errors_by_type", + "Number of errors raised in tasks by their type", + ["task_name", "error_type"], + ) + self.server_port = server_port self.server_addr = server_addr @@ -104,6 +122,24 @@ def startup(self) -> None: except OSError as exc: logger.debug("Cannot start prometheus server: %s", exc) + def pre_send( + self, + message: "TaskiqMessage", + ) -> "TaskiqMessage": + """ + Function to track the time a task spend in queue. + + This function tracks the time a task spends in a queue until it is executed. + + :param message: current message. + :return: message + """ + if not message.labels.get("_taskiq_enqueue_timestamp"): + message.labels["_taskiq_enqueue_timestamp"] = datetime.datetime.now( + datetime.UTC, + ).isoformat() # Might conside using timezones too + return message + def pre_execute( self, message: "TaskiqMessage", @@ -117,9 +153,41 @@ def pre_execute( :param message: current message. :return: message """ + if message.labels.get( + "_taskiq_enqueue_timestamp", + ): # Handle case where the sender doesn't use the prometheus middleware + time_delta = datetime.datetime.now( + datetime.UTC, + ) - datetime.datetime.fromisoformat( + message.labels["_taskiq_enqueue_timestamp"], + ) + time_delta = max(0, time_delta.total_seconds()) + self.queue_wait_seconds.labels(message.task_name).observe( + time_delta, + ) + + self.in_flight_tasks.labels(message.task_name).inc() self.received_tasks.labels(message.task_name).inc() return message + def on_error( + self, + message: TaskiqMessage, + result: TaskiqResult[Any], # pylint: disable=unused-argument + exception: BaseException, + ) -> None: + """ + This function tracks the number of errors raised by tasks. + + :param message: the received task message + :param result: the result of task + :param exception: exception raised + """ + self.task_errors_by_type.labels( + message.task_name, + type(exception).__name__, + ).inc() + def post_execute( self, message: "TaskiqMessage", @@ -135,6 +203,7 @@ def post_execute( self.found_errors.labels(message.task_name).inc() else: self.success_tasks.labels(message.task_name).inc() + self.in_flight_tasks.labels(message.task_name).dec() self.execution_time.labels(message.task_name).observe(result.execution_time) def post_save( From 2c48668d1e3c52c7a287b8c7df4455d59b4ec0d4 Mon Sep 17 00:00:00 2001 From: mohammedtarek Date: Thu, 19 Mar 2026 06:56:53 +0200 Subject: [PATCH 2/3] feat: add ReceiverObserver protocol and Prometheus receiver metrics Add production observability for the Receiver via an observer protocol that tracks prefetch queue depth, semaphore availability, active task count, unknown task lookups, and deserialization errors. - Add ReceiverObserver protocol (taskiq/receiver/observer.py) - Instrument Receiver with guarded observer callbacks at 5 sites - Add PrometheusReceiverObserver implementation with Gauges/Counters - Wire observer from middleware to receiver via broker attribute - Remove redundant in_flight_tasks gauge (replaced by active_tasks_count) --- taskiq/cli/worker/run.py | 3 +- taskiq/middlewares/prometheus_middleware.py | 65 ++++++++++++++++++--- taskiq/receiver/__init__.py | 3 +- taskiq/receiver/observer.py | 21 +++++++ taskiq/receiver/receiver.py | 44 +++++++++++++- 5 files changed, 125 insertions(+), 11 deletions(-) create mode 100644 taskiq/receiver/observer.py diff --git a/taskiq/cli/worker/run.py b/taskiq/cli/worker/run.py index 53cef7c0..831eb590 100644 --- a/taskiq/cli/worker/run.py +++ b/taskiq/cli/worker/run.py @@ -13,7 +13,7 @@ from taskiq.cli.utils import import_object, import_tasks from taskiq.cli.worker.args import WorkerArgs from taskiq.cli.worker.process_manager import ProcessManager -from taskiq.receiver import Receiver +from taskiq.receiver import Receiver, ReceiverObserver try: import uvloop @@ -163,6 +163,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None: receiver = receiver_type( broker=broker, executor=pool, + observer=getattr(broker, "_receiver_observer", None), validate_params=not args.no_parse, max_async_tasks=args.max_async_tasks, max_prefetch=args.max_prefetch, diff --git a/taskiq/middlewares/prometheus_middleware.py b/taskiq/middlewares/prometheus_middleware.py index c9df6a9f..7686715d 100644 --- a/taskiq/middlewares/prometheus_middleware.py +++ b/taskiq/middlewares/prometheus_middleware.py @@ -7,6 +7,7 @@ from taskiq.abc.middleware import TaskiqMiddleware from taskiq.message import TaskiqMessage from taskiq.result import TaskiqResult +from taskiq.receiver.observer import ReceiverObserver logger = getLogger("taskiq.prometheus") @@ -75,12 +76,6 @@ def __init__( ["task_name"], ) - self.in_flight_tasks = Gauge( - "in_flight_tasks", - "Number of tasks in flight", - ["task_name"], - multiprocess_mode="livesum", - ) self.queue_wait_seconds = Histogram( "queue_wait_seconds", "time task spent in message queue", @@ -166,7 +161,6 @@ def pre_execute( time_delta, ) - self.in_flight_tasks.labels(message.task_name).inc() self.received_tasks.labels(message.task_name).inc() return message @@ -203,9 +197,12 @@ def post_execute( self.found_errors.labels(message.task_name).inc() else: self.success_tasks.labels(message.task_name).inc() - self.in_flight_tasks.labels(message.task_name).dec() self.execution_time.labels(message.task_name).observe(result.execution_time) + def set_broker(self, broker: "AsyncBroker") -> None: # noqa: F821 pyright: ignore[reportUnknownVariableType] + super().set_broker(broker) + broker._receiver_observer = PrometheusReceiverObserver() + def post_save( self, message: "TaskiqMessage", @@ -218,3 +215,55 @@ def post_save( :param result: result of execution. """ self.saved_results.labels(message.task_name).inc() + + +class PrometheusReceiverObserver(ReceiverObserver): + """Receiver observer implementation for prometheus.""" + + def __init__(self) -> None: + try: + from prometheus_client import Counter, Gauge # noqa: PLC0415 + except ImportError as exc: + raise ImportError( + "Cannot initialize metrics. Please install 'taskiq[metrics]'.", + ) from exc + + self.prefetch_queue_size = Gauge( + "prefetch_queue_size", + "The number of task in the prefetch queue.", + multiprocess_mode="livesum", + ) + self.semaphore_available = Gauge( + "semaphore_available", + "Number of semaphore slots available in broker", + multiprocess_mode="livesum", + ) + self.active_tasks_count = Gauge( + "worker_active_tasks_count", + "Number of active tasks in worker", + multiprocess_mode="livesum", + ) + self.task_not_found_total = Counter( + "task_not_found_total", + "Number of times the worker got a task not registered", + ["task_name"], + ) + self.deserialize_error = Counter( + "deserialize_error_count", + "Number of times broker faced a desrialization error", + ) + + def on_prefetch_queue_size(self, size: int) -> None: + self.prefetch_queue_size.set(size) + + def on_semaphore_status(self, available: int) -> None: + self.semaphore_available.set(available) + + def on_active_tasks_count(self, count: int) -> None: + self.active_tasks_count.set(count) + + def on_task_not_found(self, task_name: str) -> None: + self.task_not_found_total.labels(task_name).inc() + + def on_deserialize_error(self, raw: bytes, error: Exception) -> None: + self.deserialize_error.inc() diff --git a/taskiq/receiver/__init__.py b/taskiq/receiver/__init__.py index c6a7e66b..b9527fb3 100644 --- a/taskiq/receiver/__init__.py +++ b/taskiq/receiver/__init__.py @@ -1,5 +1,6 @@ """Package for message receiver.""" from taskiq.receiver.receiver import Receiver +from taskiq.receiver.observer import ReceiverObserver -__all__ = ["Receiver"] +__all__ = ["Receiver", "ReceiverObserver"] diff --git a/taskiq/receiver/observer.py b/taskiq/receiver/observer.py new file mode 100644 index 00000000..0e0c2a0c --- /dev/null +++ b/taskiq/receiver/observer.py @@ -0,0 +1,21 @@ +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class ReceiverObserver(Protocol): + """ + Observer for reciever stats. + + This classs is used to observe/collect metrics for the receiver. + This includes semaphore usage, tasks in queue, etc. + + metrics tracked: + - Number of tasks in queue + - Number of taks in execution (Semaphore uusage) + """ + + def on_prefetch_queue_size(self, size: int) -> None: ... + def on_semaphore_status(self, available: int) -> None: ... + def on_active_tasks_count(self, count: int) -> None: ... + def on_task_not_found(self, task_name: str) -> None: ... + def on_deserialize_error(self, raw: bytes, error: Exception) -> None: ... diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 99298af2..fa195ff4 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -18,6 +18,7 @@ from taskiq.context import Context from taskiq.exceptions import NoResultError from taskiq.message import TaskiqMessage +from taskiq.receiver.observer import ReceiverObserver from taskiq.receiver.params_parser import parse_params from taskiq.result import TaskiqResult from taskiq.state import TaskiqState @@ -35,6 +36,7 @@ def __init__( self, broker: AsyncBroker, executor: Executor | None = None, + observer: ReceiverObserver | None = None, validate_params: bool = True, max_async_tasks: "int | None" = None, max_prefetch: int = 0, @@ -54,6 +56,7 @@ def __init__( self.dependency_graphs: dict[str, DependencyGraph] = {} self.propagate_exceptions = propagate_exceptions self.on_exit = on_exit + self.observer = observer self.ack_time = ack_type or AcknowledgeType.WHEN_SAVED self.known_tasks: set[str] = set() self.max_tasks_to_execute = max_tasks_to_execute @@ -92,6 +95,11 @@ async def callback( # noqa: C901, PLR0912 taskiq_msg = self.broker.formatter.loads(message=message_data) taskiq_msg.parse_labels() except Exception as exc: + if self.observer is not None: + self.observer.on_deserialize_error( + raw=message_data, + error=exc, + ) logger.warning( "Cannot parse message: %s. Skipping execution.\n %s", message_data, @@ -102,6 +110,11 @@ async def callback( # noqa: C901, PLR0912 logger.debug(f"Received message: {taskiq_msg}") task = self.broker.find_task(taskiq_msg.task_name) if task is None: + if self.observer is not None: + self.observer.on_task_not_found( + taskiq_msg.task_name, + ) + logger.warning( 'task "%s" is not found. Maybe you forgot to import it?', taskiq_msg.task_name, @@ -363,6 +376,7 @@ async def prefetcher( break try: await self.sem_prefetch.acquire() + if ( self.max_tasks_to_execute and fetched_tasks >= self.max_tasks_to_execute @@ -376,6 +390,7 @@ async def prefetcher( # and continue the loop. So it will check if finished event was set. if not done: self.sem_prefetch.release() + continue # We're done, so now we need to check # whether task has returned an error. @@ -383,6 +398,12 @@ async def prefetcher( current_message = asyncio.create_task(iterator.__anext__()) # type: ignore fetched_tasks += 1 await queue.put(message) + + if self.observer is not None: + self.observer.on_prefetch_queue_size( + queue.qsize(), + ) + except (asyncio.CancelledError, StopAsyncIteration): break # We don't want to fetch new messages if we are shutting down. @@ -413,17 +434,35 @@ def task_cb(task: "asyncio.Task[Any]") -> None: :param task: finished task """ tasks.discard(task) + if self.observer is not None: + self.observer.on_active_tasks_count( + len(tasks), + ) + if self.sem is not None: self.sem.release() + if self.observer is not None: + self.observer.on_semaphore_status( + self.sem._value # noqa + ) + while True: try: # Waits for semaphore to be released. if self.sem is not None: await self.sem.acquire() + if self.observer is not None: + self.observer.on_semaphore_status( + self.sem._value # noqa + ) self.sem_prefetch.release() message = await queue.get() + if self.observer is not None: + self.observer.on_prefetch_queue_size( + queue.qsize() # noqa + ) if message is QUEUE_DONE: # asyncio.wait will throw an error if there is nothing to wait for if tasks: @@ -438,7 +477,10 @@ def task_cb(task: "asyncio.Task[Any]") -> None: self.callback(message=message, raise_err=False), ) tasks.add(task) - + if self.observer is not None: + self.observer.on_active_tasks_count( + len(tasks), + ) # We want the task to remove itself from the set when it's done. # # Because if we won't save it anywhere, From 6498b2e4f17d106a14dca3db80fb3ab686d7dfb6 Mon Sep 17 00:00:00 2001 From: mohammedtarek Date: Thu, 19 Mar 2026 07:06:11 +0200 Subject: [PATCH 3/3] fix: clean up lint errors and typos in observability code - Fix typos in observer docstring and metric descriptions - Add missing docstrings to observer protocol and implementation methods - Remove unused Gauge import from PrometheusMiddleware.__init__ - Remove unused ReceiverObserver import from run.py - Fix import ordering (ruff I001) - Add noqa for expected complexity in runner method - Run black formatting --- taskiq/cli/worker/run.py | 2 +- taskiq/middlewares/prometheus_middleware.py | 23 +++++++++++----- taskiq/receiver/observer.py | 30 +++++++++++++++------ taskiq/receiver/receiver.py | 14 +++------- 4 files changed, 44 insertions(+), 25 deletions(-) diff --git a/taskiq/cli/worker/run.py b/taskiq/cli/worker/run.py index 831eb590..e48ac173 100644 --- a/taskiq/cli/worker/run.py +++ b/taskiq/cli/worker/run.py @@ -13,7 +13,7 @@ from taskiq.cli.utils import import_object, import_tasks from taskiq.cli.worker.args import WorkerArgs from taskiq.cli.worker.process_manager import ProcessManager -from taskiq.receiver import Receiver, ReceiverObserver +from taskiq.receiver import Receiver try: import uvloop diff --git a/taskiq/middlewares/prometheus_middleware.py b/taskiq/middlewares/prometheus_middleware.py index 7686715d..704c55c8 100644 --- a/taskiq/middlewares/prometheus_middleware.py +++ b/taskiq/middlewares/prometheus_middleware.py @@ -4,10 +4,11 @@ from pathlib import Path from tempfile import gettempdir from typing import Any + from taskiq.abc.middleware import TaskiqMiddleware from taskiq.message import TaskiqMessage -from taskiq.result import TaskiqResult from taskiq.receiver.observer import ReceiverObserver +from taskiq.result import TaskiqResult logger = getLogger("taskiq.prometheus") @@ -21,7 +22,7 @@ class PrometheusMiddleware(TaskiqMiddleware): :param server_port: The port to listen on. :param server_addr: The address to listen on. - :paam metrics_path: The path to store metrics for multiproc env. + :param metrics_path: The path to store metrics for multiproc env. """ def __init__( @@ -44,7 +45,7 @@ def __init__( logger.debug("Initializing metrics") try: - from prometheus_client import Counter, Histogram, Gauge # noqa: PLC0415 + from prometheus_client import Counter, Histogram # noqa: PLC0415 except ImportError as exc: raise ImportError( "Cannot initialize metrics. Please install 'taskiq[metrics]'.", @@ -199,9 +200,14 @@ def post_execute( self.success_tasks.labels(message.task_name).inc() self.execution_time.labels(message.task_name).observe(result.execution_time) - def set_broker(self, broker: "AsyncBroker") -> None: # noqa: F821 pyright: ignore[reportUnknownVariableType] + def set_broker(self, broker: "AsyncBroker") -> None: # noqa: F821 + """ + Set broker and attach receiver observer. + + :param broker: broker to set. + """ super().set_broker(broker) - broker._receiver_observer = PrometheusReceiverObserver() + broker._receiver_observer = PrometheusReceiverObserver() # noqa: SLF001 def post_save( self, @@ -250,20 +256,25 @@ def __init__(self) -> None: ) self.deserialize_error = Counter( "deserialize_error_count", - "Number of times broker faced a desrialization error", + "Number of times broker faced a deserialization error", ) def on_prefetch_queue_size(self, size: int) -> None: + """Record current prefetch queue depth.""" self.prefetch_queue_size.set(size) def on_semaphore_status(self, available: int) -> None: + """Record available semaphore slots.""" self.semaphore_available.set(available) def on_active_tasks_count(self, count: int) -> None: + """Record number of currently executing tasks.""" self.active_tasks_count.set(count) def on_task_not_found(self, task_name: str) -> None: + """Increment counter for unregistered task lookups.""" self.task_not_found_total.labels(task_name).inc() def on_deserialize_error(self, raw: bytes, error: Exception) -> None: + """Increment counter for message deserialization failures.""" self.deserialize_error.inc() diff --git a/taskiq/receiver/observer.py b/taskiq/receiver/observer.py index 0e0c2a0c..70a7ccd8 100644 --- a/taskiq/receiver/observer.py +++ b/taskiq/receiver/observer.py @@ -4,18 +4,32 @@ @runtime_checkable class ReceiverObserver(Protocol): """ - Observer for reciever stats. + Observer for receiver stats. - This classs is used to observe/collect metrics for the receiver. + This class is used to observe/collect metrics for the receiver. This includes semaphore usage, tasks in queue, etc. metrics tracked: - Number of tasks in queue - - Number of taks in execution (Semaphore uusage) + - Number of tasks in execution (semaphore usage) """ - def on_prefetch_queue_size(self, size: int) -> None: ... - def on_semaphore_status(self, available: int) -> None: ... - def on_active_tasks_count(self, count: int) -> None: ... - def on_task_not_found(self, task_name: str) -> None: ... - def on_deserialize_error(self, raw: bytes, error: Exception) -> None: ... + def on_prefetch_queue_size(self, size: int) -> None: + """Called when the prefetch queue size changes.""" + ... + + def on_semaphore_status(self, available: int) -> None: + """Called when semaphore availability changes.""" + ... + + def on_active_tasks_count(self, count: int) -> None: + """Called when the number of active tasks changes.""" + ... + + def on_task_not_found(self, task_name: str) -> None: + """Called when a received task is not registered.""" + ... + + def on_deserialize_error(self, raw: bytes, error: Exception) -> None: + """Called when a message fails to deserialize.""" + ... diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index fa195ff4..6afd1101 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -412,7 +412,7 @@ async def prefetcher( await queue.put(QUEUE_DONE) self.sem_prefetch.release() - async def runner( + async def runner( # noqa: C901 self, queue: "asyncio.Queue[bytes | AckableMessage]", ) -> None: @@ -443,9 +443,7 @@ def task_cb(task: "asyncio.Task[Any]") -> None: self.sem.release() if self.observer is not None: - self.observer.on_semaphore_status( - self.sem._value # noqa - ) + self.observer.on_semaphore_status(self.sem._value) # noqa while True: try: @@ -453,16 +451,12 @@ def task_cb(task: "asyncio.Task[Any]") -> None: if self.sem is not None: await self.sem.acquire() if self.observer is not None: - self.observer.on_semaphore_status( - self.sem._value # noqa - ) + self.observer.on_semaphore_status(self.sem._value) # noqa self.sem_prefetch.release() message = await queue.get() if self.observer is not None: - self.observer.on_prefetch_queue_size( - queue.qsize() # noqa - ) + self.observer.on_prefetch_queue_size(queue.qsize()) if message is QUEUE_DONE: # asyncio.wait will throw an error if there is nothing to wait for if tasks: