Skip to content

Commit

Permalink
Replacing as_text() by text for the response from generate_content.
Browse files Browse the repository at this point in the history
Fix an input conversion issue of GeminiGcpWrapper.

PiperOrigin-RevId: 676004238
  • Loading branch information
The android_world Authors committed Sep 18, 2024
1 parent 392f077 commit 551e9e4
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions android_world/agents/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ def predict_mm(

def generate(
self,
contents: content_types.ContentsType,
contents: (
content_types.ContentsType | list[str | np.ndarray | Image.Image]
),
safety_settings: safety_types.SafetySettingOptions | None = None,
generation_config: generation_types.GenerationConfigType | None = None,
) -> tuple[str, Any]:
Expand All @@ -203,14 +205,16 @@ def generate(
counter = self.max_retry
retry_delay = 1.0
response = None
if isinstance(contents, list):
contents = self.convert_content(contents)
while counter > 0:
try:
response = self.llm.generate_content(
contents=contents,
safety_settings=safety_settings,
generation_config=generation_config,
)
return response.as_text(), response
return response.text, response
except Exception as e: # pylint: disable=broad-exception-caught
counter -= 1
print('Error calling LLM, will retry in {retry_delay} seconds')
Expand All @@ -221,6 +225,21 @@ def generate(
retry_delay *= 2
raise RuntimeError(f'Error calling LLM. {response}.')

def convert_content(
self,
contents: list[str | np.ndarray | Image.Image],
) -> content_types.ContentsType:
"""Converts a list of contents to a ContentsType."""
converted = []
for item in contents:
if isinstance(item, str):
converted.append(item)
elif isinstance(item, np.ndarray):
converted.append(Image.fromarray(item))
elif isinstance(item, Image.Image):
converted.append(item)
return converted


class Gpt4Wrapper(LlmWrapper, MultimodalLlmWrapper):
"""OpenAI GPT4 wrapper.
Expand Down

0 comments on commit 551e9e4

Please sign in to comment.