Skip to content

Commit

Permalink
Expose the generate_content in GeminiGcpWrapper. It is needed to supp…
Browse files Browse the repository at this point in the history
…ort arbitrary sequence of mixed texts and images.

PiperOrigin-RevId: 671916860
  • Loading branch information
The android_world Authors committed Sep 10, 2024
1 parent 4d94115 commit 108e34b
Showing 1 changed file with 47 additions and 2 deletions.
49 changes: 47 additions & 2 deletions android_world/agents/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
import google.generativeai as genai
from google.generativeai import types
from google.generativeai.types import answer_types
from google.generativeai.types import content_types
from google.generativeai.types import generation_types
from google.generativeai.types import safety_types
import numpy as np
from PIL import Image
import requests
Expand All @@ -32,9 +34,13 @@
ERROR_CALLING_LLM = 'Error calling LLM'


def _array_to_jpeg_bytes(image: np.ndarray) -> bytes:
def array_to_jpeg_bytes(image: np.ndarray) -> bytes:
"""Converts a numpy array into a byte string for a JPEG image."""
image = Image.fromarray(image)
return image_to_jpeg_bytes(image)


def image_to_jpeg_bytes(image: Image.Image) -> bytes:
in_mem_file = io.BytesIO()
image.save(in_mem_file, format='JPEG')
# Reset file pointer to start
Expand Down Expand Up @@ -176,6 +182,45 @@ def predict_mm(
return ERROR_CALLING_LLM, False, output
return ERROR_CALLING_LLM, None, None

def generate(
self,
contents: content_types.ContentsType,
safety_settings: safety_types.SafetySettingOptions | None = None,
generation_config: generation_types.GenerationConfigType | None = None,
) -> tuple[str, Any]:
"""Exposes the generate_content API.
Args:
contents: The input to the LLM.
safety_settings: Safety settings.
generation_config: Generation config.
Returns:
The output text and the raw response.
Raises:
RuntimeError:
"""
counter = self.max_retry
retry_delay = 1.0
response = None
while counter > 0:
try:
response = self.llm.generate_content(
contents=contents,
safety_settings=safety_settings,
generation_config=generation_config,
)
return response.as_text(), response
except Exception as e: # pylint: disable=broad-exception-caught
counter -= 1
print('Error calling LLM, will retry in {retry_delay} seconds')
print(e)
if counter > 0:
# Expo backoff
time.sleep(retry_delay)
retry_delay *= 2
raise RuntimeError(f'Error calling LLM. {response}.')


class Gpt4Wrapper(LlmWrapper, MultimodalLlmWrapper):
"""OpenAI GPT4 wrapper.
Expand Down Expand Up @@ -208,7 +253,7 @@ def __init__(

@classmethod
def encode_image(cls, image: np.ndarray) -> str:
return base64.b64encode(_array_to_jpeg_bytes(image)).decode('utf-8')
return base64.b64encode(array_to_jpeg_bytes(image)).decode('utf-8')

def predict(
self,
Expand Down

0 comments on commit 108e34b

Please sign in to comment.