From 1b5fd996d3037b84b78b39a8c0525015d66ee10c Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Thu, 25 Jan 2024 15:43:28 +0800 Subject: [PATCH] Enhance tool annotation (#98) improve `tool_api` Co-authored-by: wangzy --- lagent/actions/arxiv_search.py | 5 +- lagent/actions/base_action.py | 69 +++++++++++++------------ lagent/actions/bing_map.py | 22 ++++---- lagent/actions/google_scholar_search.py | 34 ++++++------ lagent/actions/ppt.py | 25 +++++---- 5 files changed, 86 insertions(+), 69 deletions(-) diff --git a/lagent/actions/arxiv_search.py b/lagent/actions/arxiv_search.py index 0d833ccc..0b3e6332 100644 --- a/lagent/actions/arxiv_search.py +++ b/lagent/actions/arxiv_search.py @@ -26,7 +26,7 @@ def __init__(self, self.max_query_len = max_query_len self.doc_content_chars_max = doc_content_chars_max - @tool_api(return_dict=True) + @tool_api(explode_return=True) def get_arxiv_article_information(self, query: str) -> dict: """Run Arxiv search and get the article meta information. @@ -34,7 +34,8 @@ def get_arxiv_article_information(self, query: str) -> dict: query (:class:`str`): the content of search query Returns: - content (:class:`str`): a list of 3 arxiv search papers + :class:`dict`: article information + * content (str): a list of 3 arxiv search papers """ try: results = arxiv.Search( # type: ignore diff --git a/lagent/actions/base_action.py b/lagent/actions/base_action.py index 9561c7a0..9c7e0dca 100644 --- a/lagent/actions/base_action.py +++ b/lagent/actions/base_action.py @@ -20,7 +20,7 @@ def tool_api(func: Optional[Callable] = None, *, - return_dict: bool = False, + explode_return: bool = False, returns_named_value: bool = False, **kwargs): """Turn functions into tools. It will parse typehints as well as docstrings @@ -48,15 +48,13 @@ def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1): Args: func (Optional[Callable]): function to decorate. Defaults to ``None``. - return_dict (bool): suggest if returned data is a single dictionary. - When enabled, the returns sections in docstrings should indicate the - key-value infomation of the dictionary rather than hint a standard - tuple return. Defaults to ``False``. + explode_return (bool): whether to flatten the dictionary or tuple return + as the ``return_data`` field. When enabled, it is recommended to + annotate the member in docstrings. Defaults to ``False``. .. code-block:: python - # set `return_dict` True will force `returns_named_value` to be enabled - @tool_api(return_dict=True) + @tool_api(explode_return=True) def foo(a, b): '''A simple function @@ -65,8 +63,9 @@ def foo(a, b): b (int): b Returns: - x: the value of input a - y: the value of input b + dict: information of inputs + * x: value of a + * y: value of b ''' return {'x': a, 'y': b} @@ -75,14 +74,16 @@ def foo(a, b): returns_named_value (bool): whether to parse ``thing: Description`` in returns sections as a name and description, rather than a type and description. When true, type must be wrapped in parentheses: - ``(int): Description.``. When false, parentheses are optional but + ``(int): Description``. When false, parentheses are optional but the items cannot be named: ``int: Description``. Defaults to ``False``. + + Important: + ``return_data`` field will be added to ``api_description`` only + when ``explode_return`` or ``returns_named_value`` is enabled. Returns: Callable: wrapped function or partial decorator """ - if return_dict: - returns_named_value = True def _detect_type(string): field_type = 'STRING' @@ -97,6 +98,25 @@ def _detect_type(string): field_type = 'BOOLEAN' return field_type + def _explode(desc): + kvs = [] + desc = '\nArgs:\n' + '\n'.join([ + ' ' + item.lstrip(' -+*#.') + for item in desc.split('\n')[1:] if item.strip() + ]) + docs = Docstring(desc).parse('google') + if not docs: + return kvs + if docs[0].kind is DocstringSectionKind.parameters: + for d in docs[0].value: + d = d.as_dict() + if not d['annotation']: + d.pop('annotation') + else: + d['type'] = _detect_type(d.pop('annotation').lower()) + kvs.append(d) + return kvs + def _parse_tool(function): # remove rst syntax docs = Docstring( @@ -114,13 +134,11 @@ def _parse_tool(function): if doc.kind is DocstringSectionKind.parameters: for d in doc.value: d = d.as_dict() - d['description'] = d['description'] d['type'] = _detect_type(d.pop('annotation').lower()) args_doc[d['name']] = d if doc.kind is DocstringSectionKind.returns: for d in doc.value: d = d.as_dict() - d['description'] = d['description'] if not d['name']: d.pop('name') if not d['annotation']: @@ -154,26 +172,11 @@ def _parse_tool(function): if param.default is inspect.Signature.empty: desc['required'].append(param.name) - return_data, return_annotation = [], sig.return_annotation - if return_dict: + return_data = [] + if explode_return: + return_data = _explode(returns_doc[0]['description']) + elif returns_named_value: return_data = returns_doc - elif return_annotation is not inspect.Signature.empty: - if return_annotation is tuple: - return_data = returns_doc - elif get_origin(return_annotation) is tuple: - return_annotation = get_args(return_annotation) - if not return_annotation: - return_data = returns_doc - elif len(return_annotation) >= 2: - for i, item in enumerate(return_annotation): - info = returns_doc[i]['description'] if i < len( - returns_doc) else '' - if get_origin(item) is Annotated: - item, info = get_args(item) - return_data.append({ - 'description': info, - 'type': _detect_type(str(item)) - }) if return_data: desc['return_data'] = return_data return desc diff --git a/lagent/actions/bing_map.py b/lagent/actions/bing_map.py index 6906cb62..01efe698 100644 --- a/lagent/actions/bing_map.py +++ b/lagent/actions/bing_map.py @@ -25,7 +25,7 @@ def __init__(self, self.key = key self.base_url = 'http://dev.virtualearth.net/REST/V1/' - @tool_api(return_dict=True) + @tool_api(explode_return=True) def get_distance(self, start: str, end: str) -> dict: """Get the distance between two locations in km. @@ -34,7 +34,8 @@ def get_distance(self, start: str, end: str) -> dict: end (:class:`str`): The end location Returns: - distance (:class:`str`): the distance in km. + :class:`dict`: distance information + * distance (str): the distance in km. """ # Request URL url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key @@ -48,7 +49,7 @@ def get_distance(self, start: str, end: str) -> dict: distance = route['travelDistance'] return dict(distance=distance) - @tool_api(return_dict=True) + @tool_api(explode_return=True) def get_route(self, start: str, end: str) -> dict: """Get the route between two locations in km. @@ -57,7 +58,8 @@ def get_route(self, start: str, end: str) -> dict: end (:class:`str`): The end location Returns: - route (:class:`list`): the route, a list of actions. + :class:`dict`: route information + * route (list): the route, a list of actions. """ # Request URL url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key @@ -74,7 +76,7 @@ def get_route(self, start: str, end: str) -> dict: route_text.append(item['instruction']['text']) return dict(route=route_text) - @tool_api(return_dict=True) + @tool_api(explode_return=True) def get_coordinates(self, location: str) -> dict: """Get the coordinates of a location. @@ -82,8 +84,9 @@ def get_coordinates(self, location: str) -> dict: location (:class:`str`): the location need to get coordinates. Returns: - latitude (:class:`float`): the latitude of the location. - longitude (:class:`float`): the longitude of the location. + :class:`dict`: coordinates information + * latitude (float): the latitude of the location. + * longitude (float): the longitude of the location. """ url = self.base_url + 'Locations' params = {'query': location, 'key': self.key} @@ -93,7 +96,7 @@ def get_coordinates(self, location: str) -> dict: 'coordinates'] return dict(latitude=coordinates[0], longitude=coordinates[1]) - @tool_api(return_dict=True) + @tool_api(explode_return=True) def search_nearby(self, search_term: str, places: str = 'unknown', @@ -112,7 +115,8 @@ def search_nearby(self, radius (:class:`int`): radius in meters. Defaults to ``5000``. Returns: - places (:class:`list`): the list of places, each place is a dict with name and address, at most 5 places. + :class:`dict`: places information + * places (list): the list of places, each place is a dict with name and address, at most 5 places. """ url = self.base_url + 'LocalSearch' if places != 'unknown': diff --git a/lagent/actions/google_scholar_search.py b/lagent/actions/google_scholar_search.py index 9209a0e6..4098f614 100644 --- a/lagent/actions/google_scholar_search.py +++ b/lagent/actions/google_scholar_search.py @@ -35,7 +35,7 @@ def __init__(self, 'as SERPER_API_KEY or pass it as `api_key` parameter.') self.api_key = api_key - @tool_api(return_dict=True) + @tool_api(explode_return=True) def search_google_scholar( self, query: str, @@ -72,10 +72,11 @@ def search_google_scholar( as_vis (Optional[str]): Defines whether to include citations or not. Returns: - title: a list of the titles of the three selected papers - cited_by: a list of the citation numbers of the three selected papers - organic_id: a list of the organic results' ids of the three selected papers - pub_info: publication information of selected papers + :class:`dict`: article information + - title: a list of the titles of the three selected papers + - cited_by: a list of the citation numbers of the three selected papers + - organic_id: a list of the organic results' ids of the three selected papers + - pub_info: publication information of selected papers """ params = { 'q': query, @@ -120,7 +121,7 @@ def search_google_scholar( return ActionReturn( errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) - @tool_api(return_dict=True) + @tool_api(explode_return=True) def get_author_information(self, author_id: str, hl: Optional[str] = None, @@ -147,10 +148,11 @@ def get_author_information(self, output (Optional[str]): Defines the final output you want. Default is 'json'. Returns: - name: author's name - affliation: the affliation of the author - articles: at most 3 articles by the author - website: the author's homepage url + :class:`dict`: author information + * name: author's name + * affliation: the affliation of the author + * articles: at most 3 articles by the author + * website: the author's homepage url """ params = { 'engine': 'google_scholar_author', @@ -183,7 +185,7 @@ def get_author_information(self, return ActionReturn( errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) - @tool_api(return_dict=True) + @tool_api(explode_return=True) def get_citation_format(self, q: str, no_cache: Optional[bool] = None, @@ -198,8 +200,9 @@ def get_citation_format(self, output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'. Returns: - authors: the authors of the article - citation: the citation format of the article + :class:`dict`: citation format + * authors: the authors of the article + * citation: the citation format of the article """ params = { 'q': q, @@ -219,7 +222,7 @@ def get_citation_format(self, return ActionReturn( errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) - @tool_api(return_dict=True) + @tool_api(explode_return=True) def get_author_id(self, mauthors: str, hl: Optional[str] = 'en', @@ -240,7 +243,8 @@ def get_author_id(self, output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'. Returns: - author_id: the author_id of the author + :class:`dict`: author id + * author_id: the author_id of the author """ params = { 'mauthors': mauthors, diff --git a/lagent/actions/ppt.py b/lagent/actions/ppt.py index 9cd021bc..f0e68502 100644 --- a/lagent/actions/ppt.py +++ b/lagent/actions/ppt.py @@ -28,7 +28,7 @@ def __init__(self, self.pointer = None self.location = None - @tool_api(return_dict=True) + @tool_api(explode_return=True) def create_file(self, theme: str, abs_location: str) -> dict: """Create a pptx file with specific themes @@ -37,7 +37,8 @@ def create_file(self, theme: str, abs_location: str) -> dict: abs_location (:class:`str`): the ppt file's absolute location Returns: - status: the result of the execution + :class:`dict`: operation status + * status: the result of the execution """ self.location = abs_location try: @@ -48,7 +49,7 @@ def create_file(self, theme: str, abs_location: str) -> dict: print(e) return dict(status='created a ppt file.') - @tool_api(return_dict=True) + @tool_api(explode_return=True) def add_first_page(self, title: str, subtitle: str) -> dict: """Add the first page of ppt. @@ -57,7 +58,8 @@ def add_first_page(self, title: str, subtitle: str) -> dict: subtitle (:class:`str`): the subtitle of ppt Returns: - status: the result of the execution + :class:`dict`: operation status + * status: the result of the execution """ layout_name = self.theme_mapping[ self.pointer.slide_master.name]['title'] @@ -70,7 +72,7 @@ def add_first_page(self, title: str, subtitle: str) -> dict: ph_subtitle.text = subtitle return dict(status='added page') - @tool_api(return_dict=True) + @tool_api(explode_return=True) def add_text_page(self, title: str, bullet_items: str) -> dict: """Add text page of ppt @@ -79,7 +81,8 @@ def add_text_page(self, title: str, bullet_items: str) -> dict: bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them. Returns: - status: the result of the execution + :class:`dict`: operation status + * status: the result of the execution """ layout_name = self.theme_mapping[ self.pointer.slide_master.name]['single'] @@ -99,7 +102,7 @@ def add_text_page(self, title: str, bullet_items: str) -> dict: p.level = 0 return dict(status='added page') - @tool_api(return_dict=True) + @tool_api(explode_return=True) def add_text_image_page(self, title: str, bullet_items: str, image: str) -> dict: """Add a text page with one image. Image should be a path @@ -110,7 +113,8 @@ def add_text_image_page(self, title: str, bullet_items: str, image (:class:`str`): the path of the image Returns: - status: the result of the execution + :class:`dict`: operation status + * status: the result of the execution """ layout_name = self.theme_mapping[self.pointer.slide_master.name]['two'] layout = next(i for i in self.pointer.slide_master.slide_layouts @@ -138,12 +142,13 @@ def add_text_image_page(self, title: str, bullet_items: str, return dict(status='added page') - @tool_api(return_dict=True) + @tool_api(explode_return=True) def submit_file(self) -> dict: """When all steps done, YOU MUST use submit_file() to submit your work. Returns: - status: the result of the execution + :class:`dict`: operation status + * status: the result of the execution """ # file_path = os.path.join(self.CACHE_DIR, f'{self._return_timestamp()}.pptx') # self.pointer.save(file_path)