Skip to content

Commit

Permalink
FIX: Send ControlMessage *only* to worker 0.
Browse files Browse the repository at this point in the history
  • Loading branch information
xsedla1o committed Aug 23, 2024
1 parent 174c722 commit df073f5
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 14 deletions.
3 changes: 3 additions & 0 deletions dp3/common/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class ControlMessage(Task):
def routing_key(self):
return ""

def hashed_routing_key(self) -> int:
return 0

def as_message(self) -> str:
return self.model_dump_json()

Expand Down
18 changes: 18 additions & 0 deletions dp3/common/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
from abc import ABC, abstractmethod
from collections.abc import Iterator
from contextlib import contextmanager
Expand All @@ -21,6 +22,16 @@
_init_context_var = ContextVar("_init_context_var", default=None)


def HASH(key: str) -> int:
"""Hash function used to distribute tasks to worker processes.
Args:
key: to be hashed
Returns:
last 4 bytes of MD5
"""
return int(hashlib.md5(key.encode("utf8")).hexdigest()[-4:], 16)


@contextmanager
def task_context(model_spec: ModelSpec) -> Iterator[None]:
"""Context manager for setting the `model_spec` context variable."""
Expand All @@ -45,6 +56,13 @@ def routing_key(self) -> str:
A string to be used as a routing key between workers.
"""

def hashed_routing_key(self) -> int:
"""
Returns:
An integer to be used as a hashed routing key between workers.
"""
return HASH(self.routing_key())

@abstractmethod
def as_message(self) -> str:
"""
Expand Down
2 changes: 1 addition & 1 deletion dp3/scripts/add_hashes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from pymongo import UpdateOne

from dp3.common.config import ModelSpec, read_config_dir
from dp3.common.task import HASH
from dp3.database.database import EntityDatabase, MongoConfig
from dp3.task_processing.task_queue import HASH

# Arguments parser
parser = argparse.ArgumentParser(
Expand Down
15 changes: 2 additions & 13 deletions dp3/task_processing/task_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

import collections
import contextlib
import hashlib
import logging
import threading
import time
Expand All @@ -53,16 +52,6 @@
DEFAULT_PRIORITY_QUEUE = "{}-worker-{}-pri"


def HASH(key: str) -> int:
"""Hash function used to distribute tasks to worker processes.
Args:
key: to be hashed
Returns:
last 4 bytes of MD5
"""
return int(hashlib.md5(key.encode("utf8")).hexdigest()[-4:], 16)


# When reading, pre-fetch only a limited amount of messages
# (because pre-fetched messages are not counted to queue length limit)
PREFETCH_COUNT = 50
Expand Down Expand Up @@ -294,8 +283,8 @@ def put_task(self, task: Task, priority: bool = False) -> None:

# Prepare routing key
body = task.as_message()
key = task.routing_key()
routing_key = HASH(key) % self.workers # index of the worker to send the task to
# index of the worker to send the task to
routing_key = task.hashed_routing_key() % self.workers

exchange = self.exchange_pri if priority else self.exchange
self._send_message(routing_key, exchange, body)
Expand Down

0 comments on commit df073f5

Please sign in to comment.