Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force BQIO to output elements in the correct row #32584

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/trigger_files/beam_PostCommit_Python.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run.",
"modification": 2
"modification": 3
}

Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run"
"comment": "Modify this file in a trivial way to cause this test suite to run",
"modification": 1
}
55 changes: 30 additions & 25 deletions sdks/python/apache_beam/io/gcp/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ def chain_after(result):
from apache_beam.transforms.sideinputs import SIDE_INPUT_PREFIX
from apache_beam.transforms.sideinputs import get_sideinput_index
from apache_beam.transforms.util import ReshufflePerKey
from apache_beam.transforms.window import GlobalWindows
from apache_beam.typehints.row_type import RowTypeConstraint
from apache_beam.typehints.schemas import schema_from_element_type
from apache_beam.utils import retry
Expand Down Expand Up @@ -1581,7 +1580,8 @@ def _create_table_if_needed(self, table_reference, schema=None):
additional_create_parameters=self.additional_bq_parameters)
_KNOWN_TABLES.add(str_table_reference)

def process(self, element, *schema_side_inputs):
def process(
self, element, window_value=DoFn.WindowedValueParam, *schema_side_inputs):
destination = bigquery_tools.get_hashable_destination(element[0])

if callable(self.schema):
Expand All @@ -1608,12 +1608,11 @@ def process(self, element, *schema_side_inputs):
return [
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS,
GlobalWindows.windowed_value(
window_value.with_value(
(destination, row_and_insert_id[0], error))),
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS,
GlobalWindows.windowed_value(
(destination, row_and_insert_id[0])))
window_value.with_value((destination, row_and_insert_id[0])))
]

# Flush current batch first if adding this row will exceed our limits
Expand All @@ -1624,19 +1623,20 @@ def process(self, element, *schema_side_inputs):
flushed_batch = self._flush_batch(destination)
# After flushing our existing batch, we now buffer the current row
# for the next flush
self._rows_buffer[destination].append(row_and_insert_id)
self._rows_buffer[destination].append((row_and_insert_id, window_value))
self._destination_buffer_byte_size[destination] = row_byte_size
return flushed_batch

self._rows_buffer[destination].append(row_and_insert_id)
self._rows_buffer[destination].append((row_and_insert_id, window_value))
self._destination_buffer_byte_size[destination] += row_byte_size
self._total_buffered_rows += 1
if self._total_buffered_rows >= self._max_buffered_rows:
return self._flush_all_batches()
else:
# The input is already batched per destination, flush the rows now.
batched_rows = element[1]
self._rows_buffer[destination].extend(batched_rows)
for r in batched_rows:
self._rows_buffer[destination].append((r, window_value))
return self._flush_batch(destination)

def finish_bundle(self):
Expand All @@ -1659,7 +1659,7 @@ def _flush_all_batches(self):
def _flush_batch(self, destination):

# Flush the current batch of rows to BigQuery.
rows_and_insert_ids = self._rows_buffer[destination]
rows_and_insert_ids_with_windows = self._rows_buffer[destination]
table_reference = bigquery_tools.parse_table_reference(destination)
if table_reference.projectId is None:
table_reference.projectId = vp.RuntimeValueProvider.get_value(
Expand All @@ -1668,9 +1668,11 @@ def _flush_batch(self, destination):
_LOGGER.debug(
'Flushing data to %s. Total %s rows.',
destination,
len(rows_and_insert_ids))
self.batch_size_metric.update(len(rows_and_insert_ids))
len(rows_and_insert_ids_with_windows))
self.batch_size_metric.update(len(rows_and_insert_ids_with_windows))

rows_and_insert_ids = [r[0] for r in rows_and_insert_ids_with_windows]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the "zip to transpose" trick here.

rows_and_insert_ids, window_values = zip(*rows_and_insert_ids_with_windows)

window_values = [r[1] for r in rows_and_insert_ids_with_windows]
rows = [r[0] for r in rows_and_insert_ids]
if self.ignore_insert_ids:
insert_ids = [None for r in rows_and_insert_ids]
Expand All @@ -1689,7 +1691,8 @@ def _flush_batch(self, destination):
ignore_unknown_values=self.ignore_unknown_columns)
self.batch_latency_metric.update((time.time() - start) * 1000)

failed_rows = [(rows[entry['index']], entry["errors"])
failed_rows = [(
rows[entry['index']], entry["errors"], window_values[entry['index']])
for entry in errors]
retry_backoff = next(self._backoff_calculator, None)

Expand Down Expand Up @@ -1729,19 +1732,21 @@ def _flush_batch(self, destination):
if destination in self._destination_buffer_byte_size:
del self._destination_buffer_byte_size[destination]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the retry loop above, with

rows = [fr[0] for fr in failed_rows]

do we also need to update window_values so the indices line up? (Looking at this, shouldn't insert_ids have been updated as well?)

return itertools.chain([
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS,
GlobalWindows.windowed_value((destination, row, err))) for row,
err in failed_rows
],
[
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS,
GlobalWindows.windowed_value(
(destination, row))) for row,
unused_err in failed_rows
])
return itertools.chain(
[
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS,
w.with_value((destination, row, err))) for row,
err,
w in failed_rows
],
[
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS, w.with_value((destination, row)))
for row,
unused_err,
w in failed_rows
])


# The number of shards per destination when writing via streaming inserts.
Expand Down
1 change: 1 addition & 0 deletions sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ def test_big_query_write_insert_non_transient_api_call_error(self):
# pylint: disable=expression-not-assigned
errors = (
p | 'create' >> beam.Create(input_data)
| beam.WindowInto(beam.transforms.window.FixedWindows(10))
| 'write' >> beam.io.WriteToBigQuery(
table_id,
schema=table_schema,
Expand Down
Loading