From 264d702310b9b9403f3ca0728de04a7c75556746 Mon Sep 17 00:00:00 2001 From: Guilherme Kunigami Date: Fri, 27 Sep 2024 12:23:21 -0700 Subject: [PATCH] handle errors in remote function server when get_throwOnError() is false (#11083) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/11083 Differential Revision: D63346758 --- velox/functions/remote/client/Remote.cpp | 22 ++++++++ .../client/tests/RemoteFunctionTest.cpp | 42 +++++++++++++++ .../functions/remote/if/RemoteFunction.thrift | 3 ++ .../remote/server/RemoteFunctionService.cpp | 54 +++++++++++++++++-- .../remote/server/RemoteFunctionService.h | 7 +++ 5 files changed, 124 insertions(+), 4 deletions(-) diff --git a/velox/functions/remote/client/Remote.cpp b/velox/functions/remote/client/Remote.cpp index 1f88745aa73e..defc5d660194 100644 --- a/velox/functions/remote/client/Remote.cpp +++ b/velox/functions/remote/client/Remote.cpp @@ -119,6 +119,28 @@ class RemoteFunction : public exec::VectorFunction { *context.pool(), serde_.get()); result = outputRowVector->childAt(0); + + if (auto errorPayload = remoteResponse.get_result().errorPayload()) { + auto errorsRowVector = IOBufToRowVector( + *errorPayload, + velox::ROW({velox::VARCHAR()}), + *context.pool(), + serde_.get()); + auto errorsVector = + errorsRowVector->childAt(0)->asFlatVector(); + + velox::SelectivityVector selectedRows(errorsRowVector->size()); + selectedRows.applyToSelected([&](velox::vector_size_t i) { + if (errorsVector->isNullAt(i)) { + return; + } + try { + throw std::runtime_error(errorsVector->valueAt(i)); + } catch (const std::exception& ex) { + context.setError(i, std::current_exception()); + } + }); + } } const std::string functionName_; diff --git a/velox/functions/remote/client/tests/RemoteFunctionTest.cpp b/velox/functions/remote/client/tests/RemoteFunctionTest.cpp index d49e8f8f695a..0d3d37471399 100644 --- a/velox/functions/remote/client/tests/RemoteFunctionTest.cpp +++ b/velox/functions/remote/client/tests/RemoteFunctionTest.cpp @@ -37,6 +37,18 @@ using ::facebook::velox::test::assertEqualVectors; namespace facebook::velox::functions { namespace { +template +struct FailFunction { + template + FOLLY_ALWAYS_INLINE void + call(TInput& result, const TInput& a, const TInput& b) { + if (a == 0) { + throw std::runtime_error("Failure"); + } + result = b / a; + } +}; + // Parametrize in the serialization format so we can test both presto page and // unsafe row. class RemoteFunctionTest @@ -62,6 +74,13 @@ class RemoteFunctionTest .build()}; registerRemoteFunction("remote_plus", plusSignatures, metadata); + auto failSignatures = {exec::FunctionSignatureBuilder() + .returnType("bigint") + .argumentType("bigint") + .argumentType("bigint") + .build()}; + registerRemoteFunction("remote_fail", failSignatures, metadata); + RemoteVectorFunctionMetadata wrongMetadata = metadata; wrongMetadata.location = folly::SocketAddress(); // empty address. registerRemoteFunction("remote_wrong_port", plusSignatures, wrongMetadata); @@ -84,6 +103,8 @@ class RemoteFunctionTest // needed for tests since the thrift service runs in the same process. registerFunction( {remotePrefix_ + ".remote_plus"}); + registerFunction( + {remotePrefix_ + ".remote_fail"}); registerFunction( {remotePrefix_ + ".remote_divide"}); registerFunction( @@ -161,6 +182,27 @@ TEST_P(RemoteFunctionTest, string) { assertEqualVectors(expected, results); } +TEST_P(RemoteFunctionTest, fail) { + auto inputVector = makeFlatVector({1, 2, 0, 4, 5}); + auto data = makeRowVector({inputVector}); + exec::ExprSet exprSet( + {makeTypedExpr("TRY(remote_fail(c0, c0))", asRowType(data->type()))}, + &execCtx_); + std::optional rows; + exec::EvalCtx context(&execCtx_, &exprSet, data.get()); + std::vector result(1); + SelectivityVector defaultRows(data->size()); + + exprSet.eval(defaultRows, context, result); + // FIXME: why is that context has no errors here? + + ASSERT_EQ(result[0]->size(), 5); + auto expected = makeFlatVector({1, 1, 1, 1, 1}); + expected->setNull(2, true); + + assertEqualVectors(expected, result[0]); +} + TEST_P(RemoteFunctionTest, connectionError) { auto inputVector = makeFlatVector({1, 2, 3, 4, 5}); auto func = [&]() { diff --git a/velox/functions/remote/if/RemoteFunction.thrift b/velox/functions/remote/if/RemoteFunction.thrift index edac31ef22e8..559f0192c155 100644 --- a/velox/functions/remote/if/RemoteFunction.thrift +++ b/velox/functions/remote/if/RemoteFunction.thrift @@ -52,6 +52,9 @@ struct RemoteFunctionPage { /// The number of logical rows in this page. 3: i64 rowCount; + + /// Serialized errors. + 4: optional IOBuf errorPayload; } /// The parameters passed to the remote thrift call. diff --git a/velox/functions/remote/server/RemoteFunctionService.cpp b/velox/functions/remote/server/RemoteFunctionService.cpp index c3310ccb2cd0..81b4cdf19b3a 100644 --- a/velox/functions/remote/server/RemoteFunctionService.cpp +++ b/velox/functions/remote/server/RemoteFunctionService.cpp @@ -15,6 +15,7 @@ */ #include "velox/functions/remote/server/RemoteFunctionService.h" +#include "velox/common/base/Exceptions.h" #include "velox/expression/Expr.h" #include "velox/functions/remote/if/GetSerde.h" #include "velox/type/fbhive/HiveTypeParser.h" @@ -66,6 +67,46 @@ std::vector getExpressions( returnType, std::move(inputs), functionName)}; } +void RemoteFunctionServiceHandler::handleErrors( + apache::thrift::field_ref result, + exec::EvalErrors* evalErrors, + const std::unique_ptr& serde) const { + const std::int64_t numRows = result->get_rowCount(); + BufferPtr dataBuffer = + AlignedBuffer::allocate(numRows, pool_.get()); + + auto flatVector = std::make_shared>( + pool_.get(), + velox::VARCHAR(), + nullptr, // null vectors + numRows, + std::move(dataBuffer), + std::vector{}); + + for (vector_size_t i = 0; i < numRows; ++i) { + if (evalErrors->hasErrorAt(i)) { + auto exceptionPtr = *evalErrors->errorAt(i); + try { + std::rethrow_exception(*exceptionPtr); + } catch (const std::exception& ex) { + flatVector->set(i, ex.what()); + } + } else { + flatVector->set(i, velox::StringView()); + flatVector->setNull(i, true); + } + } + std::vector errorsVector{flatVector}; + auto errorRowVector = std::make_shared( + pool_.get(), + velox::ROW({velox::VARCHAR()}), + BufferPtr(), + numRows, + errorsVector); + result->errorPayload_ref() = + rowVectorToIOBuf(errorRowVector, *pool_, serde.get()); +} + void RemoteFunctionServiceHandler::invokeFunction( remote::RemoteFunctionResponse& response, std::unique_ptr request) { @@ -75,10 +116,6 @@ void RemoteFunctionServiceHandler::invokeFunction( LOG(INFO) << "Got a request for '" << functionHandle.get_name() << "': " << inputs.get_rowCount() << " input rows."; - if (!request->get_throwOnError()) { - VELOX_NYI("throwOnError not implemented yet on remote server."); - } - // Deserialize types and data. auto inputType = deserializeArgTypes(functionHandle.get_argumentTypes()); auto outputType = deserializeType(functionHandle.get_returnType()); @@ -102,7 +139,11 @@ void RemoteFunctionServiceHandler::invokeFunction( outputType, getFunctionName(functionPrefix_, functionHandle.get_name())), &execCtx}; + exec::EvalCtx evalCtx(&execCtx, &exprSet, inputVector.get()); + if (!request->get_throwOnError()) { + *evalCtx.mutableThrowOnError() = false; + } std::vector expressionResult; exprSet.eval(rows, evalCtx, expressionResult); @@ -116,6 +157,11 @@ void RemoteFunctionServiceHandler::invokeFunction( result->pageFormat_ref() = serdeFormat; result->payload_ref() = rowVectorToIOBuf(outputRowVector, rows.end(), *pool_, serde.get()); + + auto evalErrors = evalCtx.errors(); + if (evalErrors != nullptr && evalErrors->hasError()) { + handleErrors(result, evalErrors, serde); + } } } // namespace facebook::velox::functions diff --git a/velox/functions/remote/server/RemoteFunctionService.h b/velox/functions/remote/server/RemoteFunctionService.h index b033eb333b27..52778c83cd0f 100644 --- a/velox/functions/remote/server/RemoteFunctionService.h +++ b/velox/functions/remote/server/RemoteFunctionService.h @@ -18,7 +18,9 @@ #include #include "velox/common/memory/Memory.h" +#include "velox/expression/EvalCtx.h" #include "velox/functions/remote/if/gen-cpp2/RemoteFunctionService.h" +#include "velox/vector/VectorStream.h" namespace facebook::velox::functions { @@ -35,6 +37,11 @@ class RemoteFunctionServiceHandler std::unique_ptr request) override; private: + void handleErrors( + apache::thrift::field_ref result, + exec::EvalErrors* evalErrors, + const std::unique_ptr& serde) const; + std::shared_ptr pool_{ memory::memoryManager()->addLeafPool()}; const std::string functionPrefix_;