From 3c1c7e38f8520f30943e8837db3e068c5c9a713b Mon Sep 17 00:00:00 2001 From: Guilherme Kunigami Date: Thu, 3 Oct 2024 12:20:17 -0700 Subject: [PATCH] handle errors in remote function server when get_throwOnError() is false (#11083) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/11083 Reviewed By: pedroerp Differential Revision: D63346758 --- velox/functions/remote/client/Remote.cpp | 20 +++++++ .../client/tests/RemoteFunctionTest.cpp | 57 +++++++++++++++++++ .../functions/remote/if/RemoteFunction.thrift | 3 + .../remote/server/RemoteFunctionService.cpp | 53 +++++++++++++++-- .../remote/server/RemoteFunctionService.h | 12 ++++ 5 files changed, 141 insertions(+), 4 deletions(-) diff --git a/velox/functions/remote/client/Remote.cpp b/velox/functions/remote/client/Remote.cpp index 1f88745aa73e..8458b84baaef 100644 --- a/velox/functions/remote/client/Remote.cpp +++ b/velox/functions/remote/client/Remote.cpp @@ -119,6 +119,26 @@ class RemoteFunction : public exec::VectorFunction { *context.pool(), serde_.get()); result = outputRowVector->childAt(0); + + if (auto errorPayload = remoteResponse.get_result().errorPayload()) { + auto errorsRowVector = IOBufToRowVector( + *errorPayload, ROW({VARCHAR()}), *context.pool(), serde_.get()); + auto errorsVector = + errorsRowVector->childAt(0)->asFlatVector(); + VELOX_CHECK(errorsVector, "Should be convertible to flat vector"); + + SelectivityVector selectedRows(errorsRowVector->size()); + selectedRows.applyToSelected([&](vector_size_t i) { + if (errorsVector->isNullAt(i)) { + return; + } + try { + throw std::runtime_error(errorsVector->valueAt(i)); + } catch (const std::exception& ex) { + context.setError(i, std::current_exception()); + } + }); + } } const std::string functionName_; diff --git a/velox/functions/remote/client/tests/RemoteFunctionTest.cpp b/velox/functions/remote/client/tests/RemoteFunctionTest.cpp index d49e8f8f695a..3965686d97d5 100644 --- a/velox/functions/remote/client/tests/RemoteFunctionTest.cpp +++ b/velox/functions/remote/client/tests/RemoteFunctionTest.cpp @@ -25,6 +25,7 @@ #include "velox/functions/Registerer.h" #include "velox/functions/lib/CheckedArithmetic.h" #include "velox/functions/prestosql/Arithmetic.h" +#include "velox/functions/prestosql/Fail.h" #include "velox/functions/prestosql/StringFunctions.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" #include "velox/functions/remote/client/Remote.h" @@ -62,6 +63,13 @@ class RemoteFunctionTest .build()}; registerRemoteFunction("remote_plus", plusSignatures, metadata); + auto failSignatures = {exec::FunctionSignatureBuilder() + .returnType("unknown") + .argumentType("integer") + .argumentType("varchar") + .build()}; + registerRemoteFunction("remote_fail", failSignatures, metadata); + RemoteVectorFunctionMetadata wrongMetadata = metadata; wrongMetadata.location = folly::SocketAddress(); // empty address. registerRemoteFunction("remote_wrong_port", plusSignatures, wrongMetadata); @@ -84,6 +92,8 @@ class RemoteFunctionTest // needed for tests since the thrift service runs in the same process. registerFunction( {remotePrefix_ + ".remote_plus"}); + registerFunction( + {remotePrefix_ + ".remote_fail"}); registerFunction( {remotePrefix_ + ".remote_divide"}); registerFunction( @@ -161,6 +171,53 @@ TEST_P(RemoteFunctionTest, string) { assertEqualVectors(expected, results); } +TEST_P(RemoteFunctionTest, tryException) { + // remote_divide throws if denominator is 0. + auto numeratorVector = makeFlatVector({0, 1, 4, 9, 16}); + auto denominatorVector = makeFlatVector({0, 1, 2, 3, 4}); + auto data = makeRowVector({numeratorVector, denominatorVector}); + auto results = + evaluate>("TRY(remote_divide(c0, c1))", data); + + ASSERT_EQ(results->size(), 5); + auto expected = makeFlatVector({0 /* doesn't matter*/, 1, 2, 3, 4}); + expected->setNull(0, true); + + assertEqualVectors(expected, results); +} + +TEST_P(RemoteFunctionTest, conditionalConjunction) { + // conditional conjunction disables throwing on error. + auto inputVector0 = makeFlatVector({true, true}); + auto inputVector1 = makeFlatVector({1, 2}); + auto data = makeRowVector({inputVector0, inputVector1}); + auto results = evaluate>( + "case when (c0 OR remote_fail(c1, 'error')) then 'hello' else 'world' end", + data); + + ASSERT_EQ(results->size(), 2); + auto expected = makeFlatVector({"hello", "hello"}); + assertEqualVectors(expected, results); +} + +TEST_P(RemoteFunctionTest, tryErrorCode) { + // remote_fail doesn't throw, but returns error code. + auto errorCodesVector = makeFlatVector({1, 2}); + auto errorMessagesVector = + makeFlatVector({"failed 1", "failed 2"}); + auto data = makeRowVector({errorCodesVector, errorMessagesVector}); + exec::ExprSet exprSet( + {makeTypedExpr("TRY(remote_fail(c0, c1))", asRowType(data->type()))}, + &execCtx_); + std::optional rows; + exec::EvalCtx context(&execCtx_, &exprSet, data.get()); + std::vector results(1); + SelectivityVector defaultRows(data->size()); + exprSet.eval(defaultRows, context, results); + + ASSERT_EQ(results[0]->size(), 2); +} + TEST_P(RemoteFunctionTest, connectionError) { auto inputVector = makeFlatVector({1, 2, 3, 4, 5}); auto func = [&]() { diff --git a/velox/functions/remote/if/RemoteFunction.thrift b/velox/functions/remote/if/RemoteFunction.thrift index edac31ef22e8..559f0192c155 100644 --- a/velox/functions/remote/if/RemoteFunction.thrift +++ b/velox/functions/remote/if/RemoteFunction.thrift @@ -52,6 +52,9 @@ struct RemoteFunctionPage { /// The number of logical rows in this page. 3: i64 rowCount; + + /// Serialized errors. + 4: optional IOBuf errorPayload; } /// The parameters passed to the remote thrift call. diff --git a/velox/functions/remote/server/RemoteFunctionService.cpp b/velox/functions/remote/server/RemoteFunctionService.cpp index c3310ccb2cd0..2cc7a0abac12 100644 --- a/velox/functions/remote/server/RemoteFunctionService.cpp +++ b/velox/functions/remote/server/RemoteFunctionService.cpp @@ -15,6 +15,7 @@ */ #include "velox/functions/remote/server/RemoteFunctionService.h" +#include "velox/common/base/Exceptions.h" #include "velox/expression/Expr.h" #include "velox/functions/remote/if/GetSerde.h" #include "velox/type/fbhive/HiveTypeParser.h" @@ -66,6 +67,45 @@ std::vector getExpressions( returnType, std::move(inputs), functionName)}; } +void RemoteFunctionServiceHandler::handleErrors( + apache::thrift::field_ref result, + exec::EvalErrors* evalErrors, + const std::unique_ptr& serde) const { + const std::int64_t numRows = result->get_rowCount(); + BufferPtr dataBuffer = + AlignedBuffer::allocate(numRows, pool_.get()); + + auto flatVector = std::make_shared>( + pool_.get(), + VARCHAR(), + nullptr, // null vectors + numRows, + std::move(dataBuffer), + std::vector{}); + + for (vector_size_t i = 0; i < numRows; ++i) { + if (evalErrors->hasErrorAt(i)) { + auto exceptionPtr = *evalErrors->errorAt(i); + try { + std::rethrow_exception(*exceptionPtr); + } catch (const std::exception& ex) { + flatVector->set(i, ex.what()); + } + } else { + flatVector->set(i, StringView()); + flatVector->setNull(i, true); + } + } + auto errorRowVector = std::make_shared( + pool_.get(), + ROW({VARCHAR()}), + BufferPtr(), + numRows, + std::vector{flatVector}); + result->errorPayload_ref() = + rowVectorToIOBuf(errorRowVector, *pool_, serde.get()); +} + void RemoteFunctionServiceHandler::invokeFunction( remote::RemoteFunctionResponse& response, std::unique_ptr request) { @@ -75,10 +115,6 @@ void RemoteFunctionServiceHandler::invokeFunction( LOG(INFO) << "Got a request for '" << functionHandle.get_name() << "': " << inputs.get_rowCount() << " input rows."; - if (!request->get_throwOnError()) { - VELOX_NYI("throwOnError not implemented yet on remote server."); - } - // Deserialize types and data. auto inputType = deserializeArgTypes(functionHandle.get_argumentTypes()); auto outputType = deserializeType(functionHandle.get_returnType()); @@ -102,7 +138,11 @@ void RemoteFunctionServiceHandler::invokeFunction( outputType, getFunctionName(functionPrefix_, functionHandle.get_name())), &execCtx}; + exec::EvalCtx evalCtx(&execCtx, &exprSet, inputVector.get()); + if (!request->get_throwOnError()) { + *evalCtx.mutableThrowOnError() = false; + } std::vector expressionResult; exprSet.eval(rows, evalCtx, expressionResult); @@ -116,6 +156,11 @@ void RemoteFunctionServiceHandler::invokeFunction( result->pageFormat_ref() = serdeFormat; result->payload_ref() = rowVectorToIOBuf(outputRowVector, rows.end(), *pool_, serde.get()); + + auto evalErrors = evalCtx.errors(); + if (evalErrors != nullptr && evalErrors->hasError()) { + handleErrors(result, evalErrors, serde); + } } } // namespace facebook::velox::functions diff --git a/velox/functions/remote/server/RemoteFunctionService.h b/velox/functions/remote/server/RemoteFunctionService.h index b033eb333b27..88a01c9198e9 100644 --- a/velox/functions/remote/server/RemoteFunctionService.h +++ b/velox/functions/remote/server/RemoteFunctionService.h @@ -19,6 +19,11 @@ #include #include "velox/common/memory/Memory.h" #include "velox/functions/remote/if/gen-cpp2/RemoteFunctionService.h" +#include "velox/vector/VectorStream.h" + +namespace facebook::velox::exec { +class EvalErrors; +} namespace facebook::velox::functions { @@ -35,6 +40,13 @@ class RemoteFunctionServiceHandler std::unique_ptr request) override; private: + /// Add evalErrors to result by serializing them to a vector of strings and + /// converting the result to a Velox flat vector. + void handleErrors( + apache::thrift::field_ref result, + exec::EvalErrors* evalErrors, + const std::unique_ptr& serde) const; + std::shared_ptr pool_{ memory::memoryManager()->addLeafPool()}; const std::string functionPrefix_;