Skip to content

Commit

Permalink
add docstring (#116)
Browse files Browse the repository at this point in the history
* BaseLLm generate docstring
  • Loading branch information
liujiangning30 authored Jan 30, 2024
1 parent 76a46c9 commit 94ba3a1
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 14 deletions.
25 changes: 21 additions & 4 deletions lagent/llms/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import traceback
import warnings
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from lagent.schema import AgentStatusCode
from .base_llm import BaseModel
Expand Down Expand Up @@ -93,20 +93,37 @@ def tokenize(self, inputs: str):

def generate(
self,
inputs: List[str],
do_sample=True,
inputs: Union[str, List[str]],
do_sample: bool = True,
**kwargs,
):
"""Return the chat completions in non-stream mode.
Args:
inputs (Union[str, List[str]]): input texts to be completed.
do_sample (bool): do sampling if enabled
Returns:
(a list of/batched) text/chat completion
"""
for chunk in self.stream_generate(inputs, do_sample, **kwargs):
response = chunk
return response

def stream_generate(
self,
inputs: List[str],
do_sample=True,
do_sample: bool = True,
**kwargs,
):
"""Return the chat completions in stream mode.
Args:
inputs (Union[str, List[str]]): input texts to be completed.
do_sample (bool): do sampling if enabled
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
try:
import torch
from torch import nn
Expand Down
59 changes: 49 additions & 10 deletions lagent/llms/lmdepoly_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def stream_chat(self,
sequence_end: bool = True,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in non-stream mode.
completions in stream mode.
Args:
session_id (int): the identical id of a session
Expand Down Expand Up @@ -206,10 +206,10 @@ class LMDeployPipeline(BaseModel):
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
tp (int):
pipeline_cfg (dict):
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
tp (int): tensor parallel
pipeline_cfg (dict): config of pipeline
"""

def __init__(self,
Expand All @@ -226,8 +226,18 @@ def __init__(self,

def generate(self,
inputs: Union[str, List[str]],
do_preprocess=None,
do_preprocess: bool = None,
**kwargs):
"""Return the chat completions in non-stream mode.
Args:
inputs (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
Returns:
(a list of/batched) text/chat completion
"""
batched = True
if isinstance(inputs, str):
inputs = [inputs]
Expand Down Expand Up @@ -262,11 +272,11 @@ class LMDeployServer(BaseModel):
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
server_name (str): host ip for serving
server_port (int): server port
tp (int):
tp (int): tensor parallel
log_level (str): set log level whose value among
[CRITICAL, ERROR, WARNING, INFO, DEBUG]
"""
Expand Down Expand Up @@ -300,6 +310,20 @@ def generate(self,
ignore_eos: bool = False,
timeout: int = 30,
**kwargs) -> List[str]:
"""Start a new round conversation of a session. Return the chat
completions in non-stream mode.
Args:
inputs (str, List[str]): user's prompt(s) in this round
session_id (int): the identical id of a session
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
ignore_eos (bool): indicator for ignoring eos
timeout (int): max time to wait for response
Returns:
(a list of/batched) text/chat completion
"""

batched = True
if isinstance(inputs, str):
inputs = [inputs]
Expand Down Expand Up @@ -337,7 +361,21 @@ def stream_chat(self,
ignore_eos: bool = False,
timeout: int = 30,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in stream mode.
Args:
session_id (int): the identical id of a session
inputs (List[dict]): user's inputs in this round conversation
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
stream (bool): return in a streaming format if enabled
ignore_eos (bool): indicator for ignoring eos
timeout (int): max time to wait for response
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
gen_params = self.update_gen_params(**kwargs)
prompt = self.template_parser(inputs)

Expand Down Expand Up @@ -374,7 +412,8 @@ class LMDeployClient(LMDeployServer):
Args:
path (str): The path to the model.
url (str):
url (str): communicating address 'http://<ip>:<port>' of
api_server
"""

def __init__(self, path: str, url: str, **kwargs):
Expand Down

0 comments on commit 94ba3a1

Please sign in to comment.