#include "AArch64MacroFusion.h"
#include "AArch64Subtarget.h"
#include "llvm/CodeGen/MacroFusion.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
using namespace llvm;
static bool isArithmeticBccPair(const MachineInstr *FirstMI,
const MachineInstr &SecondMI, bool CmpOnly) {
if (SecondMI.getOpcode() != AArch64::Bcc)
return false;
if (FirstMI == nullptr)
return true;
if (CmpOnly && !(FirstMI->getOperand(0).getReg() == AArch64::XZR ||
FirstMI->getOperand(0).getReg() == AArch64::WZR)) {
return false;
}
switch (FirstMI->getOpcode()) {
case AArch64::ADDSWri:
case AArch64::ADDSWrr:
case AArch64::ADDSXri:
case AArch64::ADDSXrr:
case AArch64::ANDSWri:
case AArch64::ANDSWrr:
case AArch64::ANDSXri:
case AArch64::ANDSXrr:
case AArch64::SUBSWri:
case AArch64::SUBSWrr:
case AArch64::SUBSXri:
case AArch64::SUBSXrr:
case AArch64::BICSWrr:
case AArch64::BICSXrr:
return true;
case AArch64::ADDSWrs:
case AArch64::ADDSXrs:
case AArch64::ANDSWrs:
case AArch64::ANDSXrs:
case AArch64::SUBSWrs:
case AArch64::SUBSXrs:
case AArch64::BICSWrs:
case AArch64::BICSXrs:
return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
}
return false;
}
static bool isArithmeticCbzPair(const MachineInstr *FirstMI,
const MachineInstr &SecondMI) {
if (SecondMI.getOpcode() != AArch64::CBZW &&
SecondMI.getOpcode() != AArch64::CBZX &&
SecondMI.getOpcode() != AArch64::CBNZW &&
SecondMI.getOpcode() != AArch64::CBNZX)
return false;
if (FirstMI == nullptr)
return true;
switch (FirstMI->getOpcode()) {
case AArch64::ADDWri:
case AArch64::ADDWrr:
case AArch64::ADDXri:
case AArch64::ADDXrr:
case AArch64::ANDWri:
case AArch64::ANDWrr:
case AArch64::ANDXri:
case AArch64::ANDXrr:
case AArch64::EORWri:
case AArch64::EORWrr:
case AArch64::EORXri:
case AArch64::EORXrr:
case AArch64::ORRWri:
case AArch64::ORRWrr:
case AArch64::ORRXri:
case AArch64::ORRXrr:
case AArch64::SUBWri:
case AArch64::SUBWrr:
case AArch64::SUBXri:
case AArch64::SUBXrr:
return true;
case AArch64::ADDWrs:
case AArch64::ADDXrs:
case AArch64::ANDWrs:
case AArch64::ANDXrs:
case AArch64::SUBWrs:
case AArch64::SUBXrs:
case AArch64::BICWrs:
case AArch64::BICXrs:
return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
}
return false;
}
static bool isAESPair(const MachineInstr *FirstMI,
const MachineInstr &SecondMI) {
switch (SecondMI.getOpcode()) {
case AArch64::AESMCrr:
case AArch64::AESMCrrTied:
return FirstMI == nullptr || FirstMI->getOpcode() == AArch64::AESErr;
case AArch64::AESIMCrr:
case AArch64::AESIMCrrTied:
return FirstMI == nullptr || FirstMI->getOpcode() == AArch64::AESDrr;
}
return false;
}
static bool isCryptoEORPair(const MachineInstr *FirstMI,
const MachineInstr &SecondMI) {
if (SecondMI.getOpcode() != AArch64::EORv16i8)
return false;
if (FirstMI == nullptr)
return true;
switch (FirstMI->getOpcode()) {
case AArch64::AESErr:
case AArch64::AESDrr:
case AArch64::PMULLv16i8:
case AArch64::PMULLv8i8:
case AArch64::PMULLv1i64:
case AArch64::PMULLv2i64:
return true;
}
return false;
}
static bool isAdrpAddPair(const MachineInstr *FirstMI,
const MachineInstr &SecondMI) {
if ((FirstMI == nullptr || FirstMI->getOpcode() == AArch64::ADRP) &&
SecondMI.getOpcode() == AArch64::ADDXri)
return true;
return false;
}
static bool isLiteralsPair(const MachineInstr *FirstMI,
const MachineInstr &SecondMI) {
if ((FirstMI == nullptr || FirstMI->getOpcode() == AArch64::MOVZWi) &&
(SecondMI.getOpcode() == AArch64::MOVKWi &&
SecondMI.getOperand(3).getImm() == 16))
return true;
if((FirstMI == nullptr || FirstMI->getOpcode() == AArch64::MOVZXi) &&
(SecondMI.getOpcode() == AArch64::MOVKXi &&
SecondMI.getOperand(3).getImm() == 16))
return true;
if ((FirstMI == nullptr ||
(FirstMI->getOpcode() == AArch64::MOVKXi &&
FirstMI->getOperand(3).getImm() == 32)) &&
(SecondMI.getOpcode() == AArch64::MOVKXi &&
SecondMI.getOperand(3).getImm() == 48))
return true;
return false;
}
static bool isAddressLdStPair(const MachineInstr *FirstMI,
const MachineInstr &SecondMI) {
switch (SecondMI.getOpcode()) {
case AArch64::STRBBui:
case AArch64::STRBui:
case AArch64::STRDui:
case AArch64::STRHHui:
case AArch64::STRHui:
case AArch64::STRQui:
case AArch64::STRSui:
case AArch64::STRWui:
case AArch64::STRXui:
case AArch64::LDRBBui:
case AArch64::LDRBui:
case AArch64::LDRDui:
case AArch64::LDRHHui:
case AArch64::LDRHui:
case AArch64::LDRQui:
case AArch64::LDRSui:
case AArch64::LDRWui:
case AArch64::LDRXui:
case AArch64::LDRSBWui:
case AArch64::LDRSBXui:
case AArch64::LDRSHWui:
case AArch64::LDRSHXui:
case AArch64::LDRSWui:
if (FirstMI == nullptr)
return true;
switch (FirstMI->getOpcode()) {
case AArch64::ADR:
return SecondMI.getOperand(2).getImm() == 0;
case AArch64::ADRP:
return true;
}
}
return false;
}
static bool isCCSelectPair(const MachineInstr *FirstMI,
const MachineInstr &SecondMI) {
if (SecondMI.getOpcode() == AArch64::CSELWr) {
if (FirstMI == nullptr)
return true;
if (FirstMI->definesRegister(AArch64::WZR))
switch (FirstMI->getOpcode()) {
case AArch64::SUBSWrs:
return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
case AArch64::SUBSWrx:
return !AArch64InstrInfo::hasExtendedReg(*FirstMI);
case AArch64::SUBSWrr:
case AArch64::SUBSWri:
return true;
}
}
if (SecondMI.getOpcode() == AArch64::CSELXr) {
if (FirstMI == nullptr)
return true;
if (FirstMI->definesRegister(AArch64::XZR))
switch (FirstMI->getOpcode()) {
case AArch64::SUBSXrs:
return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
case AArch64::SUBSXrx:
case AArch64::SUBSXrx64:
return !AArch64InstrInfo::hasExtendedReg(*FirstMI);
case AArch64::SUBSXrr:
case AArch64::SUBSXri:
return true;
}
}
return false;
}
static bool isArithmeticLogicPair(const MachineInstr *FirstMI,
const MachineInstr &SecondMI) {
if (AArch64InstrInfo::hasShiftedReg(SecondMI))
return false;
switch (SecondMI.getOpcode()) {
case AArch64::ADDWrr:
case AArch64::ADDXrr:
case AArch64::SUBWrr:
case AArch64::SUBXrr:
case AArch64::ADDWrs:
case AArch64::ADDXrs:
case AArch64::SUBWrs:
case AArch64::SUBXrs:
case AArch64::ANDWrr:
case AArch64::ANDXrr:
case AArch64::BICWrr:
case AArch64::BICXrr:
case AArch64::EONWrr:
case AArch64::EONXrr:
case AArch64::EORWrr:
case AArch64::EORXrr:
case AArch64::ORNWrr:
case AArch64::ORNXrr:
case AArch64::ORRWrr:
case AArch64::ORRXrr:
case AArch64::ANDWrs:
case AArch64::ANDXrs:
case AArch64::BICWrs:
case AArch64::BICXrs:
case AArch64::EONWrs:
case AArch64::EONXrs:
case AArch64::EORWrs:
case AArch64::EORXrs:
case AArch64::ORNWrs:
case AArch64::ORNXrs:
case AArch64::ORRWrs:
case AArch64::ORRXrs:
if (FirstMI == nullptr)
return true;
switch (FirstMI->getOpcode()) {
case AArch64::ADDWrr:
case AArch64::ADDXrr:
case AArch64::ADDSWrr:
case AArch64::ADDSXrr:
case AArch64::SUBWrr:
case AArch64::SUBXrr:
case AArch64::SUBSWrr:
case AArch64::SUBSXrr:
return true;
case AArch64::ADDWrs:
case AArch64::ADDXrs:
case AArch64::ADDSWrs:
case AArch64::ADDSXrs:
case AArch64::SUBWrs:
case AArch64::SUBXrs:
case AArch64::SUBSWrs:
case AArch64::SUBSXrs:
return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
}
break;
case AArch64::ADDSWrr:
case AArch64::ADDSXrr:
case AArch64::SUBSWrr:
case AArch64::SUBSXrr:
case AArch64::ADDSWrs:
case AArch64::ADDSXrs:
case AArch64::SUBSWrs:
case AArch64::SUBSXrs:
if (FirstMI == nullptr)
return true;
switch (FirstMI->getOpcode()) {
case AArch64::ADDWrr:
case AArch64::ADDXrr:
case AArch64::SUBWrr:
case AArch64::SUBXrr:
return true;
case AArch64::ADDWrs:
case AArch64::ADDXrs:
case AArch64::SUBWrs:
case AArch64::SUBXrs:
return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
}
break;
}
return false;
}
static bool shouldScheduleAdjacent(const TargetInstrInfo &TII,
const TargetSubtargetInfo &TSI,
const MachineInstr *FirstMI,
const MachineInstr &SecondMI) {
const AArch64Subtarget &ST = static_cast<const AArch64Subtarget&>(TSI);
if (ST.hasCmpBccFusion() || ST.hasArithmeticBccFusion()) {
bool CmpOnly = !ST.hasArithmeticBccFusion();
if (isArithmeticBccPair(FirstMI, SecondMI, CmpOnly))
return true;
}
if (ST.hasArithmeticCbzFusion() && isArithmeticCbzPair(FirstMI, SecondMI))
return true;
if (ST.hasFuseAES() && isAESPair(FirstMI, SecondMI))
return true;
if (ST.hasFuseCryptoEOR() && isCryptoEORPair(FirstMI, SecondMI))
return true;
if (ST.hasFuseAdrpAdd() && isAdrpAddPair(FirstMI, SecondMI))
return true;
if (ST.hasFuseLiterals() && isLiteralsPair(FirstMI, SecondMI))
return true;
if (ST.hasFuseAddress() && isAddressLdStPair(FirstMI, SecondMI))
return true;
if (ST.hasFuseCCSelect() && isCCSelectPair(FirstMI, SecondMI))
return true;
if (ST.hasFuseArithmeticLogic() && isArithmeticLogicPair(FirstMI, SecondMI))
return true;
return false;
}
std::unique_ptr<ScheduleDAGMutation>
llvm::createAArch64MacroFusionDAGMutation() {
return createMacroFusionDAGMutation(shouldScheduleAdjacent);
}