| //===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Conversion/LLVMCommon/VectorPattern.h" |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| |
| using namespace mlir; |
| |
| // For >1-D vector types, extracts the necessary information to iterate over all |
| // 1-D subvectors in the underlying llrepresentation of the n-D vector |
| // Iterates on the llvm array type until we hit a non-array type (which is |
| // asserted to be an llvm vector type). |
| LLVM::detail::NDVectorTypeInfo |
| LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType, |
| const LLVMTypeConverter &converter) { |
| assert(vectorType.getRank() > 1 && "expected >1D vector type"); |
| NDVectorTypeInfo info; |
| info.llvmNDVectorTy = converter.convertType(vectorType); |
| if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) { |
| info.llvmNDVectorTy = nullptr; |
| return info; |
| } |
| info.arraySizes.reserve(vectorType.getRank() - 1); |
| auto llvmTy = info.llvmNDVectorTy; |
| while (isa<LLVM::LLVMArrayType>(llvmTy)) { |
| info.arraySizes.push_back( |
| cast<LLVM::LLVMArrayType>(llvmTy).getNumElements()); |
| llvmTy = cast<LLVM::LLVMArrayType>(llvmTy).getElementType(); |
| } |
| if (!LLVM::isCompatibleVectorType(llvmTy)) |
| return info; |
| info.llvm1DVectorTy = llvmTy; |
| return info; |
| } |
| |
| // Express `linearIndex` in terms of coordinates of `basis`. |
| // Returns the empty vector when linearIndex is out of the range [0, P] where |
| // P is the product of all the basis coordinates. |
| // |
| // Prerequisites: |
| // Basis is an array of nonnegative integers (signed type inherited from |
| // vector shape type). |
| SmallVector<int64_t, 4> LLVM::detail::getCoordinates(ArrayRef<int64_t> basis, |
| unsigned linearIndex) { |
| SmallVector<int64_t, 4> res; |
| res.reserve(basis.size()); |
| for (unsigned basisElement : llvm::reverse(basis)) { |
| res.push_back(linearIndex % basisElement); |
| linearIndex = linearIndex / basisElement; |
| } |
| if (linearIndex > 0) |
| return {}; |
| std::reverse(res.begin(), res.end()); |
| return res; |
| } |
| |
| // Iterate of linear index, convert to coords space and insert splatted 1-D |
| // vector in each position. |
| void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info, |
| OpBuilder &builder, |
| function_ref<void(ArrayRef<int64_t>)> fun) { |
| unsigned ub = 1; |
| for (auto s : info.arraySizes) |
| ub *= s; |
| for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { |
| auto coords = getCoordinates(info.arraySizes, linearIndex); |
| // Linear index is out of bounds, we are done. |
| if (coords.empty()) |
| break; |
| assert(coords.size() == info.arraySizes.size()); |
| fun(coords); |
| } |
| } |
| |
| LogicalResult LLVM::detail::handleMultidimensionalVectors( |
| Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, |
| std::function<Value(Type, ValueRange)> createOperand, |
| ConversionPatternRewriter &rewriter) { |
| auto resultNDVectorType = cast<VectorType>(op->getResult(0).getType()); |
| auto resultTypeInfo = |
| extractNDVectorTypeInfo(resultNDVectorType, typeConverter); |
| auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; |
| auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy; |
| auto loc = op->getLoc(); |
| Value desc = rewriter.create<LLVM::PoisonOp>(loc, resultNDVectoryTy); |
| nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { |
| // For this unrolled `position` corresponding to the `linearIndex`^th |
| // element, extract operand vectors |
| SmallVector<Value, 4> extractedOperands; |
| for (const auto &operand : llvm::enumerate(operands)) { |
| extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>( |
| loc, operand.value(), position)); |
| } |
| Value newVal = createOperand(result1DVectorTy, extractedOperands); |
| desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, newVal, position); |
| }); |
| rewriter.replaceOp(op, desc); |
| return success(); |
| } |
| |
| LogicalResult LLVM::detail::vectorOneToOneRewrite( |
| Operation *op, StringRef targetOp, ValueRange operands, |
| ArrayRef<NamedAttribute> targetAttrs, |
| const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, |
| IntegerOverflowFlags overflowFlags) { |
| assert(!operands.empty()); |
| |
| // Cannot convert ops if their operands are not of LLVM type. |
| if (!llvm::all_of(operands.getTypes(), isCompatibleType)) |
| return failure(); |
| |
| auto llvmNDVectorTy = operands[0].getType(); |
| if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) |
| return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter, |
| rewriter, overflowFlags); |
| |
| auto callback = [op, targetOp, targetAttrs, overflowFlags, |
| &rewriter](Type llvm1DVectorTy, ValueRange operands) { |
| Operation *newOp = |
| rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), |
| operands, llvm1DVectorTy, targetAttrs); |
| LLVM::detail::setNativeProperties(newOp, overflowFlags); |
| return newOp->getResult(0); |
| }; |
| |
| return handleMultidimensionalVectors(op, operands, typeConverter, callback, |
| rewriter); |
| } |