#ifndef LLVM_IR_MATRIXBUILDER_H
#define LLVM_IR_MATRIXBUILDER_H
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/Alignment.h"
namespace llvm {
class Function;
class Twine;
class Module;
class MatrixBuilder {
IRBuilderBase &B;
Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
Value *RHS) {
assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
"One of the operands must be a matrix (embedded in a vector)");
if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
assert(!isa<ScalableVectorType>(LHS->getType()) &&
"LHS Assumed to be fixed width");
RHS = B.CreateVectorSplat(
cast<VectorType>(LHS->getType())->getElementCount(), RHS,
"scalar.splat");
} else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
assert(!isa<ScalableVectorType>(RHS->getType()) &&
"RHS Assumed to be fixed width");
LHS = B.CreateVectorSplat(
cast<VectorType>(RHS->getType())->getElementCount(), LHS,
"scalar.splat");
}
return {LHS, RHS};
}
public:
MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {}
CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment,
Value *Stride, bool IsVolatile, unsigned Rows,
unsigned Columns, const Twine &Name = "") {
auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
B.getInt32(Columns)};
Type *OverloadedTypes[] = {RetType, Stride->getType()};
Function *TheFn = Intrinsic::getDeclaration(
getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
Attribute AlignAttr =
Attribute::getWithAlignment(Call->getContext(), Alignment);
Call->addParamAttr(0, AlignAttr);
return Call;
}
CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment,
Value *Stride, bool IsVolatile,
unsigned Rows, unsigned Columns,
const Twine &Name = "") {
Value *Ops[] = {Matrix, Ptr,
Stride, B.getInt1(IsVolatile),
B.getInt32(Rows), B.getInt32(Columns)};
Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()};
Function *TheFn = Intrinsic::getDeclaration(
getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
Attribute AlignAttr =
Attribute::getWithAlignment(Call->getContext(), Alignment);
Call->addParamAttr(1, AlignAttr);
return Call;
}
CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows,
unsigned Columns, const Twine &Name = "") {
auto *OpType = cast<VectorType>(Matrix->getType());
auto *ReturnType =
FixedVectorType::get(OpType->getElementType(), Rows * Columns);
Type *OverloadedTypes[] = {ReturnType};
Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
Function *TheFn = Intrinsic::getDeclaration(
getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
}
CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
unsigned LHSColumns, unsigned RHSColumns,
const Twine &Name = "") {
auto *LHSType = cast<VectorType>(LHS->getType());
auto *RHSType = cast<VectorType>(RHS->getType());
auto *ReturnType =
FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
B.getInt32(RHSColumns)};
Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
Function *TheFn = Intrinsic::getDeclaration(
getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
}
Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx,
Value *ColumnIdx, unsigned NumRows) {
return B.CreateInsertElement(
Matrix, NewVal,
B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
ColumnIdx->getType(), NumRows)),
RowIdx));
}
Value *CreateAdd(Value *LHS, Value *RHS) {
assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
assert(!isa<ScalableVectorType>(LHS->getType()) &&
"LHS Assumed to be fixed width");
RHS = B.CreateVectorSplat(
cast<VectorType>(LHS->getType())->getElementCount(), RHS,
"scalar.splat");
} else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
assert(!isa<ScalableVectorType>(RHS->getType()) &&
"RHS Assumed to be fixed width");
LHS = B.CreateVectorSplat(
cast<VectorType>(RHS->getType())->getElementCount(), LHS,
"scalar.splat");
}
return cast<VectorType>(LHS->getType())
->getElementType()
->isFloatingPointTy()
? B.CreateFAdd(LHS, RHS)
: B.CreateAdd(LHS, RHS);
}
Value *CreateSub(Value *LHS, Value *RHS) {
assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
assert(!isa<ScalableVectorType>(LHS->getType()) &&
"LHS Assumed to be fixed width");
RHS = B.CreateVectorSplat(
cast<VectorType>(LHS->getType())->getElementCount(), RHS,
"scalar.splat");
} else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
assert(!isa<ScalableVectorType>(RHS->getType()) &&
"RHS Assumed to be fixed width");
LHS = B.CreateVectorSplat(
cast<VectorType>(RHS->getType())->getElementCount(), LHS,
"scalar.splat");
}
return cast<VectorType>(LHS->getType())
->getElementType()
->isFloatingPointTy()
? B.CreateFSub(LHS, RHS)
: B.CreateSub(LHS, RHS);
}
Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
if (LHS->getType()->getScalarType()->isFloatingPointTy())
return B.CreateFMul(LHS, RHS);
return B.CreateMul(LHS, RHS);
}
Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
assert(!isa<ScalableVectorType>(LHS->getType()) &&
"LHS Assumed to be fixed width");
RHS =
B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(),
RHS, "scalar.splat");
return cast<VectorType>(LHS->getType())
->getElementType()
->isFloatingPointTy()
? B.CreateFDiv(LHS, RHS)
: (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
}
void CreateIndexAssumption(Value *Idx, unsigned NumElements,
Twine const &Name = "") {
Value *NumElts =
B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements);
auto *Cmp = B.CreateICmpULT(Idx, NumElts);
if (isa<ConstantInt>(Cmp))
assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!");
else
B.CreateAssumption(Cmp);
}
Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
Twine const &Name = "") {
unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
ColumnIdx->getType()->getScalarSizeInBits());
Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
RowIdx = B.CreateZExt(RowIdx, IntTy);
ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
}
};
}
#endif