Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions taskiq/abc/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@
from taskiq.utils import maybe_awaitable
from taskiq.warnings import TaskiqDeprecationWarning

if sys.version_info >= (3, 11):
if sys.version_info >= (
3,
11,
): # Check which python version are we running to import correctly
from typing import Self
else:
from typing_extensions import Self
from typing_extensions import Self # pragma: no cover


if TYPE_CHECKING: # pragma: no cover
Expand All @@ -46,6 +49,7 @@
_FuncParams = ParamSpec("_FuncParams")
_ReturnType = TypeVar("_ReturnType")

# an event handler can be either a sync or an async function that has one parameter of type TaskiqState
EventHandler: TypeAlias = Callable[[TaskiqState], Awaitable[None] | None]

logger = getLogger("taskiq")
Expand Down
2 changes: 1 addition & 1 deletion taskiq/cli/worker/process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(
for path_to_watch in watch_paths:
logger.debug(f"Watching directory: {path_to_watch}")
observer.schedule(
FileWatcher(
FileWatcher( # type: ignore
callback=schedule_workers_reload,
path=Path(path_to_watch),
use_gitignore=not args.no_gitignore,
Expand Down
1 change: 1 addition & 0 deletions taskiq/cli/worker/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
receiver = receiver_type(
broker=broker,
executor=pool,
observer=getattr(broker, "_receiver_observer", None),
validate_params=not args.no_parse,
max_async_tasks=args.max_async_tasks,
max_prefetch=args.max_prefetch,
Expand Down
131 changes: 130 additions & 1 deletion taskiq/middlewares/prometheus_middleware.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import os
from logging import getLogger
from pathlib import Path
Expand All @@ -6,6 +7,7 @@

from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.message import TaskiqMessage
from taskiq.receiver.observer import ReceiverObserver
from taskiq.result import TaskiqResult

logger = getLogger("taskiq.prometheus")
Expand All @@ -20,7 +22,7 @@ class PrometheusMiddleware(TaskiqMiddleware):

:param server_port: The port to listen on.
:param server_addr: The address to listen on.
:paam metrics_path: The path to store metrics for multiproc env.
:param metrics_path: The path to store metrics for multiproc env.
"""

def __init__(
Expand Down Expand Up @@ -74,6 +76,18 @@ def __init__(
"Time of function execution",
["task_name"],
)

self.queue_wait_seconds = Histogram(
"queue_wait_seconds",
"time task spent in message queue",
["task_name"],
)
self.task_errors_by_type = Counter(
"task_errors_by_type",
"Number of errors raised in tasks by their type",
["task_name", "error_type"],
)

self.server_port = server_port
self.server_addr = server_addr

Expand Down Expand Up @@ -104,6 +118,24 @@ def startup(self) -> None:
except OSError as exc:
logger.debug("Cannot start prometheus server: %s", exc)

def pre_send(
self,
message: "TaskiqMessage",
) -> "TaskiqMessage":
"""
Function to track the time a task spend in queue.

This function tracks the time a task spends in a queue until it is executed.

:param message: current message.
:return: message
"""
if not message.labels.get("_taskiq_enqueue_timestamp"):
message.labels["_taskiq_enqueue_timestamp"] = datetime.datetime.now(
datetime.UTC,
).isoformat() # Might conside using timezones too
return message

def pre_execute(
self,
message: "TaskiqMessage",
Expand All @@ -117,9 +149,40 @@ def pre_execute(
:param message: current message.
:return: message
"""
if message.labels.get(
"_taskiq_enqueue_timestamp",
): # Handle case where the sender doesn't use the prometheus middleware
time_delta = datetime.datetime.now(
datetime.UTC,
) - datetime.datetime.fromisoformat(
message.labels["_taskiq_enqueue_timestamp"],
)
time_delta = max(0, time_delta.total_seconds())
self.queue_wait_seconds.labels(message.task_name).observe(
time_delta,
)

self.received_tasks.labels(message.task_name).inc()
return message

def on_error(
self,
message: TaskiqMessage,
result: TaskiqResult[Any], # pylint: disable=unused-argument
exception: BaseException,
) -> None:
"""
This function tracks the number of errors raised by tasks.

:param message: the received task message
:param result: the result of task
:param exception: exception raised
"""
self.task_errors_by_type.labels(
message.task_name,
type(exception).__name__,
).inc()

def post_execute(
self,
message: "TaskiqMessage",
Expand All @@ -137,6 +200,15 @@ def post_execute(
self.success_tasks.labels(message.task_name).inc()
self.execution_time.labels(message.task_name).observe(result.execution_time)

def set_broker(self, broker: "AsyncBroker") -> None: # noqa: F821
"""
Set broker and attach receiver observer.

:param broker: broker to set.
"""
super().set_broker(broker)
broker._receiver_observer = PrometheusReceiverObserver() # noqa: SLF001

def post_save(
self,
message: "TaskiqMessage",
Expand All @@ -149,3 +221,60 @@ def post_save(
:param result: result of execution.
"""
self.saved_results.labels(message.task_name).inc()


class PrometheusReceiverObserver(ReceiverObserver):
"""Receiver observer implementation for prometheus."""

def __init__(self) -> None:
try:
from prometheus_client import Counter, Gauge # noqa: PLC0415
except ImportError as exc:
raise ImportError(
"Cannot initialize metrics. Please install 'taskiq[metrics]'.",
) from exc

self.prefetch_queue_size = Gauge(
"prefetch_queue_size",
"The number of task in the prefetch queue.",
multiprocess_mode="livesum",
)
self.semaphore_available = Gauge(
"semaphore_available",
"Number of semaphore slots available in broker",
multiprocess_mode="livesum",
)
self.active_tasks_count = Gauge(
"worker_active_tasks_count",
"Number of active tasks in worker",
multiprocess_mode="livesum",
)
self.task_not_found_total = Counter(
"task_not_found_total",
"Number of times the worker got a task not registered",
["task_name"],
)
self.deserialize_error = Counter(
"deserialize_error_count",
"Number of times broker faced a deserialization error",
)

def on_prefetch_queue_size(self, size: int) -> None:
"""Record current prefetch queue depth."""
self.prefetch_queue_size.set(size)

def on_semaphore_status(self, available: int) -> None:
"""Record available semaphore slots."""
self.semaphore_available.set(available)

def on_active_tasks_count(self, count: int) -> None:
"""Record number of currently executing tasks."""
self.active_tasks_count.set(count)

def on_task_not_found(self, task_name: str) -> None:
"""Increment counter for unregistered task lookups."""
self.task_not_found_total.labels(task_name).inc()

def on_deserialize_error(self, raw: bytes, error: Exception) -> None:
"""Increment counter for message deserialization failures."""
self.deserialize_error.inc()
3 changes: 2 additions & 1 deletion taskiq/receiver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Package for message receiver."""

from taskiq.receiver.receiver import Receiver
from taskiq.receiver.observer import ReceiverObserver

__all__ = ["Receiver"]
__all__ = ["Receiver", "ReceiverObserver"]
35 changes: 35 additions & 0 deletions taskiq/receiver/observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Protocol, runtime_checkable


@runtime_checkable
class ReceiverObserver(Protocol):
"""
Observer for receiver stats.

This class is used to observe/collect metrics for the receiver.
This includes semaphore usage, tasks in queue, etc.

metrics tracked:
- Number of tasks in queue
- Number of tasks in execution (semaphore usage)
"""

def on_prefetch_queue_size(self, size: int) -> None:
"""Called when the prefetch queue size changes."""
...

def on_semaphore_status(self, available: int) -> None:
"""Called when semaphore availability changes."""
...

def on_active_tasks_count(self, count: int) -> None:
"""Called when the number of active tasks changes."""
...

def on_task_not_found(self, task_name: str) -> None:
"""Called when a received task is not registered."""
...

def on_deserialize_error(self, raw: bytes, error: Exception) -> None:
"""Called when a message fails to deserialize."""
...
40 changes: 38 additions & 2 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from taskiq.context import Context
from taskiq.exceptions import NoResultError
from taskiq.message import TaskiqMessage
from taskiq.receiver.observer import ReceiverObserver
from taskiq.receiver.params_parser import parse_params
from taskiq.result import TaskiqResult
from taskiq.state import TaskiqState
Expand All @@ -35,6 +36,7 @@ def __init__(
self,
broker: AsyncBroker,
executor: Executor | None = None,
observer: ReceiverObserver | None = None,
validate_params: bool = True,
max_async_tasks: "int | None" = None,
max_prefetch: int = 0,
Expand All @@ -54,6 +56,7 @@ def __init__(
self.dependency_graphs: dict[str, DependencyGraph] = {}
self.propagate_exceptions = propagate_exceptions
self.on_exit = on_exit
self.observer = observer
self.ack_time = ack_type or AcknowledgeType.WHEN_SAVED
self.known_tasks: set[str] = set()
self.max_tasks_to_execute = max_tasks_to_execute
Expand Down Expand Up @@ -92,6 +95,11 @@ async def callback( # noqa: C901, PLR0912
taskiq_msg = self.broker.formatter.loads(message=message_data)
taskiq_msg.parse_labels()
except Exception as exc:
if self.observer is not None:
self.observer.on_deserialize_error(
raw=message_data,
error=exc,
)
logger.warning(
"Cannot parse message: %s. Skipping execution.\n %s",
message_data,
Expand All @@ -102,6 +110,11 @@ async def callback( # noqa: C901, PLR0912
logger.debug(f"Received message: {taskiq_msg}")
task = self.broker.find_task(taskiq_msg.task_name)
if task is None:
if self.observer is not None:
self.observer.on_task_not_found(
taskiq_msg.task_name,
)

logger.warning(
'task "%s" is not found. Maybe you forgot to import it?',
taskiq_msg.task_name,
Expand Down Expand Up @@ -363,6 +376,7 @@ async def prefetcher(
break
try:
await self.sem_prefetch.acquire()

if (
self.max_tasks_to_execute
and fetched_tasks >= self.max_tasks_to_execute
Expand All @@ -376,13 +390,20 @@ async def prefetcher(
# and continue the loop. So it will check if finished event was set.
if not done:
self.sem_prefetch.release()

continue
# We're done, so now we need to check
# whether task has returned an error.
message = current_message.result()
current_message = asyncio.create_task(iterator.__anext__()) # type: ignore
fetched_tasks += 1
await queue.put(message)

if self.observer is not None:
self.observer.on_prefetch_queue_size(
queue.qsize(),
)

except (asyncio.CancelledError, StopAsyncIteration):
break
# We don't want to fetch new messages if we are shutting down.
Expand All @@ -391,7 +412,7 @@ async def prefetcher(
await queue.put(QUEUE_DONE)
self.sem_prefetch.release()

async def runner(
async def runner( # noqa: C901
self,
queue: "asyncio.Queue[bytes | AckableMessage]",
) -> None:
Expand All @@ -413,17 +434,29 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
:param task: finished task
"""
tasks.discard(task)
if self.observer is not None:
self.observer.on_active_tasks_count(
len(tasks),
)

if self.sem is not None:
self.sem.release()

if self.observer is not None:
self.observer.on_semaphore_status(self.sem._value) # noqa

while True:
try:
# Waits for semaphore to be released.
if self.sem is not None:
await self.sem.acquire()
if self.observer is not None:
self.observer.on_semaphore_status(self.sem._value) # noqa

self.sem_prefetch.release()
message = await queue.get()
if self.observer is not None:
self.observer.on_prefetch_queue_size(queue.qsize())
if message is QUEUE_DONE:
# asyncio.wait will throw an error if there is nothing to wait for
if tasks:
Expand All @@ -438,7 +471,10 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
self.callback(message=message, raise_err=False),
)
tasks.add(task)

if self.observer is not None:
self.observer.on_active_tasks_count(
len(tasks),
)
# We want the task to remove itself from the set when it's done.
#
# Because if we won't save it anywhere,
Expand Down