#include "SparcTargetMachine.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/SelectionDAGISel.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm;
namespace {
class SparcDAGToDAGISel : public SelectionDAGISel {
      const SparcSubtarget *Subtarget = nullptr;
public:
  explicit SparcDAGToDAGISel(SparcTargetMachine &tm) : SelectionDAGISel(tm) {}
  bool runOnMachineFunction(MachineFunction &MF) override {
    Subtarget = &MF.getSubtarget<SparcSubtarget>();
    return SelectionDAGISel::runOnMachineFunction(MF);
  }
  void Select(SDNode *N) override;
    bool SelectADDRrr(SDValue N, SDValue &R1, SDValue &R2);
  bool SelectADDRri(SDValue N, SDValue &Base, SDValue &Offset);
      bool SelectInlineAsmMemoryOperand(const SDValue &Op,
                                    unsigned ConstraintID,
                                    std::vector<SDValue> &OutOps) override;
  StringRef getPassName() const override {
    return "SPARC DAG->DAG Pattern Instruction Selection";
  }
  #include "SparcGenDAGISel.inc"
private:
  SDNode* getGlobalBaseReg();
  bool tryInlineAsm(SDNode *N);
};
}  
SDNode* SparcDAGToDAGISel::getGlobalBaseReg() {
  Register GlobalBaseReg = Subtarget->getInstrInfo()->getGlobalBaseReg(MF);
  return CurDAG->getRegister(GlobalBaseReg,
                             TLI->getPointerTy(CurDAG->getDataLayout()))
      .getNode();
}
bool SparcDAGToDAGISel::SelectADDRri(SDValue Addr,
                                     SDValue &Base, SDValue &Offset) {
  if (FrameIndexSDNode *FIN = dyn_cast<FrameIndexSDNode>(Addr)) {
    Base = CurDAG->getTargetFrameIndex(
        FIN->getIndex(), TLI->getPointerTy(CurDAG->getDataLayout()));
    Offset = CurDAG->getTargetConstant(0, SDLoc(Addr), MVT::i32);
    return true;
  }
  if (Addr.getOpcode() == ISD::TargetExternalSymbol ||
      Addr.getOpcode() == ISD::TargetGlobalAddress ||
      Addr.getOpcode() == ISD::TargetGlobalTLSAddress)
    return false;  
  if (Addr.getOpcode() == ISD::ADD) {
    if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1))) {
      if (isInt<13>(CN->getSExtValue())) {
        if (FrameIndexSDNode *FIN =
                dyn_cast<FrameIndexSDNode>(Addr.getOperand(0))) {
                    Base = CurDAG->getTargetFrameIndex(
              FIN->getIndex(), TLI->getPointerTy(CurDAG->getDataLayout()));
        } else {
          Base = Addr.getOperand(0);
        }
        Offset = CurDAG->getTargetConstant(CN->getZExtValue(), SDLoc(Addr),
                                           MVT::i32);
        return true;
      }
    }
    if (Addr.getOperand(0).getOpcode() == SPISD::Lo) {
      Base = Addr.getOperand(1);
      Offset = Addr.getOperand(0).getOperand(0);
      return true;
    }
    if (Addr.getOperand(1).getOpcode() == SPISD::Lo) {
      Base = Addr.getOperand(0);
      Offset = Addr.getOperand(1).getOperand(0);
      return true;
    }
  }
  Base = Addr;
  Offset = CurDAG->getTargetConstant(0, SDLoc(Addr), MVT::i32);
  return true;
}
bool SparcDAGToDAGISel::SelectADDRrr(SDValue Addr, SDValue &R1, SDValue &R2) {
  if (Addr.getOpcode() == ISD::FrameIndex) return false;
  if (Addr.getOpcode() == ISD::TargetExternalSymbol ||
      Addr.getOpcode() == ISD::TargetGlobalAddress ||
      Addr.getOpcode() == ISD::TargetGlobalTLSAddress)
    return false;  
  if (Addr.getOpcode() == ISD::ADD) {
    if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1)))
      if (isInt<13>(CN->getSExtValue()))
        return false;      if (Addr.getOperand(0).getOpcode() == SPISD::Lo ||
        Addr.getOperand(1).getOpcode() == SPISD::Lo)
      return false;      R1 = Addr.getOperand(0);
    R2 = Addr.getOperand(1);
    return true;
  }
  R1 = Addr;
  R2 = CurDAG->getRegister(SP::G0, TLI->getPointerTy(CurDAG->getDataLayout()));
  return true;
}
bool SparcDAGToDAGISel::tryInlineAsm(SDNode *N){
  std::vector<SDValue> AsmNodeOperands;
  unsigned Flag, Kind;
  bool Changed = false;
  unsigned NumOps = N->getNumOperands();
      
      
  SDLoc dl(N);
  SDValue Glue = N->getGluedNode() ? N->getOperand(NumOps - 1) : SDValue();
  SmallVector<bool, 8> OpChanged;
    for(unsigned i = 0, e = N->getGluedNode() ? NumOps - 1 : NumOps; i < e; ++i) {
    SDValue op = N->getOperand(i);
    AsmNodeOperands.push_back(op);
    if (i < InlineAsm::Op_FirstOperand)
      continue;
    if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(N->getOperand(i))) {
      Flag = C->getZExtValue();
      Kind = InlineAsm::getKind(Flag);
    }
    else
      continue;
                    if (Kind == InlineAsm::Kind_Imm) {
      SDValue op = N->getOperand(++i);
      AsmNodeOperands.push_back(op);
      continue;
    }
    unsigned NumRegs = InlineAsm::getNumOperandRegisters(Flag);
    if (NumRegs)
      OpChanged.push_back(false);
    unsigned DefIdx = 0;
    bool IsTiedToChangedOp = false;
            if (Changed && InlineAsm::isUseOperandTiedToDef(Flag, DefIdx))
      IsTiedToChangedOp = OpChanged[DefIdx];
    if (Kind != InlineAsm::Kind_RegUse && Kind != InlineAsm::Kind_RegDef
        && Kind != InlineAsm::Kind_RegDefEarlyClobber)
      continue;
    unsigned RC;
    bool HasRC = InlineAsm::hasRegClassConstraint(Flag, RC);
    if ((!IsTiedToChangedOp && (!HasRC || RC != SP::IntRegsRegClassID))
        || NumRegs != 2)
      continue;
    assert((i+2 < NumOps) && "Invalid number of operands in inline asm");
    SDValue V0 = N->getOperand(i+1);
    SDValue V1 = N->getOperand(i+2);
    Register Reg0 = cast<RegisterSDNode>(V0)->getReg();
    Register Reg1 = cast<RegisterSDNode>(V1)->getReg();
    SDValue PairedReg;
    MachineRegisterInfo &MRI = MF->getRegInfo();
    if (Kind == InlineAsm::Kind_RegDef ||
        Kind == InlineAsm::Kind_RegDefEarlyClobber) {
            
      Register GPVR = MRI.createVirtualRegister(&SP::IntPairRegClass);
      PairedReg = CurDAG->getRegister(GPVR, MVT::v2i32);
      SDValue Chain = SDValue(N,0);
      SDNode *GU = N->getGluedUser();
      SDValue RegCopy = CurDAG->getCopyFromReg(Chain, dl, GPVR, MVT::v2i32,
                                               Chain.getValue(1));
            SDValue Sub0 = CurDAG->getTargetExtractSubreg(SP::sub_even, dl, MVT::i32,
                                                    RegCopy);
      SDValue Sub1 = CurDAG->getTargetExtractSubreg(SP::sub_odd, dl, MVT::i32,
                                                    RegCopy);
      SDValue T0 = CurDAG->getCopyToReg(Sub0, dl, Reg0, Sub0,
                                        RegCopy.getValue(1));
      SDValue T1 = CurDAG->getCopyToReg(Sub1, dl, Reg1, Sub1, T0.getValue(1));
            std::vector<SDValue> Ops(GU->op_begin(), GU->op_end()-1);
      Ops.push_back(T1.getValue(1));
      CurDAG->UpdateNodeOperands(GU, Ops);
    }
    else {
                  SDValue Chain = AsmNodeOperands[InlineAsm::Op_InputChain];
            SDValue T0 = CurDAG->getCopyFromReg(Chain, dl, Reg0, MVT::i32,
                                          Chain.getValue(1));
      SDValue T1 = CurDAG->getCopyFromReg(Chain, dl, Reg1, MVT::i32,
                                          T0.getValue(1));
      SDValue Pair = SDValue(
          CurDAG->getMachineNode(
              TargetOpcode::REG_SEQUENCE, dl, MVT::v2i32,
              {
                  CurDAG->getTargetConstant(SP::IntPairRegClassID, dl,
                                            MVT::i32),
                  T0,
                  CurDAG->getTargetConstant(SP::sub_even, dl, MVT::i32),
                  T1,
                  CurDAG->getTargetConstant(SP::sub_odd, dl, MVT::i32),
              }),
          0);
                  Register GPVR = MRI.createVirtualRegister(&SP::IntPairRegClass);
      PairedReg = CurDAG->getRegister(GPVR, MVT::v2i32);
      Chain = CurDAG->getCopyToReg(T1, dl, GPVR, Pair, T1.getValue(1));
      AsmNodeOperands[InlineAsm::Op_InputChain] = Chain;
      Glue = Chain.getValue(1);
    }
    Changed = true;
    if(PairedReg.getNode()) {
      OpChanged[OpChanged.size() -1 ] = true;
      Flag = InlineAsm::getFlagWord(Kind, 1 );
      if (IsTiedToChangedOp)
        Flag = InlineAsm::getFlagWordForMatchingOp(Flag, DefIdx);
      else
        Flag = InlineAsm::getFlagWordForRegClass(Flag, SP::IntPairRegClassID);
            AsmNodeOperands[AsmNodeOperands.size() -1] = CurDAG->getTargetConstant(
          Flag, dl, MVT::i32);
            AsmNodeOperands.push_back(PairedReg);
            i += 2;
    }
  }
  if (Glue.getNode())
    AsmNodeOperands.push_back(Glue);
  if (!Changed)
    return false;
  SelectInlineAsmMemoryOperands(AsmNodeOperands, SDLoc(N));
  SDValue New = CurDAG->getNode(N->getOpcode(), SDLoc(N),
      CurDAG->getVTList(MVT::Other, MVT::Glue), AsmNodeOperands);
  New->setNodeId(-1);
  ReplaceNode(N, New.getNode());
  return true;
}
void SparcDAGToDAGISel::Select(SDNode *N) {
  SDLoc dl(N);
  if (N->isMachineOpcode()) {
    N->setNodeId(-1);
    return;     }
  switch (N->getOpcode()) {
  default: break;
  case ISD::INLINEASM:
  case ISD::INLINEASM_BR: {
    if (tryInlineAsm(N))
      return;
    break;
  }
  case SPISD::GLOBAL_BASE_REG:
    ReplaceNode(N, getGlobalBaseReg());
    return;
  case ISD::SDIV:
  case ISD::UDIV: {
        if (N->getValueType(0) == MVT::i64)
      break;
        SDValue DivLHS = N->getOperand(0);
    SDValue DivRHS = N->getOperand(1);
        SDValue TopPart;
    if (N->getOpcode() == ISD::SDIV) {
      TopPart = SDValue(CurDAG->getMachineNode(SP::SRAri, dl, MVT::i32, DivLHS,
                                   CurDAG->getTargetConstant(31, dl, MVT::i32)),
                        0);
    } else {
      TopPart = CurDAG->getRegister(SP::G0, MVT::i32);
    }
    TopPart = CurDAG->getCopyToReg(CurDAG->getEntryNode(), dl, SP::Y, TopPart,
                                   SDValue())
                  .getValue(1);
        unsigned Opcode = N->getOpcode() == ISD::SDIV ? SP::SDIVrr : SP::UDIVrr;
    CurDAG->SelectNodeTo(N, Opcode, MVT::i32, DivLHS, DivRHS, TopPart);
    return;
  }
  }
  SelectCode(N);
}
bool
SparcDAGToDAGISel::SelectInlineAsmMemoryOperand(const SDValue &Op,
                                                unsigned ConstraintID,
                                                std::vector<SDValue> &OutOps) {
  SDValue Op0, Op1;
  switch (ConstraintID) {
  default: return true;
  case InlineAsm::Constraint_o:
  case InlineAsm::Constraint_m:    if (!SelectADDRrr(Op, Op0, Op1))
     SelectADDRri(Op, Op0, Op1);
   break;
  }
  OutOps.push_back(Op0);
  OutOps.push_back(Op1);
  return false;
}
FunctionPass *llvm::createSparcISelDag(SparcTargetMachine &TM) {
  return new SparcDAGToDAGISel(TM);
}