Skip to content

Commit

Permalink
Enhance tool annotation (#98)
Browse files Browse the repository at this point in the history
improve `tool_api`

Co-authored-by: wangzy <[email protected]>
  • Loading branch information
braisedpork1964 and wangzy authored Jan 25, 2024
1 parent 799c6a3 commit 1b5fd99
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 69 deletions.
5 changes: 3 additions & 2 deletions lagent/actions/arxiv_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@ def __init__(self,
self.max_query_len = max_query_len
self.doc_content_chars_max = doc_content_chars_max

@tool_api(return_dict=True)
@tool_api(explode_return=True)
def get_arxiv_article_information(self, query: str) -> dict:
"""Run Arxiv search and get the article meta information.
Args:
query (:class:`str`): the content of search query
Returns:
content (:class:`str`): a list of 3 arxiv search papers
:class:`dict`: article information
* content (str): a list of 3 arxiv search papers
"""
try:
results = arxiv.Search( # type: ignore
Expand Down
69 changes: 36 additions & 33 deletions lagent/actions/base_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def tool_api(func: Optional[Callable] = None,
*,
return_dict: bool = False,
explode_return: bool = False,
returns_named_value: bool = False,
**kwargs):
"""Turn functions into tools. It will parse typehints as well as docstrings
Expand Down Expand Up @@ -48,15 +48,13 @@ def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1):
Args:
func (Optional[Callable]): function to decorate. Defaults to ``None``.
return_dict (bool): suggest if returned data is a single dictionary.
When enabled, the returns sections in docstrings should indicate the
key-value infomation of the dictionary rather than hint a standard
tuple return. Defaults to ``False``.
explode_return (bool): whether to flatten the dictionary or tuple return
as the ``return_data`` field. When enabled, it is recommended to
annotate the member in docstrings. Defaults to ``False``.
.. code-block:: python
# set `return_dict` True will force `returns_named_value` to be enabled
@tool_api(return_dict=True)
@tool_api(explode_return=True)
def foo(a, b):
'''A simple function
Expand All @@ -65,8 +63,9 @@ def foo(a, b):
b (int): b
Returns:
x: the value of input a
y: the value of input b
dict: information of inputs
* x: value of a
* y: value of b
'''
return {'x': a, 'y': b}
Expand All @@ -75,14 +74,16 @@ def foo(a, b):
returns_named_value (bool): whether to parse ``thing: Description`` in
returns sections as a name and description, rather than a type and
description. When true, type must be wrapped in parentheses:
``(int): Description.``. When false, parentheses are optional but
``(int): Description``. When false, parentheses are optional but
the items cannot be named: ``int: Description``. Defaults to ``False``.
Important:
``return_data`` field will be added to ``api_description`` only
when ``explode_return`` or ``returns_named_value`` is enabled.
Returns:
Callable: wrapped function or partial decorator
"""
if return_dict:
returns_named_value = True

def _detect_type(string):
field_type = 'STRING'
Expand All @@ -97,6 +98,25 @@ def _detect_type(string):
field_type = 'BOOLEAN'
return field_type

def _explode(desc):
kvs = []
desc = '\nArgs:\n' + '\n'.join([
' ' + item.lstrip(' -+*#.')
for item in desc.split('\n')[1:] if item.strip()
])
docs = Docstring(desc).parse('google')
if not docs:
return kvs
if docs[0].kind is DocstringSectionKind.parameters:
for d in docs[0].value:
d = d.as_dict()
if not d['annotation']:
d.pop('annotation')
else:
d['type'] = _detect_type(d.pop('annotation').lower())
kvs.append(d)
return kvs

def _parse_tool(function):
# remove rst syntax
docs = Docstring(
Expand All @@ -114,13 +134,11 @@ def _parse_tool(function):
if doc.kind is DocstringSectionKind.parameters:
for d in doc.value:
d = d.as_dict()
d['description'] = d['description']
d['type'] = _detect_type(d.pop('annotation').lower())
args_doc[d['name']] = d
if doc.kind is DocstringSectionKind.returns:
for d in doc.value:
d = d.as_dict()
d['description'] = d['description']
if not d['name']:
d.pop('name')
if not d['annotation']:
Expand Down Expand Up @@ -154,26 +172,11 @@ def _parse_tool(function):
if param.default is inspect.Signature.empty:
desc['required'].append(param.name)

return_data, return_annotation = [], sig.return_annotation
if return_dict:
return_data = []
if explode_return:
return_data = _explode(returns_doc[0]['description'])
elif returns_named_value:
return_data = returns_doc
elif return_annotation is not inspect.Signature.empty:
if return_annotation is tuple:
return_data = returns_doc
elif get_origin(return_annotation) is tuple:
return_annotation = get_args(return_annotation)
if not return_annotation:
return_data = returns_doc
elif len(return_annotation) >= 2:
for i, item in enumerate(return_annotation):
info = returns_doc[i]['description'] if i < len(
returns_doc) else ''
if get_origin(item) is Annotated:
item, info = get_args(item)
return_data.append({
'description': info,
'type': _detect_type(str(item))
})
if return_data:
desc['return_data'] = return_data
return desc
Expand Down
22 changes: 13 additions & 9 deletions lagent/actions/bing_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self,
self.key = key
self.base_url = 'http://dev.virtualearth.net/REST/V1/'

@tool_api(return_dict=True)
@tool_api(explode_return=True)
def get_distance(self, start: str, end: str) -> dict:
"""Get the distance between two locations in km.
Expand All @@ -34,7 +34,8 @@ def get_distance(self, start: str, end: str) -> dict:
end (:class:`str`): The end location
Returns:
distance (:class:`str`): the distance in km.
:class:`dict`: distance information
* distance (str): the distance in km.
"""
# Request URL
url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
Expand All @@ -48,7 +49,7 @@ def get_distance(self, start: str, end: str) -> dict:
distance = route['travelDistance']
return dict(distance=distance)

@tool_api(return_dict=True)
@tool_api(explode_return=True)
def get_route(self, start: str, end: str) -> dict:
"""Get the route between two locations in km.
Expand All @@ -57,7 +58,8 @@ def get_route(self, start: str, end: str) -> dict:
end (:class:`str`): The end location
Returns:
route (:class:`list`): the route, a list of actions.
:class:`dict`: route information
* route (list): the route, a list of actions.
"""
# Request URL
url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
Expand All @@ -74,16 +76,17 @@ def get_route(self, start: str, end: str) -> dict:
route_text.append(item['instruction']['text'])
return dict(route=route_text)

@tool_api(return_dict=True)
@tool_api(explode_return=True)
def get_coordinates(self, location: str) -> dict:
"""Get the coordinates of a location.
Args:
location (:class:`str`): the location need to get coordinates.
Returns:
latitude (:class:`float`): the latitude of the location.
longitude (:class:`float`): the longitude of the location.
:class:`dict`: coordinates information
* latitude (float): the latitude of the location.
* longitude (float): the longitude of the location.
"""
url = self.base_url + 'Locations'
params = {'query': location, 'key': self.key}
Expand All @@ -93,7 +96,7 @@ def get_coordinates(self, location: str) -> dict:
'coordinates']
return dict(latitude=coordinates[0], longitude=coordinates[1])

@tool_api(return_dict=True)
@tool_api(explode_return=True)
def search_nearby(self,
search_term: str,
places: str = 'unknown',
Expand All @@ -112,7 +115,8 @@ def search_nearby(self,
radius (:class:`int`): radius in meters. Defaults to ``5000``.
Returns:
places (:class:`list`): the list of places, each place is a dict with name and address, at most 5 places.
:class:`dict`: places information
* places (list): the list of places, each place is a dict with name and address, at most 5 places.
"""
url = self.base_url + 'LocalSearch'
if places != 'unknown':
Expand Down
34 changes: 19 additions & 15 deletions lagent/actions/google_scholar_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self,
'as SERPER_API_KEY or pass it as `api_key` parameter.')
self.api_key = api_key

@tool_api(return_dict=True)
@tool_api(explode_return=True)
def search_google_scholar(
self,
query: str,
Expand Down Expand Up @@ -72,10 +72,11 @@ def search_google_scholar(
as_vis (Optional[str]): Defines whether to include citations or not.
Returns:
title: a list of the titles of the three selected papers
cited_by: a list of the citation numbers of the three selected papers
organic_id: a list of the organic results' ids of the three selected papers
pub_info: publication information of selected papers
:class:`dict`: article information
- title: a list of the titles of the three selected papers
- cited_by: a list of the citation numbers of the three selected papers
- organic_id: a list of the organic results' ids of the three selected papers
- pub_info: publication information of selected papers
"""
params = {
'q': query,
Expand Down Expand Up @@ -120,7 +121,7 @@ def search_google_scholar(
return ActionReturn(
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)

@tool_api(return_dict=True)
@tool_api(explode_return=True)
def get_author_information(self,
author_id: str,
hl: Optional[str] = None,
Expand All @@ -147,10 +148,11 @@ def get_author_information(self,
output (Optional[str]): Defines the final output you want. Default is 'json'.
Returns:
name: author's name
affliation: the affliation of the author
articles: at most 3 articles by the author
website: the author's homepage url
:class:`dict`: author information
* name: author's name
* affliation: the affliation of the author
* articles: at most 3 articles by the author
* website: the author's homepage url
"""
params = {
'engine': 'google_scholar_author',
Expand Down Expand Up @@ -183,7 +185,7 @@ def get_author_information(self,
return ActionReturn(
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)

@tool_api(return_dict=True)
@tool_api(explode_return=True)
def get_citation_format(self,
q: str,
no_cache: Optional[bool] = None,
Expand All @@ -198,8 +200,9 @@ def get_citation_format(self,
output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
Returns:
authors: the authors of the article
citation: the citation format of the article
:class:`dict`: citation format
* authors: the authors of the article
* citation: the citation format of the article
"""
params = {
'q': q,
Expand All @@ -219,7 +222,7 @@ def get_citation_format(self,
return ActionReturn(
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)

@tool_api(return_dict=True)
@tool_api(explode_return=True)
def get_author_id(self,
mauthors: str,
hl: Optional[str] = 'en',
Expand All @@ -240,7 +243,8 @@ def get_author_id(self,
output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
Returns:
author_id: the author_id of the author
:class:`dict`: author id
* author_id: the author_id of the author
"""
params = {
'mauthors': mauthors,
Expand Down
Loading

0 comments on commit 1b5fd99

Please sign in to comment.