#include "clang/AST/ASTContext.h"
#include "clang/AST/Attr.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Frontend/FrontendPluginRegistry.h"
#include "clang/Sema/ParsedAttr.h"
#include "clang/Sema/Sema.h"
#include "clang/Sema/SemaDiagnostic.h"
#include "llvm/ADT/SmallPtrSet.h"
using namespace clang;
namespace {
llvm::SmallPtrSet<const CXXMethodDecl *, 16> MarkedMethods;
bool isMarkedAsCallSuper(const CXXMethodDecl *D) {
    return MarkedMethods.contains(D);
}
class MethodUsageVisitor : public RecursiveASTVisitor<MethodUsageVisitor> {
public:
  bool IsOverriddenUsed = false;
  explicit MethodUsageVisitor(
      llvm::SmallPtrSet<const CXXMethodDecl *, 16> &MustCalledMethods)
      : MustCalledMethods(MustCalledMethods) {}
  bool VisitCallExpr(CallExpr *CallExpr) {
    const CXXMethodDecl *Callee = nullptr;
    for (const auto &MustCalled : MustCalledMethods) {
      if (CallExpr->getCalleeDecl() == MustCalled) {
                                Callee = MustCalled;
      }
    }
    if (Callee)
      MustCalledMethods.erase(Callee);
    return true;
  }
private:
  llvm::SmallPtrSet<const CXXMethodDecl *, 16> &MustCalledMethods;
};
class CallSuperVisitor : public RecursiveASTVisitor<CallSuperVisitor> {
public:
  CallSuperVisitor(DiagnosticsEngine &Diags) : Diags(Diags) {
    WarningSuperNotCalled = Diags.getCustomDiagID(
        DiagnosticsEngine::Warning,
        "virtual function %q0 is marked as 'call_super' but this overriding "
        "method does not call the base version");
    NotePreviousCallSuperDeclaration = Diags.getCustomDiagID(
        DiagnosticsEngine::Note, "function marked 'call_super' here");
  }
  bool VisitCXXMethodDecl(CXXMethodDecl *MethodDecl) {
    if (MethodDecl->isThisDeclarationADefinition() && MethodDecl->hasBody()) {
            llvm::SmallPtrSet<const CXXMethodDecl *, 16> OverriddenMarkedMethods;
      for (const auto *Overridden : MethodDecl->overridden_methods()) {
        if (isMarkedAsCallSuper(Overridden)) {
          OverriddenMarkedMethods.insert(Overridden);
        }
      }
            MethodUsageVisitor Visitor(OverriddenMarkedMethods);
      Visitor.TraverseDecl(MethodDecl);
                  for (const auto &LeftOverriddens : OverriddenMarkedMethods) {
        Diags.Report(MethodDecl->getLocation(), WarningSuperNotCalled)
            << LeftOverriddens << MethodDecl;
        Diags.Report(LeftOverriddens->getLocation(),
                     NotePreviousCallSuperDeclaration);
      }
    }
    return true;
  }
private:
  DiagnosticsEngine &Diags;
  unsigned WarningSuperNotCalled;
  unsigned NotePreviousCallSuperDeclaration;
};
class CallSuperConsumer : public ASTConsumer {
public:
  void HandleTranslationUnit(ASTContext &Context) override {
    auto &Diags = Context.getDiagnostics();
    for (const auto *Method : MarkedMethods) {
      lateDiagAppertainsToDecl(Diags, Method);
    }
    CallSuperVisitor Visitor(Context.getDiagnostics());
    Visitor.TraverseDecl(Context.getTranslationUnitDecl());
  }
private:
          void lateDiagAppertainsToDecl(DiagnosticsEngine &Diags,
                                const CXXMethodDecl *MethodDecl) {
    if (MethodDecl->hasAttr<FinalAttr>()) {
      unsigned ID = Diags.getCustomDiagID(
          DiagnosticsEngine::Warning,
          "'call_super' attribute marked on a final method");
      Diags.Report(MethodDecl->getLocation(), ID);
    }
  }
};
class CallSuperAction : public PluginASTAction {
public:
  std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
                                                 llvm::StringRef) override {
    return std::make_unique<CallSuperConsumer>();
  }
  bool ParseArgs(const CompilerInstance &CI,
                 const std::vector<std::string> &args) override {
    if (!args.empty() && args[0] == "help")
      llvm::errs() << "Help for the CallSuperAttr plugin goes here\n";
    return true;
  }
  PluginASTAction::ActionType getActionType() override {
    return AddBeforeMainAction;
  }
};
struct CallSuperAttrInfo : public ParsedAttrInfo {
  CallSuperAttrInfo() {
    OptArgs = 0;
    static constexpr Spelling S[] = {
        {ParsedAttr::AS_GNU, "call_super"},
        {ParsedAttr::AS_CXX11, "clang::call_super"}};
    Spellings = S;
  }
  bool diagAppertainsToDecl(Sema &S, const ParsedAttr &Attr,
                            const Decl *D) const override {
    const auto *TheMethod = dyn_cast_or_null<CXXMethodDecl>(D);
    if (!TheMethod || !TheMethod->isVirtual()) {
      S.Diag(Attr.getLoc(), diag::warn_attribute_wrong_decl_type_str)
          << Attr << "virtual functions";
      return false;
    }
    MarkedMethods.insert(TheMethod);
    return true;
  }
  AttrHandling handleDeclAttribute(Sema &S, Decl *D,
                                   const ParsedAttr &Attr) const override {
                return AttributeNotApplied;
  }
};
} static FrontendPluginRegistry::Add<CallSuperAction>
    X("call_super_plugin", "clang plugin, checks every overridden virtual "
                           "function whether called this function or not.");
static ParsedAttrInfoRegistry::Add<CallSuperAttrInfo>
    Y("call_super_attr", "Attr plugin to define 'call_super' attribute");