From 34d958a5f910872b248e4188a44d5a7c93f7c7a6 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Sun, 29 Sep 2024 14:29:04 -0600 Subject: [PATCH 1/3] Fix instrumentation for workflows --- .../core/instrumentation/dispatcher.py | 49 ++++++++++++++++--- 1 file changed, 42 insertions(+), 7 deletions(-) diff --git a/llama-index-core/llama_index/core/instrumentation/dispatcher.py b/llama-index-core/llama_index/core/instrumentation/dispatcher.py index 8e892e17162aa..4f5683430402f 100644 --- a/llama-index-core/llama_index/core/instrumentation/dispatcher.py +++ b/llama-index-core/llama_index/core/instrumentation/dispatcher.py @@ -1,3 +1,5 @@ +import asyncio +from functools import partial from contextlib import contextmanager from contextvars import ContextVar, Token from typing import Any, Callable, Generator, List, Optional, Dict, Protocol @@ -261,20 +263,53 @@ def wrapper(func: Callable, instance: Any, args: list, kwargs: dict) -> Any: parent_id=parent_id, tags=tags, ) + + def handle_future_result(future, span_id, bound_args, instance): + try: + result = future.result() + self.span_exit( + id_=span_id, + bound_args=bound_args, + instance=instance, + result=result, + ) + return result + except BaseException as e: + self.event(SpanDropEvent(span_id=span_id, err_str=str(e))) + self.span_drop( + id_=span_id, bound_args=bound_args, instance=instance, err=e + ) + raise + finally: + active_span_id.reset(token) + try: result = func(*args, **kwargs) + if isinstance(result, asyncio.Future): + # If the result is a Future, wrap it + new_future = asyncio.ensure_future(result) + new_future.add_done_callback( + partial( + handle_future_result, + span_id=id_, + bound_args=bound_args, + instance=instance, + ) + ) + return new_future + else: + # For non-Future results, proceed as before + self.span_exit( + id_=id_, bound_args=bound_args, instance=instance, result=result + ) + return result except BaseException as e: self.event(SpanDropEvent(span_id=id_, err_str=str(e))) self.span_drop(id_=id_, bound_args=bound_args, instance=instance, err=e) raise - else: - self.span_exit( - id_=id_, bound_args=bound_args, instance=instance, result=result - ) - return result finally: - # clean up - active_span_id.reset(token) + if not isinstance(result, asyncio.Future): + active_span_id.reset(token) @wrapt.decorator async def async_wrapper( From 6966e2cc5e89500272977eaa6af04e12ba7715e9 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Mon, 30 Sep 2024 10:16:54 -0600 Subject: [PATCH 2/3] linting --- .../llama_index/core/instrumentation/dispatcher.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/llama-index-core/llama_index/core/instrumentation/dispatcher.py b/llama-index-core/llama_index/core/instrumentation/dispatcher.py index 4f5683430402f..5917036bb6c8e 100644 --- a/llama-index-core/llama_index/core/instrumentation/dispatcher.py +++ b/llama-index-core/llama_index/core/instrumentation/dispatcher.py @@ -264,7 +264,12 @@ def wrapper(func: Callable, instance: Any, args: list, kwargs: dict) -> Any: tags=tags, ) - def handle_future_result(future, span_id, bound_args, instance): + def handle_future_result( + future: asyncio.Future, + span_id: str, + bound_args: inspect.BoundArguments, + instance: Any, + ) -> None: try: result = future.result() self.span_exit( From ab455dbcced2d3ab881a02822381088e988fe14f Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Mon, 30 Sep 2024 11:00:54 -0600 Subject: [PATCH 3/3] fix tests --- llama-index-core/llama_index/core/instrumentation/dispatcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llama-index-core/llama_index/core/instrumentation/dispatcher.py b/llama-index-core/llama_index/core/instrumentation/dispatcher.py index 5917036bb6c8e..c05f326cb6fd1 100644 --- a/llama-index-core/llama_index/core/instrumentation/dispatcher.py +++ b/llama-index-core/llama_index/core/instrumentation/dispatcher.py @@ -253,6 +253,7 @@ def wrapper(func: Callable, instance: Any, args: list, kwargs: dict) -> Any: bound_args = inspect.signature(func).bind(*args, **kwargs) id_ = f"{func.__qualname__}-{uuid.uuid4()}" tags = active_instrument_tags.get() + result = None token = active_span_id.set(id_) parent_id = None if token.old_value is Token.MISSING else token.old_value