#ifndef LLVM_CLANG_UNITTESTS_ASTMATCHERS_ASTMATCHERSTEST_H
#define LLVM_CLANG_UNITTESTS_ASTMATCHERS_ASTMATCHERSTEST_H
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/Frontend/ASTUnit.h"
#include "clang/Testing/CommandLineArgs.h"
#include "clang/Testing/TestClangConfig.h"
#include "clang/Tooling/Tooling.h"
#include "gtest/gtest.h"
namespace clang {
namespace ast_matchers {
using clang::tooling::buildASTFromCodeWithArgs;
using clang::tooling::FileContentMappings;
using clang::tooling::FrontendActionFactory;
using clang::tooling::newFrontendActionFactory;
using clang::tooling::runToolOnCodeWithArgs;
class BoundNodesCallback {
public:
virtual ~BoundNodesCallback() {}
virtual bool run(const BoundNodes *BoundNodes) = 0;
virtual bool run(const BoundNodes *BoundNodes, ASTContext *Context) = 0;
virtual void onEndOfTranslationUnit() {}
};
class VerifyMatch : public MatchFinder::MatchCallback {
public:
VerifyMatch(std::unique_ptr<BoundNodesCallback> FindResultVerifier,
bool *Verified)
: Verified(Verified), FindResultReviewer(std::move(FindResultVerifier)) {}
void run(const MatchFinder::MatchResult &Result) override {
if (FindResultReviewer != nullptr) {
*Verified |= FindResultReviewer->run(&Result.Nodes, Result.Context);
} else {
*Verified = true;
}
}
void onEndOfTranslationUnit() override {
if (FindResultReviewer)
FindResultReviewer->onEndOfTranslationUnit();
}
private:
bool *const Verified;
const std::unique_ptr<BoundNodesCallback> FindResultReviewer;
};
inline ArrayRef<TestLanguage> langCxx11OrLater() {
static const TestLanguage Result[] = {Lang_CXX11, Lang_CXX14, Lang_CXX17,
Lang_CXX20};
return ArrayRef<TestLanguage>(Result);
}
inline ArrayRef<TestLanguage> langCxx14OrLater() {
static const TestLanguage Result[] = {Lang_CXX14, Lang_CXX17, Lang_CXX20};
return ArrayRef<TestLanguage>(Result);
}
inline ArrayRef<TestLanguage> langCxx17OrLater() {
static const TestLanguage Result[] = {Lang_CXX17, Lang_CXX20};
return ArrayRef<TestLanguage>(Result);
}
inline ArrayRef<TestLanguage> langCxx20OrLater() {
static const TestLanguage Result[] = {Lang_CXX20};
return ArrayRef<TestLanguage>(Result);
}
template <typename T>
testing::AssertionResult matchesConditionally(
const Twine &Code, const T &AMatcher, bool ExpectMatch,
ArrayRef<std::string> CompileArgs,
const FileContentMappings &VirtualMappedFiles = FileContentMappings(),
StringRef Filename = "input.cc") {
bool Found = false, DynamicFound = false;
MatchFinder Finder;
VerifyMatch VerifyFound(nullptr, &Found);
Finder.addMatcher(AMatcher, &VerifyFound);
VerifyMatch VerifyDynamicFound(nullptr, &DynamicFound);
if (!Finder.addDynamicMatcher(AMatcher, &VerifyDynamicFound))
return testing::AssertionFailure() << "Could not add dynamic matcher";
std::unique_ptr<FrontendActionFactory> Factory(
newFrontendActionFactory(&Finder));
std::vector<std::string> Args = {
"-frtti", "-fexceptions",
"-Werror=c++14-extensions", "-Werror=c++17-extensions",
"-Werror=c++20-extensions"};
llvm::copy(CompileArgs, std::back_inserter(Args));
if (llvm::find(Args, "-target") == Args.end()) {
Args.push_back("-target");
Args.push_back("i386-unknown-unknown");
}
if (!runToolOnCodeWithArgs(
Factory->create(), Code, Args, Filename, "clang-tool",
std::make_shared<PCHContainerOperations>(), VirtualMappedFiles)) {
return testing::AssertionFailure() << "Parsing error in \"" << Code << "\"";
}
if (Found != DynamicFound) {
return testing::AssertionFailure()
<< "Dynamic match result (" << DynamicFound
<< ") does not match static result (" << Found << ")";
}
if (!Found && ExpectMatch) {
return testing::AssertionFailure()
<< "Could not find match in \"" << Code << "\"";
} else if (Found && !ExpectMatch) {
return testing::AssertionFailure()
<< "Found unexpected match in \"" << Code << "\"";
}
return testing::AssertionSuccess();
}
template <typename T>
testing::AssertionResult
matchesConditionally(const Twine &Code, const T &AMatcher, bool ExpectMatch,
ArrayRef<TestLanguage> TestLanguages) {
for (auto Lang : TestLanguages) {
auto Result = matchesConditionally(
Code, AMatcher, ExpectMatch, getCommandLineArgsForTesting(Lang),
FileContentMappings(), getFilenameForTesting(Lang));
if (!Result)
return Result;
}
return testing::AssertionSuccess();
}
template <typename T>
testing::AssertionResult
matches(const Twine &Code, const T &AMatcher,
ArrayRef<TestLanguage> TestLanguages = {Lang_CXX11}) {
return matchesConditionally(Code, AMatcher, true, TestLanguages);
}
template <typename T>
testing::AssertionResult
notMatches(const Twine &Code, const T &AMatcher,
ArrayRef<TestLanguage> TestLanguages = {Lang_CXX11}) {
return matchesConditionally(Code, AMatcher, false, TestLanguages);
}
template <typename T>
testing::AssertionResult matchesObjC(const Twine &Code, const T &AMatcher,
bool ExpectMatch = true) {
return matchesConditionally(Code, AMatcher, ExpectMatch,
{"-fobjc-nonfragile-abi", "-Wno-objc-root-class",
"-fblocks", "-Wno-incomplete-implementation"},
FileContentMappings(), "input.m");
}
template <typename T>
testing::AssertionResult matchesC(const Twine &Code, const T &AMatcher) {
return matchesConditionally(Code, AMatcher, true, {}, FileContentMappings(),
"input.c");
}
template <typename T>
testing::AssertionResult notMatchesObjC(const Twine &Code, const T &AMatcher) {
return matchesObjC(Code, AMatcher, false);
}
template <typename T>
testing::AssertionResult
matchesConditionallyWithCuda(const Twine &Code, const T &AMatcher,
bool ExpectMatch, llvm::StringRef CompileArg) {
const std::string CudaHeader =
"typedef unsigned int size_t;\n"
"#define __constant__ __attribute__((constant))\n"
"#define __device__ __attribute__((device))\n"
"#define __global__ __attribute__((global))\n"
"#define __host__ __attribute__((host))\n"
"#define __shared__ __attribute__((shared))\n"
"struct dim3 {"
" unsigned x, y, z;"
" __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1)"
" : x(x), y(y), z(z) {}"
"};"
"typedef struct cudaStream *cudaStream_t;"
"int cudaConfigureCall(dim3 gridSize, dim3 blockSize,"
" size_t sharedSize = 0,"
" cudaStream_t stream = 0);"
"extern \"C\" unsigned __cudaPushCallConfiguration("
" dim3 gridDim, dim3 blockDim, size_t sharedMem = 0, void *stream = "
"0);";
bool Found = false, DynamicFound = false;
MatchFinder Finder;
VerifyMatch VerifyFound(nullptr, &Found);
Finder.addMatcher(AMatcher, &VerifyFound);
VerifyMatch VerifyDynamicFound(nullptr, &DynamicFound);
if (!Finder.addDynamicMatcher(AMatcher, &VerifyDynamicFound))
return testing::AssertionFailure() << "Could not add dynamic matcher";
std::unique_ptr<FrontendActionFactory> Factory(
newFrontendActionFactory(&Finder));
std::vector<std::string> Args = {
"-xcuda", "-fno-ms-extensions", "--cuda-host-only", "-nocudainc",
"-target", "x86_64-unknown-unknown", std::string(CompileArg)};
if (!runToolOnCodeWithArgs(Factory->create(), CudaHeader + Code, Args)) {
return testing::AssertionFailure() << "Parsing error in \"" << Code << "\"";
}
if (Found != DynamicFound) {
return testing::AssertionFailure()
<< "Dynamic match result (" << DynamicFound
<< ") does not match static result (" << Found << ")";
}
if (!Found && ExpectMatch) {
return testing::AssertionFailure()
<< "Could not find match in \"" << Code << "\"";
} else if (Found && !ExpectMatch) {
return testing::AssertionFailure()
<< "Found unexpected match in \"" << Code << "\"";
}
return testing::AssertionSuccess();
}
template <typename T>
testing::AssertionResult matchesWithCuda(const Twine &Code, const T &AMatcher) {
return matchesConditionallyWithCuda(Code, AMatcher, true, "-std=c++11");
}
template <typename T>
testing::AssertionResult notMatchesWithCuda(const Twine &Code,
const T &AMatcher) {
return matchesConditionallyWithCuda(Code, AMatcher, false, "-std=c++11");
}
template <typename T>
testing::AssertionResult matchesWithOpenMP(const Twine &Code,
const T &AMatcher) {
return matchesConditionally(Code, AMatcher, true, {"-fopenmp=libomp"});
}
template <typename T>
testing::AssertionResult notMatchesWithOpenMP(const Twine &Code,
const T &AMatcher) {
return matchesConditionally(Code, AMatcher, false, {"-fopenmp=libomp"});
}
template <typename T>
testing::AssertionResult matchesWithOpenMP51(const Twine &Code,
const T &AMatcher) {
return matchesConditionally(Code, AMatcher, true,
{"-fopenmp=libomp", "-fopenmp-version=51"});
}
template <typename T>
testing::AssertionResult notMatchesWithOpenMP51(const Twine &Code,
const T &AMatcher) {
return matchesConditionally(Code, AMatcher, false,
{"-fopenmp=libomp", "-fopenmp-version=51"});
}
template <typename T>
testing::AssertionResult matchAndVerifyResultConditionally(
const Twine &Code, const T &AMatcher,
std::unique_ptr<BoundNodesCallback> FindResultVerifier, bool ExpectResult) {
bool VerifiedResult = false;
MatchFinder Finder;
VerifyMatch VerifyVerifiedResult(std::move(FindResultVerifier),
&VerifiedResult);
Finder.addMatcher(AMatcher, &VerifyVerifiedResult);
std::unique_ptr<FrontendActionFactory> Factory(
newFrontendActionFactory(&Finder));
std::vector<std::string> Args = {"-std=gnu++11", "-target",
"i386-unknown-unknown"};
if (!runToolOnCodeWithArgs(Factory->create(), Code, Args)) {
return testing::AssertionFailure() << "Parsing error in \"" << Code << "\"";
}
if (!VerifiedResult && ExpectResult) {
return testing::AssertionFailure()
<< "Could not verify result in \"" << Code << "\"";
} else if (VerifiedResult && !ExpectResult) {
return testing::AssertionFailure()
<< "Verified unexpected result in \"" << Code << "\"";
}
VerifiedResult = false;
SmallString<256> Buffer;
std::unique_ptr<ASTUnit> AST(
buildASTFromCodeWithArgs(Code.toStringRef(Buffer), Args));
if (!AST.get())
return testing::AssertionFailure()
<< "Parsing error in \"" << Code << "\" while building AST";
Finder.matchAST(AST->getASTContext());
if (!VerifiedResult && ExpectResult) {
return testing::AssertionFailure()
<< "Could not verify result in \"" << Code << "\" with AST";
} else if (VerifiedResult && !ExpectResult) {
return testing::AssertionFailure()
<< "Verified unexpected result in \"" << Code << "\" with AST";
}
return testing::AssertionSuccess();
}
template <typename T>
testing::AssertionResult matchAndVerifyResultTrue(
const Twine &Code, const T &AMatcher,
std::unique_ptr<BoundNodesCallback> FindResultVerifier) {
return matchAndVerifyResultConditionally(Code, AMatcher,
std::move(FindResultVerifier), true);
}
template <typename T>
testing::AssertionResult matchAndVerifyResultFalse(
const Twine &Code, const T &AMatcher,
std::unique_ptr<BoundNodesCallback> FindResultVerifier) {
return matchAndVerifyResultConditionally(
Code, AMatcher, std::move(FindResultVerifier), false);
}
template <typename T> class VerifyIdIsBoundTo : public BoundNodesCallback {
public:
explicit VerifyIdIsBoundTo(llvm::StringRef Id)
: Id(std::string(Id)), ExpectedCount(-1), Count(0) {}
VerifyIdIsBoundTo(llvm::StringRef Id, int ExpectedCount)
: Id(std::string(Id)), ExpectedCount(ExpectedCount), Count(0) {}
VerifyIdIsBoundTo(llvm::StringRef Id, llvm::StringRef ExpectedName,
int ExpectedCount = 1)
: Id(std::string(Id)), ExpectedCount(ExpectedCount), Count(0),
ExpectedName(std::string(ExpectedName)) {}
void onEndOfTranslationUnit() override {
if (ExpectedCount != -1) {
EXPECT_EQ(ExpectedCount, Count);
}
if (!ExpectedName.empty()) {
EXPECT_EQ(ExpectedName, Name);
}
Count = 0;
Name.clear();
}
~VerifyIdIsBoundTo() override {
EXPECT_EQ(0, Count);
EXPECT_EQ("", Name);
}
bool run(const BoundNodes *Nodes) override {
const BoundNodes::IDToNodeMap &M = Nodes->getMap();
if (Nodes->getNodeAs<T>(Id)) {
++Count;
if (const NamedDecl *Named = Nodes->getNodeAs<NamedDecl>(Id)) {
Name = Named->getNameAsString();
} else if (const NestedNameSpecifier *NNS =
Nodes->getNodeAs<NestedNameSpecifier>(Id)) {
llvm::raw_string_ostream OS(Name);
NNS->print(OS, PrintingPolicy(LangOptions()));
}
BoundNodes::IDToNodeMap::const_iterator I = M.find(Id);
EXPECT_NE(M.end(), I);
if (I != M.end()) {
EXPECT_EQ(Nodes->getNodeAs<T>(Id), I->second.get<T>());
}
return true;
}
EXPECT_TRUE(M.count(Id) == 0 ||
M.find(Id)->second.template get<T>() == nullptr);
return false;
}
bool run(const BoundNodes *Nodes, ASTContext *Context) override {
return run(Nodes);
}
private:
const std::string Id;
const int ExpectedCount;
int Count;
const std::string ExpectedName;
std::string Name;
};
class ASTMatchersTest : public ::testing::Test,
public ::testing::WithParamInterface<TestClangConfig> {
protected:
template <typename T>
testing::AssertionResult matches(const Twine &Code, const T &AMatcher) {
const TestClangConfig &TestConfig = GetParam();
return clang::ast_matchers::matchesConditionally(
Code, AMatcher, true, TestConfig.getCommandLineArgs(),
FileContentMappings(), getFilenameForTesting(TestConfig.Language));
}
template <typename T>
testing::AssertionResult notMatches(const Twine &Code, const T &AMatcher) {
const TestClangConfig &TestConfig = GetParam();
return clang::ast_matchers::matchesConditionally(
Code, AMatcher, false, TestConfig.getCommandLineArgs(),
FileContentMappings(), getFilenameForTesting(TestConfig.Language));
}
};
} }
#endif