Skip to content

Commit

Permalink
handle errors in remote function server when get_throwOnError() is fa…
Browse files Browse the repository at this point in the history
…lse (facebookincubator#11083)

Summary: Pull Request resolved: facebookincubator#11083

Differential Revision: D63346758
  • Loading branch information
Guilherme Kunigami authored and facebook-github-bot committed Sep 27, 2024
1 parent 28a8979 commit 61fc63c
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 4 deletions.
23 changes: 23 additions & 0 deletions velox/functions/remote/client/Remote.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,29 @@ class RemoteFunction : public exec::VectorFunction {
*context.pool(),
serde_.get());
result = outputRowVector->childAt(0);

if (remoteResponse.get_result().errorPayload().is_set()) {
auto errorPayload = remoteResponse.get_result().get_errorPayload();
auto errorsRowVector = IOBufToRowVector(
errorPayload,
velox::ROW({velox::VARCHAR()}),
*context.pool(),
serde_.get());
auto errorsVector =
errorsRowVector->childAt(0)->asFlatVector<velox::StringView>();

velox::SelectivityVector selectedRows(errorsRowVector->size());
selectedRows.applyToSelected([&](velox::vector_size_t i) {
if (errorsVector->isNullAt(i)) {
return;
}
try {
throw std::runtime_error(errorsVector->valueAt(i));
} catch (const std::exception& ex) {
context.setError(i, std::current_exception());
}
});
}
}

const std::string functionName_;
Expand Down
42 changes: 42 additions & 0 deletions velox/functions/remote/client/tests/RemoteFunctionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ using ::facebook::velox::test::assertEqualVectors;
namespace facebook::velox::functions {
namespace {

template <typename T>
struct FailFunction {
template <typename TInput>
FOLLY_ALWAYS_INLINE void
call(TInput& result, const TInput& a, const TInput& b) {
if (a == 0) {
throw std::runtime_error("Failure");
}
result = b / a;
}
};

// Parametrize in the serialization format so we can test both presto page and
// unsafe row.
class RemoteFunctionTest
Expand All @@ -62,6 +74,13 @@ class RemoteFunctionTest
.build()};
registerRemoteFunction("remote_plus", plusSignatures, metadata);

auto failSignatures = {exec::FunctionSignatureBuilder()
.returnType("bigint")
.argumentType("bigint")
.argumentType("bigint")
.build()};
registerRemoteFunction("remote_fail", failSignatures, metadata);

RemoteVectorFunctionMetadata wrongMetadata = metadata;
wrongMetadata.location = folly::SocketAddress(); // empty address.
registerRemoteFunction("remote_wrong_port", plusSignatures, wrongMetadata);
Expand All @@ -84,6 +103,8 @@ class RemoteFunctionTest
// needed for tests since the thrift service runs in the same process.
registerFunction<PlusFunction, int64_t, int64_t, int64_t>(
{remotePrefix_ + ".remote_plus"});
registerFunction<FailFunction, int64_t, int64_t, int64_t>(
{remotePrefix_ + ".remote_fail"});
registerFunction<CheckedDivideFunction, double, double, double>(
{remotePrefix_ + ".remote_divide"});
registerFunction<SubstrFunction, Varchar, Varchar, int32_t>(
Expand Down Expand Up @@ -161,6 +182,27 @@ TEST_P(RemoteFunctionTest, string) {
assertEqualVectors(expected, results);
}

TEST_P(RemoteFunctionTest, fail) {
auto inputVector = makeFlatVector<int64_t>({1, 2, 0, 4, 5});
auto data = makeRowVector({inputVector});
exec::ExprSet exprSet(
{makeTypedExpr("TRY(remote_fail(c0, c0))", asRowType(data->type()))},
&execCtx_);
std::optional<SelectivityVector> rows;
exec::EvalCtx context(&execCtx_, &exprSet, data.get());
std::vector<VectorPtr> result(1);
SelectivityVector defaultRows(data->size());

exprSet.eval(defaultRows, context, result);
// FIXME: why is that context has no errors here?

ASSERT_EQ(result[0]->size(), 5);
auto expected = makeFlatVector<int64_t>({1, 1, 1, 1, 1});
expected->setNull(2, true);

assertEqualVectors(expected, result[0]);
}

TEST_P(RemoteFunctionTest, connectionError) {
auto inputVector = makeFlatVector<int64_t>({1, 2, 3, 4, 5});
auto func = [&]() {
Expand Down
3 changes: 3 additions & 0 deletions velox/functions/remote/if/RemoteFunction.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ struct RemoteFunctionPage {

/// The number of logical rows in this page.
3: i64 rowCount;

/// Serialized errors.
4: IOBuf errorPayload;
}

/// The parameters passed to the remote thrift call.
Expand Down
54 changes: 50 additions & 4 deletions velox/functions/remote/server/RemoteFunctionService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "velox/functions/remote/server/RemoteFunctionService.h"
#include "velox/common/base/Exceptions.h"
#include "velox/expression/Expr.h"
#include "velox/functions/remote/if/GetSerde.h"
#include "velox/type/fbhive/HiveTypeParser.h"
Expand Down Expand Up @@ -66,6 +67,46 @@ std::vector<core::TypedExprPtr> getExpressions(
returnType, std::move(inputs), functionName)};
}

void RemoteFunctionServiceHandler::handleErrors(
apache::thrift::field_ref<remote::RemoteFunctionPage&> result,
exec::EvalErrors* evalErrors,
const std::unique_ptr<VectorSerde>& serde) const {
const vector_size_t numRows = result->get_rowCount();
BufferPtr dataBuffer =
AlignedBuffer::allocate<velox::StringView>(numRows, pool_.get());

auto flatVector = std::make_shared<velox::FlatVector<velox::StringView>>(
pool_.get(),
velox::VARCHAR(),
nullptr, // null vectors
numRows,
std::move(dataBuffer),
std::vector<velox::BufferPtr>{});

for (vector_size_t i = 0; i < numRows; ++i) {
if (evalErrors->hasErrorAt(i)) {
auto exceptionPtr = *evalErrors->errorAt(i);
try {
std::rethrow_exception(*exceptionPtr);
} catch (const std::exception& ex) {
flatVector->set(i, ex.what());
}
} else {
flatVector->set(i, velox::StringView());
flatVector->setNull(i, true);
}
}
std::vector<VectorPtr> errorsVector{flatVector};
auto errorRowVector = std::make_shared<RowVector>(
pool_.get(),
velox::ROW({velox::VARCHAR()}),
BufferPtr(),
numRows,
errorsVector);
result->errorPayload_ref() =
rowVectorToIOBuf(errorRowVector, *pool_, serde.get());
}

void RemoteFunctionServiceHandler::invokeFunction(
remote::RemoteFunctionResponse& response,
std::unique_ptr<remote::RemoteFunctionRequest> request) {
Expand All @@ -75,10 +116,6 @@ void RemoteFunctionServiceHandler::invokeFunction(
LOG(INFO) << "Got a request for '" << functionHandle.get_name()
<< "': " << inputs.get_rowCount() << " input rows.";

if (!request->get_throwOnError()) {
VELOX_NYI("throwOnError not implemented yet on remote server.");
}

// Deserialize types and data.
auto inputType = deserializeArgTypes(functionHandle.get_argumentTypes());
auto outputType = deserializeType(functionHandle.get_returnType());
Expand All @@ -102,7 +139,11 @@ void RemoteFunctionServiceHandler::invokeFunction(
outputType,
getFunctionName(functionPrefix_, functionHandle.get_name())),
&execCtx};

exec::EvalCtx evalCtx(&execCtx, &exprSet, inputVector.get());
if (!request->get_throwOnError()) {
*evalCtx.mutableThrowOnError() = false;
}

std::vector<VectorPtr> expressionResult;
exprSet.eval(rows, evalCtx, expressionResult);
Expand All @@ -116,6 +157,11 @@ void RemoteFunctionServiceHandler::invokeFunction(
result->pageFormat_ref() = serdeFormat;
result->payload_ref() =
rowVectorToIOBuf(outputRowVector, rows.end(), *pool_, serde.get());

auto evalErrors = evalCtx.errors();
if (evalErrors != nullptr && evalErrors->hasError()) {
handleErrors(result, evalErrors, serde);
}
}

} // namespace facebook::velox::functions
7 changes: 7 additions & 0 deletions velox/functions/remote/server/RemoteFunctionService.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

#include <thrift/lib/cpp2/server/ThriftServer.h>
#include "velox/common/memory/Memory.h"
#include "velox/expression/EvalCtx.h"
#include "velox/functions/remote/if/gen-cpp2/RemoteFunctionService.h"
#include "velox/vector/VectorStream.h"

namespace facebook::velox::functions {

Expand All @@ -35,6 +37,11 @@ class RemoteFunctionServiceHandler
std::unique_ptr<remote::RemoteFunctionRequest> request) override;

private:
void handleErrors(
apache::thrift::field_ref<remote::RemoteFunctionPage&> result,
exec::EvalErrors* evalErrors,
const std::unique_ptr<VectorSerde>& serde) const;

std::shared_ptr<memory::MemoryPool> pool_{
memory::memoryManager()->addLeafPool()};
const std::string functionPrefix_;
Expand Down

0 comments on commit 61fc63c

Please sign in to comment.