Skip to content

Commit

Permalink
Squash, hopefully, a few more refcycles for cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Dec 6, 2023
1 parent 793ee37 commit 67b17b4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
23 changes: 13 additions & 10 deletions python/distributed-ucxx/distributed_ucxx/ucxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,24 +566,27 @@ def address(self):
return f"{self.prefix}{self.ip}:{self.port}"

async def start(self):
async def serve_forever(client_ep):
ucx = self.comm_class(
async def serve_forever(client_ep, *, selfref):
ucx = selfref().comm_class(
client_ep,
local_addr=self.address,
peer_addr=self.address,
deserialize=self.deserialize,
local_addr=selfref().address,
peer_addr=selfref().address,
deserialize=selfref().deserialize,
)
ucx.allow_offload = self.allow_offload
ucx.allow_offload = selfref().allow_offload
try:
await self.on_connection(ucx)
await selfref().on_connection(ucx)
except CommClosedError:
logger.debug("Connection closed before handshake completed")
return
if self.comm_handler:
await self.comm_handler(ucx)
if selfref().comm_handler:
await selfref().comm_handler(ucx)

init_once()
self.ucxx_server = ucxx.create_listener(serve_forever, port=self._input_port)
self.ucxx_server = ucxx.create_listener(
functools.partial(serve_forever, selfref=weakref.ref(self)),
port=self._input_port,
)

def stop(self):
self.ucxx_server = None
Expand Down
11 changes: 7 additions & 4 deletions python/ucxx/_lib_async/tests/test_custom_send_recv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# SPDX-License-Identifier: BSD-3-Clause

import asyncio
import functools
import pickle
import weakref

import numpy as np
import pytest
Expand Down Expand Up @@ -98,11 +100,12 @@ def __init__(self):
self.comm = None

def start(self):
async def serve_forever(ep):
ucx = UCX(ep)
self.comm = ucx
async def serve_forever(ep, *, selfref):
selfref().comm = UCX(ep)

self.ucxx_server = ucxx.create_listener(serve_forever)
self.ucxx_server = ucxx.create_listener(
functools.partial(serve_forever, selfref=weakref.ref(self))
)

uu = UCXListener()
uu.start()
Expand Down

0 comments on commit 67b17b4

Please sign in to comment.