Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SandboxIR] Add RegionPass/RegionPassManager #110933

Merged
merged 3 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions llvm/include/llvm/SandboxIR/Pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
namespace llvm::sandboxir {

class Function;
class Region;

/// The base class of a Sandbox IR Pass.
class Pass {
Expand All @@ -24,6 +25,7 @@ class Pass {
const std::string Name;

public:
/// \p Name can't contain any spaces or start with '-'.
Pass(StringRef Name) : Name(Name) {
assert(!Name.contains(' ') &&
"A pass name should not contain whitespaces!");
Expand All @@ -47,11 +49,21 @@ class Pass {
/// A pass that runs on a sandbox::Function.
class FunctionPass : public Pass {
public:
/// \p Name can't contain any spaces or start with '-'.
FunctionPass(StringRef Name) : Pass(Name) {}
/// \Returns true if it modifies \p F.
virtual bool runOnFunction(Function &F) = 0;
};

/// A pass that runs on a sandbox::Region.
class RegionPass : public Pass {
public:
/// \p Name can't contain any spaces or start with '-'.
RegionPass(StringRef Name) : Pass(Name) {}
/// \Returns true if it modifies \p R.
virtual bool runOnRegion(Region &R) = 0;
};

} // namespace llvm::sandboxir

#endif // LLVM_SANDBOXIR_PASS_H
6 changes: 6 additions & 0 deletions llvm/include/llvm/SandboxIR/PassManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ class FunctionPassManager final
bool runOnFunction(Function &F) final;
};

class RegionPassManager final : public PassManager<RegionPass, RegionPass> {
public:
RegionPassManager(StringRef Name) : PassManager(Name) {}
bool runOnRegion(Region &R) final;
};

/// Owns the passes and provides an API to get a pass by its name.
class PassRegistry {
SmallVector<std::unique_ptr<Pass>, 8> Passes;
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/SandboxIR/PassManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ bool FunctionPassManager::runOnFunction(Function &F) {
return Change;
}

bool RegionPassManager::runOnRegion(Region &R) {
bool Change = false;
for (RegionPass *Pass : Passes) {
Change |= Pass->runOnRegion(R);
// TODO: run the verifier.
}
// TODO: Check ChangeAll against hashes before/after.
return Change;
}

FunctionPassManager &
PassRegistry::parseAndCreatePassPipeline(StringRef Pipeline) {
static constexpr const char EndToken = '\0';
Expand Down
124 changes: 124 additions & 0 deletions llvm/unittests/SandboxIR/PassTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "llvm/SandboxIR/Context.h"
#include "llvm/SandboxIR/Function.h"
#include "llvm/SandboxIR/PassManager.h"
#include "llvm/SandboxIR/Region.h"
#include "llvm/Support/SourceMgr.h"
#include "gtest/gtest.h"

Expand Down Expand Up @@ -86,6 +87,68 @@ define void @foo() {
#endif
}

TEST_F(PassTest, RegionPass) {
auto *F = parseFunction(R"IR(
define i8 @foo(i8 %v0, i8 %v1) {
%t0 = add i8 %v0, 1
%t1 = add i8 %t0, %v1, !sandboxvec !0
%t2 = add i8 %t1, %v1, !sandboxvec !0
ret i8 %t1
}

!0 = distinct !{!"sandboxregion"}
)IR",
"foo");

class TestPass final : public RegionPass {
unsigned &InstCount;

public:
TestPass(unsigned &InstCount)
: RegionPass("test-pass"), InstCount(InstCount) {}
bool runOnRegion(Region &R) final {
for ([[maybe_unused]] auto &Inst : R) {
++InstCount;
}
return false;
}
};
unsigned InstCount = 0;
TestPass TPass(InstCount);
// Check getName(),
EXPECT_EQ(TPass.getName(), "test-pass");
// Check runOnRegion();
llvm::SmallVector<std::unique_ptr<Region>> Regions =
Region::createRegionsFromMD(*F);
ASSERT_EQ(Regions.size(), 1u);
TPass.runOnRegion(*Regions[0]);
EXPECT_EQ(InstCount, 2u);
#ifndef NDEBUG
{
// Check print().
std::string Buff;
llvm::raw_string_ostream SS(Buff);
TPass.print(SS);
EXPECT_EQ(Buff, "test-pass");
}
{
// Check operator<<().
std::string Buff;
llvm::raw_string_ostream SS(Buff);
SS << TPass;
EXPECT_EQ(Buff, "test-pass");
}
// Check pass name assertions.
class TestNamePass final : public RegionPass {
public:
TestNamePass(llvm::StringRef Name) : RegionPass(Name) {}
bool runOnRegion(Region &F) { return false; }
};
EXPECT_DEATH(TestNamePass("white space"), ".*whitespace.*");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a comment at the constructor about these name restrictions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base Pass class does, kinda (it mentions it can't contain spaces, but it doesn't mention it can't begin with a -). And the comment is not for the constructor but for the Name member. So, technically the answer to your question is "no, not really, and neither does FunctionPass".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should have some kind of comment that describes the allowed names or points to the code that does the checks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would something like this suffice?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, looks great.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Name field in the Pass class already has a comment with the reasoning.

/// The pass name. This is also used as a command-line flag and should not
/// contain whitespaces.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok

unrelated to this patch, I'm not sure why we don't allow "-" as the first char. it's not conventional to start a pass name with "-", but there are plenty of other worse things you could do, like start with "("

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, using an illegal name shouldn't be too much of an issue. Feel free to drop the comments you added.

Copy link
Collaborator Author

@slackito slackito Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll merge as-is if you don't mind. The comments are a slight improvement IMO (nothing wrong with documenting the public interface of the class) and I don't want to push another revision and start pre-merge checks again.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

EXPECT_DEATH(TestNamePass("-dash"), ".*start with.*");
#endif
}

TEST_F(PassTest, FunctionPassManager) {
auto *F = parseFunction(R"IR(
define void @foo() {
Expand Down Expand Up @@ -136,6 +199,67 @@ define void @foo() {
#endif // NDEBUG
}

TEST_F(PassTest, RegionPassManager) {
auto *F = parseFunction(R"IR(
define i8 @foo(i8 %v0, i8 %v1) {
%t0 = add i8 %v0, 1
%t1 = add i8 %t0, %v1, !sandboxvec !0
%t2 = add i8 %t1, %v1, !sandboxvec !0
ret i8 %t1
}

!0 = distinct !{!"sandboxregion"}
)IR",
"foo");

class TestPass1 final : public RegionPass {
unsigned &InstCount;

public:
TestPass1(unsigned &InstCount)
: RegionPass("test-pass1"), InstCount(InstCount) {}
bool runOnRegion(Region &R) final {
for ([[maybe_unused]] auto &Inst : R)
++InstCount;
return false;
}
};
class TestPass2 final : public RegionPass {
unsigned &InstCount;

public:
TestPass2(unsigned &InstCount)
: RegionPass("test-pass2"), InstCount(InstCount) {}
bool runOnRegion(Region &R) final {
for ([[maybe_unused]] auto &Inst : R)
++InstCount;
return false;
}
};
unsigned InstCount1 = 0;
unsigned InstCount2 = 0;
TestPass1 TPass1(InstCount1);
TestPass2 TPass2(InstCount2);

RegionPassManager RPM("test-rpm");
RPM.addPass(&TPass1);
RPM.addPass(&TPass2);
// Check runOnRegion().
llvm::SmallVector<std::unique_ptr<Region>> Regions =
Region::createRegionsFromMD(*F);
ASSERT_EQ(Regions.size(), 1u);
RPM.runOnRegion(*Regions[0]);
EXPECT_EQ(InstCount1, 2u);
EXPECT_EQ(InstCount2, 2u);
#ifndef NDEBUG
// Check dump().
std::string Buff;
llvm::raw_string_ostream SS(Buff);
RPM.print(SS);
EXPECT_EQ(Buff, "test-rpm(test-pass1,test-pass2)");
#endif // NDEBUG
}

TEST_F(PassTest, PassRegistry) {
class TestPass1 final : public FunctionPass {
public:
Expand Down
Loading