Skip to content

Commit

Permalink
handle errors in remote function server when get_throwOnError() is fa…
Browse files Browse the repository at this point in the history
…lse (#11083)

Summary: Pull Request resolved: #11083

Reviewed By: pedroerp

Differential Revision: D63346758
  • Loading branch information
Guilherme Kunigami authored and facebook-github-bot committed Oct 3, 2024
1 parent 8369142 commit 3c1c7e3
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 4 deletions.
20 changes: 20 additions & 0 deletions velox/functions/remote/client/Remote.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,26 @@ class RemoteFunction : public exec::VectorFunction {
*context.pool(),
serde_.get());
result = outputRowVector->childAt(0);

if (auto errorPayload = remoteResponse.get_result().errorPayload()) {
auto errorsRowVector = IOBufToRowVector(
*errorPayload, ROW({VARCHAR()}), *context.pool(), serde_.get());
auto errorsVector =
errorsRowVector->childAt(0)->asFlatVector<StringView>();
VELOX_CHECK(errorsVector, "Should be convertible to flat vector");

SelectivityVector selectedRows(errorsRowVector->size());
selectedRows.applyToSelected([&](vector_size_t i) {
if (errorsVector->isNullAt(i)) {
return;
}
try {
throw std::runtime_error(errorsVector->valueAt(i));
} catch (const std::exception& ex) {
context.setError(i, std::current_exception());
}
});
}
}

const std::string functionName_;
Expand Down
57 changes: 57 additions & 0 deletions velox/functions/remote/client/tests/RemoteFunctionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "velox/functions/Registerer.h"
#include "velox/functions/lib/CheckedArithmetic.h"
#include "velox/functions/prestosql/Arithmetic.h"
#include "velox/functions/prestosql/Fail.h"
#include "velox/functions/prestosql/StringFunctions.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/remote/client/Remote.h"
Expand Down Expand Up @@ -62,6 +63,13 @@ class RemoteFunctionTest
.build()};
registerRemoteFunction("remote_plus", plusSignatures, metadata);

auto failSignatures = {exec::FunctionSignatureBuilder()
.returnType("unknown")
.argumentType("integer")
.argumentType("varchar")
.build()};
registerRemoteFunction("remote_fail", failSignatures, metadata);

RemoteVectorFunctionMetadata wrongMetadata = metadata;
wrongMetadata.location = folly::SocketAddress(); // empty address.
registerRemoteFunction("remote_wrong_port", plusSignatures, wrongMetadata);
Expand All @@ -84,6 +92,8 @@ class RemoteFunctionTest
// needed for tests since the thrift service runs in the same process.
registerFunction<PlusFunction, int64_t, int64_t, int64_t>(
{remotePrefix_ + ".remote_plus"});
registerFunction<FailFunction, UnknownValue, int32_t, Varchar>(
{remotePrefix_ + ".remote_fail"});
registerFunction<CheckedDivideFunction, double, double, double>(
{remotePrefix_ + ".remote_divide"});
registerFunction<SubstrFunction, Varchar, Varchar, int32_t>(
Expand Down Expand Up @@ -161,6 +171,53 @@ TEST_P(RemoteFunctionTest, string) {
assertEqualVectors(expected, results);
}

TEST_P(RemoteFunctionTest, tryException) {
// remote_divide throws if denominator is 0.
auto numeratorVector = makeFlatVector<double>({0, 1, 4, 9, 16});
auto denominatorVector = makeFlatVector<double>({0, 1, 2, 3, 4});
auto data = makeRowVector({numeratorVector, denominatorVector});
auto results =
evaluate<SimpleVector<double>>("TRY(remote_divide(c0, c1))", data);

ASSERT_EQ(results->size(), 5);
auto expected = makeFlatVector<double>({0 /* doesn't matter*/, 1, 2, 3, 4});
expected->setNull(0, true);

assertEqualVectors(expected, results);
}

TEST_P(RemoteFunctionTest, conditionalConjunction) {
// conditional conjunction disables throwing on error.
auto inputVector0 = makeFlatVector<bool>({true, true});
auto inputVector1 = makeFlatVector<int32_t>({1, 2});
auto data = makeRowVector({inputVector0, inputVector1});
auto results = evaluate<SimpleVector<StringView>>(
"case when (c0 OR remote_fail(c1, 'error')) then 'hello' else 'world' end",
data);

ASSERT_EQ(results->size(), 2);
auto expected = makeFlatVector<StringView>({"hello", "hello"});
assertEqualVectors(expected, results);
}

TEST_P(RemoteFunctionTest, tryErrorCode) {
// remote_fail doesn't throw, but returns error code.
auto errorCodesVector = makeFlatVector<int32_t>({1, 2});
auto errorMessagesVector =
makeFlatVector<StringView>({"failed 1", "failed 2"});
auto data = makeRowVector({errorCodesVector, errorMessagesVector});
exec::ExprSet exprSet(
{makeTypedExpr("TRY(remote_fail(c0, c1))", asRowType(data->type()))},
&execCtx_);
std::optional<SelectivityVector> rows;
exec::EvalCtx context(&execCtx_, &exprSet, data.get());
std::vector<VectorPtr> results(1);
SelectivityVector defaultRows(data->size());
exprSet.eval(defaultRows, context, results);

ASSERT_EQ(results[0]->size(), 2);
}

TEST_P(RemoteFunctionTest, connectionError) {
auto inputVector = makeFlatVector<int64_t>({1, 2, 3, 4, 5});
auto func = [&]() {
Expand Down
3 changes: 3 additions & 0 deletions velox/functions/remote/if/RemoteFunction.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ struct RemoteFunctionPage {

/// The number of logical rows in this page.
3: i64 rowCount;

/// Serialized errors.
4: optional IOBuf errorPayload;
}

/// The parameters passed to the remote thrift call.
Expand Down
53 changes: 49 additions & 4 deletions velox/functions/remote/server/RemoteFunctionService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "velox/functions/remote/server/RemoteFunctionService.h"
#include "velox/common/base/Exceptions.h"
#include "velox/expression/Expr.h"
#include "velox/functions/remote/if/GetSerde.h"
#include "velox/type/fbhive/HiveTypeParser.h"
Expand Down Expand Up @@ -66,6 +67,45 @@ std::vector<core::TypedExprPtr> getExpressions(
returnType, std::move(inputs), functionName)};
}

void RemoteFunctionServiceHandler::handleErrors(
apache::thrift::field_ref<remote::RemoteFunctionPage&> result,
exec::EvalErrors* evalErrors,
const std::unique_ptr<VectorSerde>& serde) const {
const std::int64_t numRows = result->get_rowCount();
BufferPtr dataBuffer =
AlignedBuffer::allocate<StringView>(numRows, pool_.get());

auto flatVector = std::make_shared<FlatVector<StringView>>(
pool_.get(),
VARCHAR(),
nullptr, // null vectors
numRows,
std::move(dataBuffer),
std::vector<BufferPtr>{});

for (vector_size_t i = 0; i < numRows; ++i) {
if (evalErrors->hasErrorAt(i)) {
auto exceptionPtr = *evalErrors->errorAt(i);
try {
std::rethrow_exception(*exceptionPtr);
} catch (const std::exception& ex) {
flatVector->set(i, ex.what());
}
} else {
flatVector->set(i, StringView());
flatVector->setNull(i, true);
}
}
auto errorRowVector = std::make_shared<RowVector>(
pool_.get(),
ROW({VARCHAR()}),
BufferPtr(),
numRows,
std::vector<VectorPtr>{flatVector});
result->errorPayload_ref() =
rowVectorToIOBuf(errorRowVector, *pool_, serde.get());
}

void RemoteFunctionServiceHandler::invokeFunction(
remote::RemoteFunctionResponse& response,
std::unique_ptr<remote::RemoteFunctionRequest> request) {
Expand All @@ -75,10 +115,6 @@ void RemoteFunctionServiceHandler::invokeFunction(
LOG(INFO) << "Got a request for '" << functionHandle.get_name()
<< "': " << inputs.get_rowCount() << " input rows.";

if (!request->get_throwOnError()) {
VELOX_NYI("throwOnError not implemented yet on remote server.");
}

// Deserialize types and data.
auto inputType = deserializeArgTypes(functionHandle.get_argumentTypes());
auto outputType = deserializeType(functionHandle.get_returnType());
Expand All @@ -102,7 +138,11 @@ void RemoteFunctionServiceHandler::invokeFunction(
outputType,
getFunctionName(functionPrefix_, functionHandle.get_name())),
&execCtx};

exec::EvalCtx evalCtx(&execCtx, &exprSet, inputVector.get());
if (!request->get_throwOnError()) {
*evalCtx.mutableThrowOnError() = false;
}

std::vector<VectorPtr> expressionResult;
exprSet.eval(rows, evalCtx, expressionResult);
Expand All @@ -116,6 +156,11 @@ void RemoteFunctionServiceHandler::invokeFunction(
result->pageFormat_ref() = serdeFormat;
result->payload_ref() =
rowVectorToIOBuf(outputRowVector, rows.end(), *pool_, serde.get());

auto evalErrors = evalCtx.errors();
if (evalErrors != nullptr && evalErrors->hasError()) {
handleErrors(result, evalErrors, serde);
}
}

} // namespace facebook::velox::functions
12 changes: 12 additions & 0 deletions velox/functions/remote/server/RemoteFunctionService.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
#include <thrift/lib/cpp2/server/ThriftServer.h>
#include "velox/common/memory/Memory.h"
#include "velox/functions/remote/if/gen-cpp2/RemoteFunctionService.h"
#include "velox/vector/VectorStream.h"

namespace facebook::velox::exec {
class EvalErrors;
}

namespace facebook::velox::functions {

Expand All @@ -35,6 +40,13 @@ class RemoteFunctionServiceHandler
std::unique_ptr<remote::RemoteFunctionRequest> request) override;

private:
/// Add evalErrors to result by serializing them to a vector of strings and
/// converting the result to a Velox flat vector.
void handleErrors(
apache::thrift::field_ref<remote::RemoteFunctionPage&> result,
exec::EvalErrors* evalErrors,
const std::unique_ptr<VectorSerde>& serde) const;

std::shared_ptr<memory::MemoryPool> pool_{
memory::memoryManager()->addLeafPool()};
const std::string functionPrefix_;
Expand Down

0 comments on commit 3c1c7e3

Please sign in to comment.