Skip to content

Commit

Permalink
Read check-pointed data in same order that they are run in.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671381606
  • Loading branch information
Chris Rawles authored and The android_world Authors committed Sep 5, 2024
1 parent 90abf85 commit 4d94115
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
7 changes: 6 additions & 1 deletion android_world/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from absl import logging

INSTANCE_SEPARATOR = '_'

Episode = dict[str, Any]

Expand Down Expand Up @@ -108,8 +109,12 @@ def save_episodes(self, task_episodes: list[Episode], task_name: str):

def load(self, fields: list[str] | None = None) -> list[Episode]:
"""Loads all task groups from disk."""
# Keep same order as runtime.
directories = os.listdir(self.directory)
directories.sort(key=lambda x: x.split(INSTANCE_SEPARATOR)[0])

data = []
for filename in os.listdir(self.directory):
for filename in directories:
if filename.endswith('.pkl.gz'):
task_group_id = filename[:-7] # Remove ".pkl.gz" extension
task_group = self._load_task_group(task_group_id)
Expand Down
18 changes: 16 additions & 2 deletions android_world/suite_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,21 @@ def _run_task(


def _get_task_info(episodes: list[dict[str, Any]]) -> tuple[set[str], set[str]]:
"""Gets task info from episodes.
Args:
episodes: Episodes to get info from.
Returns:
A tuple of completed and failed task names.
"""
completed, failed = [], []
for episode in episodes:
instance_name = f'{episode[constants.EpisodeConstants.TASK_TEMPLATE]}_{episode[constants.EpisodeConstants.INSTANCE_ID]}'
instance_name = (
episode[constants.EpisodeConstants.TASK_TEMPLATE]
+ checkpointer_lib.INSTANCE_SEPARATOR
+ str(episode[constants.EpisodeConstants.INSTANCE_ID])
)
if episode.get(constants.EpisodeConstants.EXCEPTION_INFO) is not None:
failed.append(instance_name)
else:
Expand Down Expand Up @@ -323,7 +335,9 @@ def _run_task_suite(
print(msg + '\n' + '=' * len(msg))

for i, instance in enumerate(instances):
instance_name = f'{instance.name}_{i}'
instance_name = (
instance.name + checkpointer_lib.INSTANCE_SEPARATOR + str(i)
)
already_processed = (
instance_name in completed_tasks and instance_name not in failed_tasks
)
Expand Down
4 changes: 2 additions & 2 deletions android_world/task_evals/robustness_study/screen_variation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ def initialize_task(self, env: interface.AsyncEnv):
super().initialize_task(env)
# Go back to home screen with a reset.
env.reset(True)
adb_utils.set_screen_size(self.width, self.height, env.base_env)
adb_utils.set_screen_size(self.width, self.height, env.controller)
# It has been observed that without this pause, the following orientation
# change will not work.
time.sleep(2)
# Task starts from the home screen and the following orientation change
# will take effect for the next app opened but expired after closing.
adb_utils.change_orientation(self.orientation, env.base_env)
adb_utils.change_orientation(self.orientation, env.controller)

@property
def name(self) -> str:
Expand Down

0 comments on commit 4d94115

Please sign in to comment.