diff --git a/llvm/include/llvm/Transforms/IPO/MergeFunctions.h b/llvm/include/llvm/Transforms/IPO/MergeFunctions.h index 822f0fd99188d0..1b3b1d22f11e28 100644 --- a/llvm/include/llvm/Transforms/IPO/MergeFunctions.h +++ b/llvm/include/llvm/Transforms/IPO/MergeFunctions.h @@ -15,7 +15,10 @@ #ifndef LLVM_TRANSFORMS_IPO_MERGEFUNCTIONS_H #define LLVM_TRANSFORMS_IPO_MERGEFUNCTIONS_H +#include "llvm/IR/Function.h" #include "llvm/IR/PassManager.h" +#include +#include namespace llvm { @@ -25,6 +28,10 @@ class Module; class MergeFunctionsPass : public PassInfoMixin { public: PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); + + static bool runOnModule(Module &M); + static std::pair> + runOnFunctions(std::set &F); }; } // end namespace llvm diff --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp index b50a700e09038f..a434d7920b6ccf 100644 --- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp +++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp @@ -123,6 +123,7 @@ #include #include #include +#include #include #include #include @@ -198,6 +199,8 @@ class MergeFunctions { } bool runOnModule(Module &M); + bool runOnFunctions(std::set &F); + std::map &getDelToNewMap(); private: // The function comparison operator is provided here so that FunctionNodes do @@ -298,17 +301,31 @@ class MergeFunctions { // dangling iterators into FnTree. The invariant that preserves this is that // there is exactly one mapping F -> FN for each FunctionNode FN in FnTree. DenseMap, FnTreeType::iterator> FNodesInTree; + + /// Deleted-New functions mapping + std::map DelToNewMap; }; } // end anonymous namespace PreservedAnalyses MergeFunctionsPass::run(Module &M, ModuleAnalysisManager &AM) { - MergeFunctions MF; - if (!MF.runOnModule(M)) + if (!MergeFunctionsPass::runOnModule(M)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); } +bool MergeFunctionsPass::runOnModule(Module &M) { + MergeFunctions MF; + return MF.runOnModule(M); +} + +std::pair> +MergeFunctionsPass::runOnFunctions(std::set &F) { + MergeFunctions MF; + bool MergeResult = MF.runOnFunctions(F); + return {MergeResult, MF.getDelToNewMap()}; +} + #ifndef NDEBUG bool MergeFunctions::doFunctionalCheck(std::vector &Worklist) { if (const unsigned Max = NumFunctionsForVerificationCheck) { @@ -468,6 +485,47 @@ bool MergeFunctions::runOnModule(Module &M) { return Changed; } +bool MergeFunctions::runOnFunctions(std::set &F) { + bool Changed = false; + std::vector> HashedFuncs; + for (Function *Func : F) { + if (isEligibleForMerging(*Func)) { + HashedFuncs.push_back({StructuralHash(*Func), Func}); + } + } + llvm::stable_sort(HashedFuncs, less_first()); + auto S = HashedFuncs.begin(); + for (auto I = HashedFuncs.begin(), IE = HashedFuncs.end(); I != IE; ++I) { + if ((I != S && std::prev(I)->first == I->first) || + (std::next(I) != IE && std::next(I)->first == I->first)) { + Deferred.push_back(WeakTrackingVH(I->second)); + } + } + do { + std::vector Worklist; + Deferred.swap(Worklist); + LLVM_DEBUG(dbgs() << "size of function: " << F.size() << '\n'); + LLVM_DEBUG(dbgs() << "size of worklist: " << Worklist.size() << '\n'); + for (WeakTrackingVH &I : Worklist) { + if (!I) + continue; + Function *F = cast(I); + if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage()) { + Changed |= insert(F); + } + } + LLVM_DEBUG(dbgs() << "size of FnTree: " << FnTree.size() << '\n'); + } while (!Deferred.empty()); + FnTree.clear(); + FNodesInTree.clear(); + GlobalNumbers.clear(); + return Changed; +} + +std::map &MergeFunctions::getDelToNewMap() { + return this->DelToNewMap; +} + // Replace direct callers of Old with New. void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) { for (Use &U : llvm::make_early_inc_range(Old->uses())) { @@ -1004,6 +1062,7 @@ bool MergeFunctions::insert(Function *NewFunction) { Function *DeleteF = NewFunction; mergeTwoFunctions(OldF.getFunc(), DeleteF); + this->DelToNewMap.emplace(DeleteF, OldF.getFunc()); return true; } diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt index 5c7ec28709c169..7effa5d8e7d6d2 100644 --- a/llvm/unittests/Transforms/Utils/CMakeLists.txt +++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt @@ -26,6 +26,7 @@ add_llvm_unittest(UtilsTests LoopUtilsTest.cpp MemTransferLowering.cpp ModuleUtilsTest.cpp + MergeFunctionsTest.cpp ScalarEvolutionExpanderTest.cpp SizeOptsTest.cpp SSAUpdaterBulkTest.cpp diff --git a/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp b/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp new file mode 100644 index 00000000000000..696c5391ef4f68 --- /dev/null +++ b/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp @@ -0,0 +1,271 @@ +//===- MergeFunctionsTest.cpp - Unit tests for +//MergeFunctionsPass-----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/MergeFunctions.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" +#include + +using namespace llvm; + +namespace { + +TEST(MergeFunctions, TrueOutputModuleTest) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"invalid( + @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1 + @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1 + + define dso_local i32 @f(i32 noundef %arg) #0 { + entry: + %add109 = call i32 @_slice_add10(i32 %arg) + %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109) + ret i32 %add109 + } + + declare i32 @printf(ptr noundef, ...) #1 + + define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 { + entry: + %add99 = call i32 @_slice_add10(i32 %argc) + %call = call i32 @f(i32 noundef 2) + %sub = sub nsw i32 %call, 6 + %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99) + ret i32 %add99 + } + + define internal i32 @_slice_add10(i32 %arg) #2 { + sliceclone_entry: + %0 = mul nsw i32 %arg, %arg + %1 = mul nsw i32 %0, 2 + %2 = mul nsw i32 %1, 2 + %3 = mul nsw i32 %2, 2 + %4 = add nsw i32 %3, 2 + ret i32 %4 + } + + define internal i32 @_slice_add10_alt(i32 %arg) #2 { + sliceclone_entry: + %0 = mul nsw i32 %arg, %arg + %1 = mul nsw i32 %0, 2 + %2 = mul nsw i32 %1, 2 + %3 = mul nsw i32 %2, 2 + %4 = add nsw i32 %3, 2 + ret i32 %4 + } + + attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" } + attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" } + attributes #2 = { nounwind willreturn } + )invalid", + Err, Ctx)); + + // Expects true after merging _slice_add10 and _slice_add10_alt + EXPECT_TRUE(MergeFunctionsPass::runOnModule(*M)); +} + +TEST(MergeFunctions, TrueOutputFunctionsTest) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"invalid( + @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1 + @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1 + + define dso_local i32 @f(i32 noundef %arg) #0 { + entry: + %add109 = call i32 @_slice_add10(i32 %arg) + %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109) + ret i32 %add109 + } + + declare i32 @printf(ptr noundef, ...) #1 + + define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 { + entry: + %add99 = call i32 @_slice_add10(i32 %argc) + %call = call i32 @f(i32 noundef 2) + %sub = sub nsw i32 %call, 6 + %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99) + ret i32 %add99 + } + + define internal i32 @_slice_add10(i32 %arg) #2 { + sliceclone_entry: + %0 = mul nsw i32 %arg, %arg + %1 = mul nsw i32 %0, 2 + %2 = mul nsw i32 %1, 2 + %3 = mul nsw i32 %2, 2 + %4 = add nsw i32 %3, 2 + ret i32 %4 + } + + define internal i32 @_slice_add10_alt(i32 %arg) #2 { + sliceclone_entry: + %0 = mul nsw i32 %arg, %arg + %1 = mul nsw i32 %0, 2 + %2 = mul nsw i32 %1, 2 + %3 = mul nsw i32 %2, 2 + %4 = add nsw i32 %3, 2 + ret i32 %4 + } + + attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" } + attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" } + attributes #2 = { nounwind willreturn } + )invalid", + Err, Ctx)); + + std::set FunctionsSet; + for (Function &F : *M) + FunctionsSet.insert(&F); + + std::pair> MergeResult = + MergeFunctionsPass::runOnFunctions(FunctionsSet); + + // Expects true after merging _slice_add10 and _slice_add10_alt + EXPECT_TRUE(MergeResult.first); + + // Expects that both functions (_slice_add10 and _slice_add10_alt) + // be mapped to the same new function + EXPECT_TRUE(MergeResult.second.size() > 0); + std::map DelToNew = MergeResult.second; + Function *NewFunction = M->getFunction("_slice_add10"); + for (auto P : DelToNew) + if (P.second) + EXPECT_EQ(P.second, NewFunction); +} + +TEST(MergeFunctions, FalseOutputModuleTest) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"invalid( + @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1 + @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1 + + define dso_local i32 @f(i32 noundef %arg) #0 { + entry: + %add109 = call i32 @_slice_add10(i32 %arg) + %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109) + ret i32 %add109 + } + + declare i32 @printf(ptr noundef, ...) #1 + + define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 { + entry: + %add99 = call i32 @_slice_add10(i32 %argc) + %call = call i32 @f(i32 noundef 2) + %sub = sub nsw i32 %call, 6 + %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99) + ret i32 %add99 + } + + define internal i32 @_slice_add10(i32 %arg) #2 { + sliceclone_entry: + %0 = mul nsw i32 %arg, %arg + %1 = mul nsw i32 %0, 2 + %2 = mul nsw i32 %1, 2 + %3 = mul nsw i32 %2, 2 + %4 = add nsw i32 %3, 2 + ret i32 %4 + } + + define internal i32 @_slice_add10_alt(i32 %arg) #2 { + sliceclone_entry: + %0 = mul nsw i32 %arg, %arg + %1 = mul nsw i32 %0, 2 + %2 = mul nsw i32 %1, 2 + %3 = mul nsw i32 %2, 2 + %4 = add nsw i32 %3, 2 + ret i32 %0 + } + + attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" } + attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" } + attributes #2 = { nounwind willreturn } + )invalid", + Err, Ctx)); + + // Expects false after trying to merge _slice_add10 and _slice_add10_alt + EXPECT_FALSE(MergeFunctionsPass::runOnModule(*M)); +} + +TEST(MergeFunctions, FalseOutputFunctionsTest) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"invalid( + @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1 + @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1 + + define dso_local i32 @f(i32 noundef %arg) #0 { + entry: + %add109 = call i32 @_slice_add10(i32 %arg) + %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109) + ret i32 %add109 + } + + declare i32 @printf(ptr noundef, ...) #1 + + define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 { + entry: + %add99 = call i32 @_slice_add10(i32 %argc) + %call = call i32 @f(i32 noundef 2) + %sub = sub nsw i32 %call, 6 + %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99) + ret i32 %add99 + } + + define internal i32 @_slice_add10(i32 %arg) #2 { + sliceclone_entry: + %0 = mul nsw i32 %arg, %arg + %1 = mul nsw i32 %0, 2 + %2 = mul nsw i32 %1, 2 + %3 = mul nsw i32 %2, 2 + %4 = add nsw i32 %3, 2 + ret i32 %4 + } + + define internal i32 @_slice_add10_alt(i32 %arg) #2 { + sliceclone_entry: + %0 = mul nsw i32 %arg, %arg + %1 = mul nsw i32 %0, 2 + %2 = mul nsw i32 %1, 2 + %3 = mul nsw i32 %2, 2 + %4 = add nsw i32 %3, 2 + ret i32 %0 + } + + attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" } + attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" } + attributes #2 = { nounwind willreturn } + )invalid", + Err, Ctx)); + + std::set FunctionsSet; + for (Function &F : *M) + FunctionsSet.insert(&F); + + std::pair> MergeResult = + MergeFunctionsPass::runOnFunctions(FunctionsSet); + + for (auto P : MergeResult.second) + std::cout << P.first << " " << P.second << "\n"; + + // Expects false after trying to merge _slice_add10 and _slice_add10_alt + EXPECT_FALSE(MergeResult.first); + + // Expects empty map + EXPECT_EQ(MergeResult.second.size(), 0u); +} + +} // namespace \ No newline at end of file diff --git a/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn b/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn index 380ed71a2bc010..fcea55c91f083c 100644 --- a/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn +++ b/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn @@ -27,6 +27,7 @@ unittest("UtilsTests") { "LoopUtilsTest.cpp", "MemTransferLowering.cpp", "ModuleUtilsTest.cpp", + "MergeFunctionsTest.cpp", "ProfDataUtilTest.cpp", "SSAUpdaterBulkTest.cpp", "ScalarEvolutionExpanderTest.cpp",