From 66171246b3ff23ac5db4c74d88ff538b06233f47 Mon Sep 17 00:00:00 2001 From: Chris Rawles Date: Thu, 12 Sep 2024 11:55:26 -0700 Subject: [PATCH] Add option to use UIAutomator dump accessibility tree information. PiperOrigin-RevId: 673942389 --- android_world/env/adb_utils.py | 11 +++ android_world/env/android_world_controller.py | 75 ++++++++++++++----- android_world/env/env_launcher.py | 11 ++- android_world/env/representation_utils.py | 56 ++++++++++++++ 4 files changed, 132 insertions(+), 21 deletions(-) diff --git a/android_world/env/adb_utils.py b/android_world/env/adb_utils.py index 9eaa5ce..29afc82 100644 --- a/android_world/env/adb_utils.py +++ b/android_world/env/adb_utils.py @@ -1631,3 +1631,14 @@ def set_root_if_needed( return response return issue_generic_request(['root'], env, timeout_sec) + + +def uiautomator_dump(env) -> str: + """Issues a uiautomator dump request and returns the UI hierarchy.""" + dump_args = 'shell uiautomator dump /sdcard/window_dump.xml' + issue_generic_request(dump_args, env) + + read_args = 'shell cat /sdcard/window_dump.xml' + response = issue_generic_request(read_args, env) + + return response.generic.output.decode('utf-8') diff --git a/android_world/env/android_world_controller.py b/android_world/env/android_world_controller.py index 0f4f798..0901b68 100644 --- a/android_world/env/android_world_controller.py +++ b/android_world/env/android_world_controller.py @@ -15,6 +15,7 @@ """Controller for Android that adds UI tree information to the observation.""" import contextlib +import enum import os import time from typing import Any @@ -121,6 +122,16 @@ def get_a11y_tree( OBSERVATION_KEY_UI_ELEMENTS = 'ui_elements' +class A11yMethod(enum.Enum): + """Method to get a11y tree.""" + + # Custom gRPC wrapper that uses a11y forwarder app. + A11Y_FORWARDER_APP = 'a11y_forwarder_app' + + # From `uiautomator dump``. + UIAUTOMATOR = 'uiautomator' + + class AndroidWorldController(base_wrapper.BaseWrapper): """Controller for an Android instance that adds accessibility tree data. @@ -131,26 +142,23 @@ class AndroidWorldController(base_wrapper.BaseWrapper): element. """ - def __init__(self, env: env_interface.AndroidEnvInterface): - self._env = a11y_grpc_wrapper.A11yGrpcWrapper( - env, - install_a11y_forwarding=True, - start_a11y_service=True, - enable_a11y_tree_info=True, - latest_a11y_info_only=True, - ) - self._env.reset() # Initializes required server services in a11y wrapper. - - def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: - """Adds a11y tree info to the observation.""" - forest = self.get_a11y_forest() - ui_elements = representation_utils.forest_to_ui_elements( - forest, - exclude_invisible_elements=True, - ) - timestep.observation[OBSERVATION_KEY_FOREST] = forest - timestep.observation[OBSERVATION_KEY_UI_ELEMENTS] = ui_elements - return timestep + def __init__( + self, + env: env_interface.AndroidEnvInterface, + a11y_method: A11yMethod = A11yMethod.A11Y_FORWARDER_APP, + ): + if a11y_method == A11yMethod.A11Y_FORWARDER_APP: + self._env = a11y_grpc_wrapper.A11yGrpcWrapper( + env, + install_a11y_forwarding=True, + start_a11y_service=True, + enable_a11y_tree_info=True, + latest_a11y_info_only=True, + ) + self._env.reset() # Initializes required server services in a11y wrapper. + else: + self._env = env + self._a11y_method = a11y_method @property def device_screen_size(self) -> tuple[int, int]: @@ -196,6 +204,33 @@ def get_a11y_forest( self.refresh_env() return get_a11y_tree(self._env) + def get_ui_elements(self) -> list[representation_utils.UIElement]: + """Returns the most recent UI elements from the device.""" + if self._a11y_method == A11yMethod.A11Y_FORWARDER_APP: + return representation_utils.forest_to_ui_elements( + self.get_a11y_forest(), + exclude_invisible_elements=True, + ) + else: + return representation_utils.xml_dump_to_ui_elements( + adb_utils.uiautomator_dump(self._env) + ) + + def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: + """Adds a11y tree info to the observation.""" + if self._a11y_method == A11yMethod.A11Y_FORWARDER_APP: + forest = self.get_a11y_forest() + ui_elements = representation_utils.forest_to_ui_elements( + forest, + exclude_invisible_elements=True, + ) + else: + forest = None + ui_elements = self.get_ui_elements() + timestep.observation[OBSERVATION_KEY_FOREST] = forest + timestep.observation[OBSERVATION_KEY_UI_ELEMENTS] = ui_elements + return timestep + def pull_file( self, remote_db_file_path: str, timeout_sec: Optional[float] = None ) -> contextlib._GeneratorContextManager[str]: diff --git a/android_world/env/env_launcher.py b/android_world/env/env_launcher.py index c91e420..eaf7b30 100644 --- a/android_world/env/env_launcher.py +++ b/android_world/env/env_launcher.py @@ -15,6 +15,7 @@ """Launches the environment used in the benchmark.""" import resource +from typing import Optional from absl import logging from android_world.env import adb_utils @@ -97,6 +98,9 @@ def load_and_setup_env( freeze_datetime: bool = True, adb_path: str = android_world_controller.DEFAULT_ADB_PATH, grpc_port: int = 8554, + controller: Optional[ + android_world_controller.AndroidWorldController + ] = None, ) -> interface.AsyncEnv: """Create environment with `get_env()` and perform env setup and validation. @@ -116,10 +120,15 @@ def load_and_setup_env( 2023, to ensure consistent benchmarking. adb_path: The location of the adb binary. grpc_port: The port for gRPC communication with the emulator. + controller: The controller to use. If None, a new controller will be + created. Returns: An interactable Android environment. """ - env = _get_env(console_port, adb_path, grpc_port) + if controller is None: + env = _get_env(console_port, adb_path, grpc_port) + else: + env = interface.AsyncAndroidEnv(controller) setup_env(env, emulator_setup, freeze_datetime) return env diff --git a/android_world/env/representation_utils.py b/android_world/env/representation_utils.py index af80aa4..c465d5b 100644 --- a/android_world/env/representation_utils.py +++ b/android_world/env/representation_utils.py @@ -16,6 +16,7 @@ import dataclasses from typing import Any, Optional +import xml.etree.ElementTree as ET @dataclasses.dataclass @@ -158,3 +159,58 @@ def forest_to_ui_elements( else: elements.append(_accessibility_node_to_ui_element(node, screen_size)) return elements + + +def _parse_ui_hierarchy(xml_string: str) -> dict[str, Any]: + """Parses the UI hierarchy XML into a dictionary structure.""" + root = ET.fromstring(xml_string) + + def parse_node(node): + result = node.attrib + result['children'] = [parse_node(child) for child in node] + return result + + return parse_node(root) + + +def xml_dump_to_ui_elements(xml_string: str) -> list[UIElement]: + """Converts a UI hierarchy XML dump from uiautomator dump to UIElements.""" + parsed_hierarchy = _parse_ui_hierarchy(xml_string) + ui_elements = [] + + def process_node(node, is_root): + bounds = node.get('bounds') + if bounds: + x_min, y_min, x_max, y_max = map( + int, bounds.strip('[]').replace('][', ',').split(',') + ) + bbox = BoundingBox(x_min, x_max, y_min, y_max) + else: + bbox = None + + ui_element = UIElement( + text=node.get('text'), + content_description=node.get('content-desc'), + class_name=node.get('class'), + bbox=bbox, + bbox_pixels=bbox, + is_checked=node.get('checked') == 'true', + is_checkable=node.get('checkable') == 'true', + is_clickable=node.get('clickable') == 'true', + is_enabled=node.get('enabled') == 'true', + is_focused=node.get('focused') == 'true', + is_focusable=node.get('focusable') == 'true', + is_long_clickable=node.get('long-clickable') == 'true', + is_scrollable=node.get('scrollable') == 'true', + is_selected=node.get('selected') == 'true', + package_name=node.get('package'), + resource_id=node.get('resource-id'), + ) + if not is_root: + ui_elements.append(ui_element) + + for child in node.get('children', []): + process_node(child, is_root=False) + + process_node(parsed_hierarchy, is_root=True) + return ui_elements