blob: b866afbce98b0cb6c708709b716fa313613713fd [file] [log] [blame] [edit]
//===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert MemRef dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Visitors.h"
#include <cassert>
#include <optional>
#define DEBUG_TYPE "memref-to-spirv-pattern"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//
/// Returns the offset of the value in `targetBits` representation.
///
/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
/// It's assumed to be non-negative.
///
/// When accessing an element in the array treating as having elements of
/// `targetBits`, multiple values are loaded in the same time. The method
/// returns the offset where the `srcIdx` locates in the value. For example, if
/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
/// located at (x % 4) * 8. Because there are four elements in one i32, and one
/// element has 8 bits.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
int targetBits, OpBuilder &builder) {
assert(targetBits % sourceBits == 0);
Type type = srcIdx.getType();
IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
auto srcBitsValue =
builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
}
/// Returns an adjusted spirv::AccessChainOp. Based on the
/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
/// supported. During conversion if a memref of an unsupported type is used,
/// load/stores to this memref need to be modified to use a supported higher
/// bitwidth `targetBits` and extracting the required bits. For an accessing a
/// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
/// the bits needed. The extraction of the actual bits needed are handled
/// separately. Note that this only works for a 1-D tensor.
static Value
adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
spirv::AccessChainOp op, int sourceBits,
int targetBits, OpBuilder &builder) {
assert(targetBits % sourceBits == 0);
const auto loc = op.getLoc();
Value lastDim = op->getOperand(op.getNumOperands() - 1);
Type type = lastDim.getType();
IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
auto indices = llvm::to_vector<4>(op.getIndices());
// There are two elements if this is a 1-D tensor.
assert(indices.size() == 2);
indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
Type t = typeConverter.convertType(op.getComponentPtr().getType());
return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
}
/// Casts the given `srcBool` into an integer of `dstType`.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
OpBuilder &builder) {
assert(srcBool.getType().isInteger(1));
if (dstType.isInteger(1))
return srcBool;
Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
zero);
}
/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
/// to the type destination type, and masked.
static Value shiftValue(Location loc, Value value, Value offset, Value mask,
OpBuilder &builder) {
IntegerType dstType = cast<IntegerType>(mask.getType());
int targetBits = static_cast<int>(dstType.getWidth());
int valueBits = value.getType().getIntOrFloatBitWidth();
assert(valueBits <= targetBits);
if (valueBits == 1) {
value = castBoolToIntN(loc, value, dstType, builder);
} else {
if (valueBits < targetBits) {
value = builder.create<spirv::UConvertOp>(
loc, builder.getIntegerType(targetBits), value);
}
value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
}
return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
value, offset);
}
/// Returns true if the allocations of memref `type` generated from `allocOp`
/// can be lowered to SPIR-V.
static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
return false;
} else if (isa<memref::AllocaOp>(allocOp)) {
auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
if (!sc || sc.getValue() != spirv::StorageClass::Function)
return false;
} else {
return false;
}
// Currently only support static shape and int or float or vector of int or
// float element type.
if (!type.hasStaticShape())
return false;
Type elementType = type.getElementType();
if (auto vecType = dyn_cast<VectorType>(elementType))
elementType = vecType.getElementType();
return elementType.isIntOrFloat();
}
/// Returns the scope to use for atomic operations use for emulating store
/// operations of unsupported integer bitwidths, based on the memref
/// type. Returns std::nullopt on failure.
static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
switch (sc.getValue()) {
case spirv::StorageClass::StorageBuffer:
return spirv::Scope::Device;
case spirv::StorageClass::Workgroup:
return spirv::Scope::Workgroup;
default:
break;
}
return {};
}
/// Casts the given `srcInt` into a boolean value.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
if (srcInt.getType().isInteger(1))
return srcInt;
auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
}
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
// Note that DRR cannot be used for the patterns in this file: we may need to
// convert type along the way, which requires ConversionPattern. DRR generates
// normal RewritePattern.
namespace {
/// Converts memref.alloca to SPIR-V Function variables.
class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
public:
using OpConversionPattern<memref::AllocaOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts an allocation operation to SPIR-V. Currently only supports lowering
/// to Workgroup memory when the size is constant. Note that this pattern needs
/// to be applied in a pass that runs at least at spirv.module scope since it
/// wil ladd global variables into the spirv.module.
class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
public:
using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts memref.automic_rmw operations to SPIR-V atomic operations.
class AtomicRMWOpPattern final
: public OpConversionPattern<memref::AtomicRMWOp> {
public:
using OpConversionPattern<memref::AtomicRMWOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Removed a deallocation if it is a supported allocation. Currently only
/// removes deallocation if the memory space is workgroup memory.
class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
public:
using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
public:
using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts memref.load to spirv.Load + spirv.AccessChain.
class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
public:
using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts memref.store to spirv.Store on integers.
class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
public:
using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts memref.memory_space_cast to the appropriate spirv cast operations.
class MemorySpaceCastOpPattern final
: public OpConversionPattern<memref::MemorySpaceCastOp> {
public:
using OpConversionPattern<memref::MemorySpaceCastOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts memref.store to spirv.Store.
class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
public:
using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
class ReinterpretCastPattern final
: public OpConversionPattern<memref::ReinterpretCastOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
class CastPattern final : public OpConversionPattern<memref::CastOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value src = adaptor.getSource();
Type srcType = src.getType();
const TypeConverter *converter = getTypeConverter();
Type dstType = converter->convertType(op.getType());
if (srcType != dstType)
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "types doesn't match: " << srcType << " and " << dstType;
});
rewriter.replaceOp(op, src);
return success();
}
};
/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
class ExtractAlignedPointerAsIndexOpPattern final
: public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
//===----------------------------------------------------------------------===//
// AllocaOp
//===----------------------------------------------------------------------===//
LogicalResult
AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MemRefType allocType = allocaOp.getType();
if (!isAllocationSupported(allocaOp, allocType))
return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
// Get the SPIR-V type for the allocation.
Type spirvType = getTypeConverter()->convertType(allocType);
if (!spirvType)
return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
spirv::StorageClass::Function,
/*initializer=*/nullptr);
return success();
}
//===----------------------------------------------------------------------===//
// AllocOp
//===----------------------------------------------------------------------===//
LogicalResult
AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MemRefType allocType = operation.getType();
if (!isAllocationSupported(operation, allocType))
return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
// Get the SPIR-V type for the allocation.
Type spirvType = getTypeConverter()->convertType(allocType);
if (!spirvType)
return rewriter.notifyMatchFailure(operation, "type conversion failed");
// Insert spirv.GlobalVariable for this allocation.
Operation *parent =
SymbolTable::getNearestSymbolTable(operation->getParentOp());
if (!parent)
return failure();
Location loc = operation.getLoc();
spirv::GlobalVariableOp varOp;
{
OpBuilder::InsertionGuard guard(rewriter);
Block &entryBlock = *parent->getRegion(0).begin();
rewriter.setInsertionPointToStart(&entryBlock);
auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
std::string varName =
std::string("__workgroup_mem__") +
std::to_string(std::distance(varOps.begin(), varOps.end()));
varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
/*initializer=*/nullptr);
}
// Get pointer to global variable at the current scope.
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
return success();
}
//===----------------------------------------------------------------------===//
// AllocOp
//===----------------------------------------------------------------------===//
LogicalResult
AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (isa<FloatType>(atomicOp.getType()))
return rewriter.notifyMatchFailure(atomicOp,
"unimplemented floating-point case");
auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
if (!scope)
return rewriter.notifyMatchFailure(atomicOp,
"unsupported memref memory space");
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Type resultType = typeConverter.convertType(atomicOp.getType());
if (!resultType)
return rewriter.notifyMatchFailure(atomicOp,
"failed to convert result type");
auto loc = atomicOp.getLoc();
Value ptr =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
adaptor.getIndices(), loc, rewriter);
if (!ptr)
return failure();
#define ATOMIC_CASE(kind, spirvOp) \
case arith::AtomicRMWKind::kind: \
rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
atomicOp, resultType, ptr, *scope, \
spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
break
switch (atomicOp.getKind()) {
ATOMIC_CASE(addi, AtomicIAddOp);
ATOMIC_CASE(maxs, AtomicSMaxOp);
ATOMIC_CASE(maxu, AtomicUMaxOp);
ATOMIC_CASE(mins, AtomicSMinOp);
ATOMIC_CASE(minu, AtomicUMinOp);
ATOMIC_CASE(ori, AtomicOrOp);
ATOMIC_CASE(andi, AtomicAndOp);
default:
return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
}
#undef ATOMIC_CASE
return success();
}
//===----------------------------------------------------------------------===//
// DeallocOp
//===----------------------------------------------------------------------===//
LogicalResult
DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
if (!isAllocationSupported(operation, deallocType))
return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
rewriter.eraseOp(operation);
return success();
}
//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//
struct MemoryRequirements {
spirv::MemoryAccessAttr memoryAccess;
IntegerAttr alignment;
};
/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
/// any.
static FailureOr<MemoryRequirements>
calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
MLIRContext *ctx = accessedPtr.getContext();
auto memoryAccess = spirv::MemoryAccess::None;
if (isNontemporal) {
memoryAccess = spirv::MemoryAccess::Nontemporal;
}
auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
if (memoryAccess == spirv::MemoryAccess::None) {
return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
}
return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
IntegerAttr{}};
}
// PhysicalStorageBuffers require the `Aligned` attribute.
auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
if (!pointeeType)
return failure();
// For scalar types, the alignment is determined by their size.
std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
if (!sizeInBytes.has_value())
return failure();
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
return MemoryRequirements{memAccessAttr, alignment};
}
/// Given an accessed SPIR-V pointer and the original memref load/store
/// `memAccess` op, calculates the alignment requirements, if any. Takes into
/// account the alignment attributes applied to the load/store op.
template <class LoadOrStoreOp>
static FailureOr<MemoryRequirements>
calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
static_assert(
llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
"Must be called on either memref::LoadOp or memref::StoreOp");
Operation *memrefAccessOp = loadOrStoreOp.getOperation();
auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
spirv::attributeName<spirv::MemoryAccess>());
auto memrefAlignment =
memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
if (memrefMemAccess && memrefAlignment)
return MemoryRequirements{memrefMemAccess, memrefAlignment};
return calculateMemoryRequirements(accessedPtr,
loadOrStoreOp.getNontemporal());
}
LogicalResult
IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = loadOp.getLoc();
auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
if (!memrefType.getElementType().isSignlessInteger())
return failure();
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Value accessChain =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
adaptor.getIndices(), loc, rewriter);
if (!accessChain)
return failure();
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
bool isBool = srcBits == 1;
if (isBool)
srcBits = typeConverter.getOptions().boolNumBits;
auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
if (!pointerType)
return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
Type pointeeType = pointerType.getPointeeType();
Type dstType;
if (typeConverter.allows(spirv::Capability::Kernel)) {
if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
dstType = arrayType.getElementType();
else
dstType = pointeeType;
} else {
// For Vulkan we need to extract element from wrapping struct and array.
Type structElemType =
cast<spirv::StructType>(pointeeType).getElementType(0);
if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
dstType = arrayType.getElementType();
else
dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
}
int dstBits = dstType.getIntOrFloatBitWidth();
assert(dstBits % srcBits == 0);
// If the rewritten load op has the same bit width, use the loading value
// directly.
if (srcBits == dstBits) {
auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
if (failed(memoryRequirements))
return rewriter.notifyMatchFailure(
loadOp, "failed to determine memory requirements");
auto [memoryAccess, alignment] = *memoryRequirements;
Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
memoryAccess, alignment);
if (isBool)
loadVal = castIntNToBool(loc, loadVal, rewriter);
rewriter.replaceOp(loadOp, loadVal);
return success();
}
// Bitcasting is currently unsupported for Kernel capability /
// spirv.PtrAccessChain.
if (typeConverter.allows(spirv::Capability::Kernel))
return failure();
auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
if (!accessChainOp)
return failure();
// Assume that getElementPtr() works linearizely. If it's a scalar, the method
// still returns a linearized accessing. If the accessing is not linearized,
// there will be offset issues.
assert(accessChainOp.getIndices().size() == 2);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);
auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
if (failed(memoryRequirements))
return rewriter.notifyMatchFailure(
loadOp, "failed to determine memory requirements");
auto [memoryAccess, alignment] = *memoryRequirements;
Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
memoryAccess, alignment);
// Shift the bits to the rightmost.
// ____XXXX________ -> ____________XXXX
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
loc, spvLoadOp.getType(), spvLoadOp, offset);
// Apply the mask to extract corresponding bits.
Value mask = rewriter.createOrFold<spirv::ConstantOp>(
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
result =
rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
// Apply sign extension on the loading value unconditionally. The signedness
// semantic is carried in the operator itself, we relies other pattern to
// handle the casting.
IntegerAttr shiftValueAttr =
rewriter.getIntegerAttr(dstType, dstBits - srcBits);
Value shiftValue =
rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
result, shiftValue);
result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
loc, dstType, result, shiftValue);
rewriter.replaceOp(loadOp, result);
assert(accessChainOp.use_empty());
rewriter.eraseOp(accessChainOp);
return success();
}
LogicalResult
LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
if (memrefType.getElementType().isSignlessInteger())
return failure();
Value loadPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
adaptor.getIndices(), loadOp.getLoc(), rewriter);
if (!loadPtr)
return failure();
auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
if (failed(memoryRequirements))
return rewriter.notifyMatchFailure(
loadOp, "failed to determine memory requirements");
auto [memoryAccess, alignment] = *memoryRequirements;
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
alignment);
return success();
}
LogicalResult
IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
if (!memrefType.getElementType().isSignlessInteger())
return rewriter.notifyMatchFailure(storeOp,
"element type is not a signless int");
auto loc = storeOp.getLoc();
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Value accessChain =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
adaptor.getIndices(), loc, rewriter);
if (!accessChain)
return rewriter.notifyMatchFailure(
storeOp, "failed to convert element pointer type");
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
bool isBool = srcBits == 1;
if (isBool)
srcBits = typeConverter.getOptions().boolNumBits;
auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
if (!pointerType)
return rewriter.notifyMatchFailure(storeOp,
"failed to convert memref type");
Type pointeeType = pointerType.getPointeeType();
IntegerType dstType;
if (typeConverter.allows(spirv::Capability::Kernel)) {
if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
dstType = dyn_cast<IntegerType>(arrayType.getElementType());
else
dstType = dyn_cast<IntegerType>(pointeeType);
} else {
// For Vulkan we need to extract element from wrapping struct and array.
Type structElemType =
cast<spirv::StructType>(pointeeType).getElementType(0);
if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
dstType = dyn_cast<IntegerType>(arrayType.getElementType());
else
dstType = dyn_cast<IntegerType>(
cast<spirv::RuntimeArrayType>(structElemType).getElementType());
}
if (!dstType)
return rewriter.notifyMatchFailure(
storeOp, "failed to determine destination element type");
int dstBits = static_cast<int>(dstType.getWidth());
assert(dstBits % srcBits == 0);
if (srcBits == dstBits) {
auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
if (failed(memoryRequirements))
return rewriter.notifyMatchFailure(
storeOp, "failed to determine memory requirements");
auto [memoryAccess, alignment] = *memoryRequirements;
Value storeVal = adaptor.getValue();
if (isBool)
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
memoryAccess, alignment);
return success();
}
// Bitcasting is currently unsupported for Kernel capability /
// spirv.PtrAccessChain.
if (typeConverter.allows(spirv::Capability::Kernel))
return failure();
auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
if (!accessChainOp)
return failure();
// Since there are multiple threads in the processing, the emulation will be
// done with atomic operations. E.g., if the stored value is i8, rewrite the
// StoreOp to:
// 1) load a 32-bit integer
// 2) clear 8 bits in the loaded value
// 3) set 8 bits in the loaded value
// 4) store 32-bit value back
//
// Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
// loaded 32-bit value and the shifted 8-bit store value) as another atomic
// step.
assert(accessChainOp.getIndices().size() == 2);
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
// Create a mask to clear the destination. E.g., if it is the second i8 in
// i32, 0xFFFF00FF is created.
Value mask = rewriter.createOrFold<spirv::ConstantOp>(
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
loc, dstType, mask, offset);
clearBitsMask =
rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);
std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
if (!scope)
return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
Value result = rewriter.create<spirv::AtomicAndOp>(
loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
clearBitsMask);
result = rewriter.create<spirv::AtomicOrOp>(
loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
storeVal);
// The AtomicOrOp has no side effect. Since it is already inserted, we can
// just remove the original StoreOp. Note that rewriter.replaceOp()
// doesn't work because it only accepts that the numbers of result are the
// same.
rewriter.eraseOp(storeOp);
assert(accessChainOp.use_empty());
rewriter.eraseOp(accessChainOp);
return success();
}
//===----------------------------------------------------------------------===//
// MemorySpaceCastOp
//===----------------------------------------------------------------------===//
LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = addrCastOp.getLoc();
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
if (!typeConverter.allows(spirv::Capability::Kernel))
return rewriter.notifyMatchFailure(
loc, "address space casts require kernel capability");
auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
if (!sourceType)
return rewriter.notifyMatchFailure(
loc, "SPIR-V lowering requires ranked memref types");
auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
auto sourceStorageClassAttr =
dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
if (!sourceStorageClassAttr)
return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
diag << "source address space " << sourceType.getMemorySpace()
<< " must be a SPIR-V storage class";
});
auto resultStorageClassAttr =
dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
if (!resultStorageClassAttr)
return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
diag << "result address space " << resultType.getMemorySpace()
<< " must be a SPIR-V storage class";
});
spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
Value result = adaptor.getSource();
Type resultPtrType = typeConverter.convertType(resultType);
if (!resultPtrType)
return rewriter.notifyMatchFailure(addrCastOp,
"failed to convert memref type");
Type genericPtrType = resultPtrType;
// SPIR-V doesn't have a general address space cast operation. Instead, it has
// conversions to and from generic pointers. To implement the general case,
// we use specific-to-generic conversions when the source class is not
// generic. Then when the result storage class is not generic, we convert the
// generic pointer (either the input on ar intermediate result) to that
// class. This also means that we'll need the intermediate generic pointer
// type if neither the source or destination have it.
if (sourceSc != spirv::StorageClass::Generic &&
resultSc != spirv::StorageClass::Generic) {
Type intermediateType =
MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
sourceType.getLayout(),
rewriter.getAttr<spirv::StorageClassAttr>(
spirv::StorageClass::Generic));
genericPtrType = typeConverter.convertType(intermediateType);
}
if (sourceSc != spirv::StorageClass::Generic) {
result =
rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
}
if (resultSc != spirv::StorageClass::Generic) {
result =
rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
}
rewriter.replaceOp(addrCastOp, result);
return success();
}
LogicalResult
StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
if (memrefType.getElementType().isSignlessInteger())
return rewriter.notifyMatchFailure(storeOp, "signless int");
auto storePtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
adaptor.getIndices(), storeOp.getLoc(), rewriter);
if (!storePtr)
return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
if (failed(memoryRequirements))
return rewriter.notifyMatchFailure(
storeOp, "failed to determine memory requirements");
auto [memoryAccess, alignment] = *memoryRequirements;
rewriter.replaceOpWithNewOp<spirv::StoreOp>(
storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
return success();
}
LogicalResult ReinterpretCastPattern::matchAndRewrite(
memref::ReinterpretCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value src = adaptor.getSource();
auto srcType = dyn_cast<spirv::PointerType>(src.getType());
if (!srcType)
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "invalid src type " << src.getType();
});
const TypeConverter *converter = getTypeConverter();
auto dstType = converter->convertType<spirv::PointerType>(op.getType());
if (dstType != srcType)
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "invalid dst type " << op.getType();
});
OpFoldResult offset =
getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
.front();
if (isZeroInteger(offset)) {
rewriter.replaceOp(op, src);
return success();
}
Type intType = converter->convertType(rewriter.getIndexType());
if (!intType)
return rewriter.notifyMatchFailure(op, "failed to convert index type");
Location loc = op.getLoc();
auto offsetValue = [&]() -> Value {
if (auto val = dyn_cast<Value>(offset))
return val;
int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
}();
rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
op, src, offsetValue, ValueRange());
return success();
}
//===----------------------------------------------------------------------===//
// ExtractAlignedPointerAsIndexOp
//===----------------------------------------------------------------------===//
LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Type indexType = typeConverter.getIndexType();
rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
adaptor.getSource());
return success();
}
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
namespace mlir {
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns
.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
typeConverter, patterns.getContext());
}
} // namespace mlir