Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-Tensor Input in Servo-Beam #10

Open
wants to merge 42 commits into
base: master
Choose a base branch
from

Conversation

meixinzhang
Copy link

@meixinzhang meixinzhang commented Jul 23, 2020

Internally uses Arrow RecordBatch for processing, supports multi-tensor input

  • separate public API calls for run inference on Examples and SequenceExamples
  • there is now an internal API supporting Pcoll of recordbatch as input
  • serialized examples maybe in bytes or string which are both supported
  • modified signature processing, now allows multiple inputs and we will extract the key/name to those inputs
  • when predict model requires more than 1 input, we will feed the model feature tensors that has the same name
  • since there is no clean way to extract tensor proto from a composite tensor, post-process will only include request values when the input tensor is dense, otherwise only input info such as type and name, and output is included

Integrate Arrow as internal processing container
@@ -383,14 +504,20 @@ def setup(self):
# user agent once custom header is supported in googleapiclient.
self._api_client = discovery.build('ml', 'v1')

def _extract_from_recordBatch(self, elements: pa.RecordBatch):
serialized_examples = bsl_util.ExtractSerializedExampleFromRecordBatch(elements)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this is the same in Batch and Remote DoFn. Maybe extract this out to Base, and only get model_input in _extract_from_recordBatch?

) -> Mapping[Text, np.ndarray]:
self._check_elements(elements)
outputs = self._run_tf_operations(elements)
self, tensors: Mapping[Any, Any]) -> Mapping[Text, np.ndarray]:
Copy link

@rose-rong-liu rose-rong-liu Aug 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment on what's expected in tensors. And is the Mapping key a Text?

self, elements: Mapping[Any, Any],
outputs: Mapping[Text, np.ndarray]
) -> Iterable[Tuple[Union[str, bytes], classification_pb2.Classifications]]:
serialized_examples, = elements.values()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove ','

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't give the right answer

self, elements: Mapping[Any, Any],
outputs: Mapping[Text, np.ndarray]
) -> Iterable[Tuple[Union[str, bytes], classification_pb2.Classifications]]:
serialized_examples, = elements.values()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is element.values serialized examples?

raise ValueError('Expected to have one name and one alias per tensor')

include_request = True
if len(input_tensor_names) == 1:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make the determination of single input string tensor in a internal utility function inside of BaseDoFn?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input tensor names is not in baseDoFn


include_request = True
if len(input_tensor_names) == 1:
serialized_examples, = elements.values()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also check the type of elements.values is string/bytes?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's checked in extract form record batch

Comment on lines 838 to 839
else:
input_tensor_proto.tensor_shape.dim.add().size = len(elements[tensor_name][0])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the dim size is len(elements[tensor_name][0]) instead of:
for s in elements[tensor_name][0].shape:
input_tensor_proto.tensor_shape.dim.add().size = s

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have an nd.array, I dont think we will have shape parameter

for alias, tensor_name in zip(input_tensor_alias, input_tensor_names):
input_tensor_proto = predict_log_tmpl.request.inputs[alias]
input_tensor_proto.dtype = tf.as_dtype(input_tensor_types[alias]).as_datatype_enum
if len(input_tensor_alias) == 1:
Copy link

@rose-rong-liu rose-rong-liu Aug 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the single input case be handled separately?

alias = input_tensor_alias[0]
predict_log.request.inputs[alias].string_val.append(process_elements[i])
else:
for alias, tensor_name in zip(input_tensor_alias, input_tensor_names):
Copy link

@rose-rong-liu rose-rong-liu Aug 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct given it's already in the loop of alias, tensor_name

) -> Iterable[Tuple[tf.train.Example, inference_pb2.MultiInferenceResponse]]:
self, elements: Mapping[Any, Any],
outputs: Mapping[Text, np.ndarray]
) -> Iterable[Tuple[Union[str, bytes], inference_pb2.MultiInferenceResponse]]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this just be bytes instead of Union[str, bytes] ?
str is the same as 'bytes' in py2.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to make sure it's compatible with py2


model_input = None
if (len(self._io_tensor_spec.input_tensor_names) == 1):
model_input = {self._io_tensor_spec.input_tensor_names[0]: serialized_examples}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just leave this in _BaseBatchsavedModelDoFn and move the rest to _BatchPredictDoFn?


Args:
examples: A PCollection containing examples.
inference_spec_type: Model inference endpoint.
Schema [optional]: required for models that requires

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mention this is only available for Predict method.


_KERAS_INPUT_SUFFIX = '_input'

def ExtractSerializedExampleFromRecordBatch(elements: pa.RecordBatch) -> List[Text]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExtractSerializedExamplesFromRecordBatch

def ExtractSerializedExampleFromRecordBatch(elements: pa.RecordBatch) -> List[Text]:
serialized_examples = None
for column_name, column_array in zip(elements.schema.names, elements.columns):
if column_name == _RECORDBATCH_COLUMN:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should _RECORDBATCH_COLUMN be passed an an argument to the API?

If we use a constant here, it would mean users would have to use this same constant when creating the TFXIO.


@beam.ptransform_fn
@beam.typehints.with_input_types(Union[tf.train.Example,
tf.train.SequenceExample])
@beam.typehints.with_input_types(tf.train.Example)
@beam.typehints.with_output_types(prediction_log_pb2.PredictionLog)
def RunInference( # pylint: disable=invalid-name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the long term plan to deprecate the tf.example API? And only have a record batch API?

If so, mention it in a comment

if prepare_instances_serialized:
return [{'b64': base64.b64encode(value).decode()} for value in df[_RECORDBATCH_COLUMN]]
else:
as_binary = df.columns.str.endswith("_bytes")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the name end with "_bytes"?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User specified byte columns, it's consistent with the original implementation

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is required by cloud ai platform to indicate the bytes feature with '_bytes' suffix.

@beam.typehints.with_output_types(prediction_log_pb2.PredictionLog)
def RunInferenceImpl( # pylint: disable=invalid-name
def RunInferenceOnExamples( # pylint: disable=invalid-name

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the first option of public API here to have a polymorphic RunInference and RunInferenceImpl.

@googlebot
Copy link

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@meixinzhang
Copy link
Author

@googlebot I signed it!

@googlebot
Copy link

CLAs look good, thanks!

ℹ️ Googlers: Go here for more info.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants