| //===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements patterns to convert SCF dialect to SPIR-V dialect. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "llvm/Support/FormatVariadic.h" |
| |
| using namespace mlir; |
| |
| //===----------------------------------------------------------------------===// |
| // Context |
| //===----------------------------------------------------------------------===// |
| |
| namespace mlir { |
| struct ScfToSPIRVContextImpl { |
| // Map between the spirv region control flow operation (spirv.mlir.loop or |
| // spirv.mlir.selection) to the VariableOp created to store the region |
| // results. The order of the VariableOp matches the order of the results. |
| DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars; |
| }; |
| } // namespace mlir |
| |
| /// We use ScfToSPIRVContext to store information about the lowering of the scf |
| /// region that need to be used later on. When we lower scf.for/scf.if we create |
| /// VariableOp to store the results. We need to keep track of the VariableOp |
| /// created as we need to insert stores into them when lowering Yield. Those |
| /// StoreOp cannot be created earlier as they may use a different type than |
| /// yield operands. |
| ScfToSPIRVContext::ScfToSPIRVContext() { |
| impl = std::make_unique<::ScfToSPIRVContextImpl>(); |
| } |
| |
| ScfToSPIRVContext::~ScfToSPIRVContext() = default; |
| |
| namespace { |
| |
| //===----------------------------------------------------------------------===// |
| // Helper Functions |
| //===----------------------------------------------------------------------===// |
| |
| /// Replaces SCF op outputs with SPIR-V variable loads. |
| /// We create VariableOp to handle the results value of the control flow region. |
| /// spirv.mlir.loop/spirv.mlir.selection currently don't yield value. Right |
| /// after the loop we load the value from the allocation and use it as the SCF |
| /// op result. |
| template <typename ScfOp, typename OpTy> |
| void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, |
| ConversionPatternRewriter &rewriter, |
| ScfToSPIRVContextImpl *scfToSPIRVContext, |
| ArrayRef<Type> returnTypes) { |
| |
| Location loc = scfOp.getLoc(); |
| auto &allocas = scfToSPIRVContext->outputVars[newOp]; |
| // Clearing the allocas is necessary in case a dialect conversion path failed |
| // previously, and this is the second attempt of this conversion. |
| allocas.clear(); |
| SmallVector<Value, 8> resultValue; |
| for (Type convertedType : returnTypes) { |
| auto pointerType = |
| spirv::PointerType::get(convertedType, spirv::StorageClass::Function); |
| rewriter.setInsertionPoint(newOp); |
| auto alloc = spirv::VariableOp::create(rewriter, loc, pointerType, |
| spirv::StorageClass::Function, |
| /*initializer=*/nullptr); |
| allocas.push_back(alloc); |
| rewriter.setInsertionPointAfter(newOp); |
| Value loadResult = spirv::LoadOp::create(rewriter, loc, alloc); |
| resultValue.push_back(loadResult); |
| } |
| rewriter.replaceOp(scfOp, resultValue); |
| } |
| |
| Region::iterator getBlockIt(Region ®ion, unsigned index) { |
| return std::next(region.begin(), index); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Conversion Patterns |
| //===----------------------------------------------------------------------===// |
| |
| /// Common class for all vector to GPU patterns. |
| template <typename OpTy> |
| class SCFToSPIRVPattern : public OpConversionPattern<OpTy> { |
| public: |
| SCFToSPIRVPattern(MLIRContext *context, const SPIRVTypeConverter &converter, |
| ScfToSPIRVContextImpl *scfToSPIRVContext) |
| : OpConversionPattern<OpTy>::OpConversionPattern(converter, context), |
| scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {} |
| |
| protected: |
| ScfToSPIRVContextImpl *scfToSPIRVContext; |
| // FIXME: We explicitly keep a reference of the type converter here instead of |
| // passing it to OpConversionPattern during construction. This effectively |
| // bypasses the conversion framework's automation on type conversion. This is |
| // needed right now because the conversion framework will unconditionally |
| // legalize all types used by SCF ops upon discovering them, for example, the |
| // types of loop carried values. We use SPIR-V variables for those loop |
| // carried values. Depending on the available capabilities, the SPIR-V |
| // variable can be different, for example, cooperative matrix or normal |
| // variable. We'd like to detach the conversion of the loop carried values |
| // from the SCF ops (which is mainly a region). So we need to "mark" types |
| // used by SCF ops as legal, if to use the conversion framework for type |
| // conversion. There isn't a straightforward way to do that yet, as when |
| // converting types, ops aren't taken into consideration. Therefore, we just |
| // bypass the framework's type conversion for now. |
| const SPIRVTypeConverter &typeConverter; |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // scf::ForOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp. |
| struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> { |
| using SCFToSPIRVPattern::SCFToSPIRVPattern; |
| |
| LogicalResult |
| matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| // scf::ForOp can be lowered to the structured control flow represented by |
| // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop |
| // latch and the merge block the exit block. The resulting spirv::LoopOp has |
| // a single back edge from the continue to header block, and a single exit |
| // from header to merge. |
| auto loc = forOp.getLoc(); |
| auto loopOp = |
| spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None); |
| loopOp.addEntryAndMergeBlock(rewriter); |
| |
| OpBuilder::InsertionGuard guard(rewriter); |
| // Create the block for the header. |
| Block *header = rewriter.createBlock(&loopOp.getBody(), |
| getBlockIt(loopOp.getBody(), 1)); |
| rewriter.setInsertionPointAfter(loopOp); |
| |
| // Create the new induction variable to use. |
| Value adapLowerBound = adaptor.getLowerBound(); |
| BlockArgument newIndVar = |
| header->addArgument(adapLowerBound.getType(), adapLowerBound.getLoc()); |
| for (Value arg : adaptor.getInitArgs()) |
| header->addArgument(arg.getType(), arg.getLoc()); |
| Block *body = forOp.getBody(); |
| |
| // Apply signature conversion to the body of the forOp. It has a single |
| // block, with argument which is the induction variable. That has to be |
| // replaced with the new induction variable. |
| TypeConverter::SignatureConversion signatureConverter( |
| body->getNumArguments()); |
| signatureConverter.remapInput(0, newIndVar); |
| for (unsigned i = 1, e = body->getNumArguments(); i < e; i++) |
| signatureConverter.remapInput(i, header->getArgument(i)); |
| body = rewriter.applySignatureConversion(&forOp.getRegion().front(), |
| signatureConverter); |
| |
| // Move the blocks from the forOp into the loopOp. This is the body of the |
| // loopOp. |
| rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(), |
| getBlockIt(loopOp.getBody(), 2)); |
| |
| SmallVector<Value, 8> args(1, adaptor.getLowerBound()); |
| args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); |
| // Branch into it from the entry. |
| rewriter.setInsertionPointToEnd(&(loopOp.getBody().front())); |
| spirv::BranchOp::create(rewriter, loc, header, args); |
| |
| // Generate the rest of the loop header. |
| rewriter.setInsertionPointToEnd(header); |
| auto *mergeBlock = loopOp.getMergeBlock(); |
| auto cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(), |
| newIndVar, adaptor.getUpperBound()); |
| |
| spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body, |
| ArrayRef<Value>(), mergeBlock, |
| ArrayRef<Value>()); |
| |
| // Generate instructions to increment the step of the induction variable and |
| // branch to the header. |
| Block *continueBlock = loopOp.getContinueBlock(); |
| rewriter.setInsertionPointToEnd(continueBlock); |
| |
| // Add the step to the induction variable and branch to the header. |
| Value updatedIndVar = spirv::IAddOp::create( |
| rewriter, loc, newIndVar.getType(), newIndVar, adaptor.getStep()); |
| spirv::BranchOp::create(rewriter, loc, header, updatedIndVar); |
| |
| // Infer the return types from the init operands. Vector type may get |
| // converted to CooperativeMatrix or to Vector type, to avoid having complex |
| // extra logic to figure out the right type we just infer it from the Init |
| // operands. |
| SmallVector<Type, 8> initTypes; |
| for (auto arg : adaptor.getInitArgs()) |
| initTypes.push_back(arg.getType()); |
| replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, |
| initTypes); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // scf::IfOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Pattern to convert a scf::IfOp within kernel functions into |
| /// spirv::SelectionOp. |
| struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> { |
| using SCFToSPIRVPattern::SCFToSPIRVPattern; |
| |
| LogicalResult |
| matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| // When lowering `scf::IfOp` we explicitly create a selection header block |
| // before the control flow diverges and a merge block where control flow |
| // subsequently converges. |
| auto loc = ifOp.getLoc(); |
| |
| // Compute return types. |
| SmallVector<Type, 8> returnTypes; |
| for (auto result : ifOp.getResults()) { |
| auto convertedType = typeConverter.convertType(result.getType()); |
| if (!convertedType) |
| return rewriter.notifyMatchFailure( |
| loc, |
| llvm::formatv("failed to convert type '{0}'", result.getType())); |
| |
| returnTypes.push_back(convertedType); |
| } |
| |
| // Create `spirv.selection` operation, selection header block and merge |
| // block. |
| auto selectionOp = spirv::SelectionOp::create( |
| rewriter, loc, spirv::SelectionControl::None); |
| auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(), |
| selectionOp.getBody().end()); |
| spirv::MergeOp::create(rewriter, loc); |
| |
| OpBuilder::InsertionGuard guard(rewriter); |
| auto *selectionHeaderBlock = |
| rewriter.createBlock(&selectionOp.getBody().front()); |
| |
| // Inline `then` region before the merge block and branch to it. |
| auto &thenRegion = ifOp.getThenRegion(); |
| auto *thenBlock = &thenRegion.front(); |
| rewriter.setInsertionPointToEnd(&thenRegion.back()); |
| spirv::BranchOp::create(rewriter, loc, mergeBlock); |
| rewriter.inlineRegionBefore(thenRegion, mergeBlock); |
| |
| auto *elseBlock = mergeBlock; |
| // If `else` region is not empty, inline that region before the merge block |
| // and branch to it. |
| if (!ifOp.getElseRegion().empty()) { |
| auto &elseRegion = ifOp.getElseRegion(); |
| elseBlock = &elseRegion.front(); |
| rewriter.setInsertionPointToEnd(&elseRegion.back()); |
| spirv::BranchOp::create(rewriter, loc, mergeBlock); |
| rewriter.inlineRegionBefore(elseRegion, mergeBlock); |
| } |
| |
| // Create a `spirv.BranchConditional` operation for selection header block. |
| rewriter.setInsertionPointToEnd(selectionHeaderBlock); |
| spirv::BranchConditionalOp::create(rewriter, loc, adaptor.getCondition(), |
| thenBlock, ArrayRef<Value>(), elseBlock, |
| ArrayRef<Value>()); |
| |
| replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext, |
| returnTypes); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // scf::YieldOp |
| //===----------------------------------------------------------------------===// |
| |
| struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> { |
| public: |
| using SCFToSPIRVPattern::SCFToSPIRVPattern; |
| |
| LogicalResult |
| matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| ValueRange operands = adaptor.getOperands(); |
| |
| Operation *parent = terminatorOp->getParentOp(); |
| |
| // TODO: Implement conversion for the remaining `scf` ops. |
| if (parent->getDialect()->getNamespace() == |
| scf::SCFDialect::getDialectNamespace() && |
| !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent)) |
| return rewriter.notifyMatchFailure( |
| terminatorOp, |
| llvm::formatv("conversion not supported for parent op: '{0}'", |
| parent->getName())); |
| |
| // If the region return values, store each value into the associated |
| // VariableOp created during lowering of the parent region. |
| if (!operands.empty()) { |
| auto &allocas = scfToSPIRVContext->outputVars[parent]; |
| if (allocas.size() != operands.size()) |
| return failure(); |
| |
| auto loc = terminatorOp.getLoc(); |
| for (unsigned i = 0, e = operands.size(); i < e; i++) |
| spirv::StoreOp::create(rewriter, loc, allocas[i], operands[i]); |
| if (isa<spirv::LoopOp>(parent)) { |
| // For loops we also need to update the branch jumping back to the |
| // header. |
| auto br = cast<spirv::BranchOp>( |
| rewriter.getInsertionBlock()->getTerminator()); |
| SmallVector<Value, 8> args(br.getBlockArguments()); |
| args.append(operands.begin(), operands.end()); |
| rewriter.setInsertionPoint(br); |
| spirv::BranchOp::create(rewriter, terminatorOp.getLoc(), br.getTarget(), |
| args); |
| rewriter.eraseOp(br); |
| } |
| } |
| rewriter.eraseOp(terminatorOp); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // scf::WhileOp |
| //===----------------------------------------------------------------------===// |
| |
| struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> { |
| using SCFToSPIRVPattern::SCFToSPIRVPattern; |
| |
| LogicalResult |
| matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = whileOp.getLoc(); |
| auto loopOp = |
| spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None); |
| loopOp.addEntryAndMergeBlock(rewriter); |
| |
| Region &beforeRegion = whileOp.getBefore(); |
| Region &afterRegion = whileOp.getAfter(); |
| |
| if (failed(rewriter.convertRegionTypes(&beforeRegion, typeConverter)) || |
| failed(rewriter.convertRegionTypes(&afterRegion, typeConverter))) |
| return rewriter.notifyMatchFailure(whileOp, |
| "Failed to convert region types"); |
| |
| OpBuilder::InsertionGuard guard(rewriter); |
| |
| Block &entryBlock = *loopOp.getEntryBlock(); |
| Block &beforeBlock = beforeRegion.front(); |
| Block &afterBlock = afterRegion.front(); |
| Block &mergeBlock = *loopOp.getMergeBlock(); |
| |
| auto cond = cast<scf::ConditionOp>(beforeBlock.getTerminator()); |
| SmallVector<Value> condArgs; |
| if (failed(rewriter.getRemappedValues(cond.getArgs(), condArgs))) |
| return failure(); |
| |
| Value conditionVal = rewriter.getRemappedValue(cond.getCondition()); |
| if (!conditionVal) |
| return failure(); |
| |
| auto yield = cast<scf::YieldOp>(afterBlock.getTerminator()); |
| SmallVector<Value> yieldArgs; |
| if (failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs))) |
| return failure(); |
| |
| // Move the while before block as the initial loop header block. |
| rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(), |
| getBlockIt(loopOp.getBody(), 1)); |
| |
| // Move the while after block as the initial loop body block. |
| rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(), |
| getBlockIt(loopOp.getBody(), 2)); |
| |
| // Jump from the loop entry block to the loop header block. |
| rewriter.setInsertionPointToEnd(&entryBlock); |
| spirv::BranchOp::create(rewriter, loc, &beforeBlock, adaptor.getInits()); |
| |
| auto condLoc = cond.getLoc(); |
| |
| SmallVector<Value> resultValues(condArgs.size()); |
| |
| // For other SCF ops, the scf.yield op yields the value for the whole SCF |
| // op. So we use the scf.yield op as the anchor to create/load/store SPIR-V |
| // local variables. But for the scf.while op, the scf.yield op yields a |
| // value for the before region, which may not matching the whole op's |
| // result. Instead, the scf.condition op returns values matching the whole |
| // op's results. So we need to create/load/store variables according to |
| // that. |
| for (const auto &it : llvm::enumerate(condArgs)) { |
| auto res = it.value(); |
| auto i = it.index(); |
| auto pointerType = |
| spirv::PointerType::get(res.getType(), spirv::StorageClass::Function); |
| |
| // Create local variables before the scf.while op. |
| rewriter.setInsertionPoint(loopOp); |
| auto alloc = spirv::VariableOp::create(rewriter, condLoc, pointerType, |
| spirv::StorageClass::Function, |
| /*initializer=*/nullptr); |
| |
| // Load the final result values after the scf.while op. |
| rewriter.setInsertionPointAfter(loopOp); |
| auto loadResult = spirv::LoadOp::create(rewriter, condLoc, alloc); |
| resultValues[i] = loadResult; |
| |
| // Store the current iteration's result value. |
| rewriter.setInsertionPointToEnd(&beforeBlock); |
| spirv::StoreOp::create(rewriter, condLoc, alloc, res); |
| } |
| |
| rewriter.setInsertionPointToEnd(&beforeBlock); |
| rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>( |
| cond, conditionVal, &afterBlock, condArgs, &mergeBlock, ValueRange()); |
| |
| // Convert the scf.yield op to a branch back to the header block. |
| rewriter.setInsertionPointToEnd(&afterBlock); |
| rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock, |
| yieldArgs); |
| |
| rewriter.replaceOp(whileOp, resultValues); |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Public API |
| //===----------------------------------------------------------------------===// |
| |
| void mlir::populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, |
| ScfToSPIRVContext &scfToSPIRVContext, |
| RewritePatternSet &patterns) { |
| patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion, |
| WhileOpConversion>(patterns.getContext(), typeConverter, |
| scfToSPIRVContext.getImpl()); |
| } |