Skip to content

Commit

Permalink
Catch and ignore exceptions when loading checkpoints.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676067177
  • Loading branch information
The android_world Authors committed Sep 18, 2024
1 parent 532c54c commit 1b443ff
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions android_world/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,17 @@ def load(self, fields: list[str] | None = None) -> list[Episode]:
data = []
for filename in directories:
if filename.endswith('.pkl.gz'):
task_group_id = filename[:-7] # Remove ".pkl.gz" extension
task_group = self._load_task_group(task_group_id)
if fields is not None:
task_group = [
{field: episode[field] for field in fields}
for episode in task_group
]
data.extend(task_group)
try:
task_group_id = filename[:-7] # Remove ".pkl.gz" extension
task_group = self._load_task_group(task_group_id)
if fields is not None:
task_group = [
{field: episode[field] for field in fields}
for episode in task_group
]
data.extend(task_group)
except Exception as e:
print(e)
return data

def _load_task_group(self, task_group_id: str) -> list[Episode]:
Expand Down

0 comments on commit 1b443ff

Please sign in to comment.