blob: afd02a2f435af0e1237064674c9fceee50ad690f [file] [log] [blame]
// Copyright 2019 The Clspv Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/IR/CallingConv.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/Pass.h"
#include "llvm/Support/raw_ostream.h"
#include "Builtins.h"
#include "Constants.h"
#include "Passes.h"
#include "Types.h"
#include "clspv/Option.h"
using namespace clspv;
using namespace clspv::Builtins;
using namespace llvm;
namespace {
class SpecializeImageTypesPass : public ModulePass {
public:
static char ID;
SpecializeImageTypesPass() : ModulePass(ID) {}
bool runOnModule(Module &M) override;
private:
// Returns the specialized image type for |arg|.
Type *RemapType(Argument *arg);
// Returns the specialized image type for operand |operand_no| in |value|.
Type *RemapUse(Value *value, unsigned operand_no);
// Specializes |arg| as |new_type|. Recursively updates the use chain.
void SpecializeArg(Function *f, Argument *arg, Type *new_type);
// Returns a replacement image builtin function for the specialized type
// |type|.
Function *ReplaceImageBuiltin(Function *f, Type *type);
// Rewrites |f| using the |remapped_args_| to determine to updated types.
void RewriteFunction(Function *f);
// Tracks the generation of specialized types so they are not further
// specialized.
DenseSet<Type *> specialized_images_;
// Maps an argument to a specialized type.
DenseMap<Argument *, Type *> remapped_args_;
// Tracks which functions need rewritten due to modified arguments.
DenseSet<Function *> functions_to_modify_;
};
} // namespace
char SpecializeImageTypesPass::ID = 0;
INITIALIZE_PASS(SpecializeImageTypesPass, "SpecializeImageTypesPass",
"Specialize image types", false, false)
namespace clspv {
ModulePass *createSpecializeImageTypesPass() {
return new SpecializeImageTypesPass();
}
} // namespace clspv
namespace {
bool SpecializeImageTypesPass::runOnModule(Module &M) {
bool changed = false;
SmallVector<Function *, 8> kernels;
for (auto &F : M) {
if (F.isDeclaration() || F.getCallingConv() != CallingConv::SPIR_KERNEL)
continue;
kernels.push_back(&F);
}
for (auto f : kernels) {
for (auto &Arg : f->args()) {
if (IsImageType(Arg.getType())) {
Type *new_type = RemapType(&Arg);
if (!new_type) {
// No specializing information found, assume the image is sampled with
// a float type.
std::string name =
cast<StructType>(Arg.getType()->getPointerElementType())
->getName()
.str();
name += ".float";
const auto pos = name.find("_wo_t");
if (!IsStorageImageType(Arg.getType())) {
name += ".sampled";
} else if (clspv::Option::Language() >=
clspv::Option::SourceLanguage::OpenCL_C_20 &&
pos != std::string::npos) {
// In OpenCL 2.0 (or later), treat write_only images as read_write
// images. This prevents the compiler from generating duplicate
// image types (invalid SPIR-V).
name = name.substr(0, pos) + "_rw_t" + name.substr(pos + 5);
}
StructType *new_struct =
StructType::getTypeByName(M.getContext(), name);
if (!new_struct)
new_struct = StructType::create(Arg.getContext(), name);
new_type = PointerType::get(new_struct,
Arg.getType()->getPointerAddressSpace());
}
specialized_images_.insert(new_type);
changed = true;
SpecializeArg(f, &Arg, new_type);
}
}
}
// Keep functions in the same relative order.
std::vector<Function *> to_rewrite;
for (auto &F : M) {
if (functions_to_modify_.count(&F))
to_rewrite.push_back(&F);
}
for (auto f : to_rewrite) {
RewriteFunction(f);
}
return changed;
}
Type *SpecializeImageTypesPass::RemapType(Argument *arg) {
for (auto &U : arg->uses()) {
if (auto new_type = RemapUse(U.getUser(), U.getOperandNo())) {
return new_type;
}
}
return nullptr;
}
Type *SpecializeImageTypesPass::RemapUse(Value *value, unsigned operand_no) {
if (CallInst *call = dyn_cast<CallInst>(value)) {
auto called = call->getCalledFunction();
auto func_info = Builtins::Lookup(called);
switch (func_info.getType()) {
case Builtins::kReadImagef:
case Builtins::kReadImagei:
case Builtins::kReadImageui:
case Builtins::kWriteImagef:
case Builtins::kWriteImagei:
case Builtins::kWriteImageui: {
// Specialize the image type based on it's usage in the builtin.
Value *image = call->getOperand(0);
Type *imageTy = image->getType();
// Check if this type is already specialized.
if (specialized_images_.count(imageTy))
return imageTy;
std::string name =
cast<StructType>(imageTy->getPointerElementType())->getName().str();
switch (func_info.getType()) {
case Builtins::kReadImagef:
case Builtins::kWriteImagef:
name += ".float";
break;
case Builtins::kReadImagei:
case Builtins::kWriteImagei:
name += ".int";
break;
case Builtins::kReadImageui:
case Builtins::kWriteImageui:
name += ".uint";
break;
default:
break;
}
// Read only images are translated as sampled images.
const auto pos = name.find("_wo_t");
if (!IsStorageImageType(imageTy)) {
name += ".sampled";
} else if (clspv::Option::Language() >=
clspv::Option::SourceLanguage::OpenCL_C_20 &&
pos != std::string::npos) {
// In OpenCL 2.0 (or later), treat write_only images as read_write
// images. This prevents the compiler from generating duplicate image
// types (invalid SPIR-V).
name = name.substr(0, pos) + "_rw_t" + name.substr(pos + 5);
}
StructType *new_struct =
StructType::getTypeByName(call->getContext(), name);
if (!new_struct) {
new_struct = StructType::create(call->getContext(), name);
}
PointerType *new_pointer =
PointerType::get(new_struct, imageTy->getPointerAddressSpace());
return new_pointer;
}
default:
if (!called->isDeclaration()) {
for (auto &U : called->getArg(operand_no)->uses()) {
if (auto new_type = RemapUse(U.getUser(), U.getOperandNo())) {
return new_type;
}
}
}
break;
}
} else if (IsImageType(value->getType())) {
for (auto &U : value->uses()) {
if (auto new_type = RemapUse(U.getUser(), U.getOperandNo())) {
return new_type;
}
}
}
return nullptr;
}
void SpecializeImageTypesPass::SpecializeArg(Function *f, Argument *arg,
Type *new_type) {
auto where = remapped_args_.find(arg);
if (where != remapped_args_.end())
return;
remapped_args_[arg] = new_type;
functions_to_modify_.insert(f);
// Fix all uses of |arg|.
std::vector<Value *> stack;
stack.push_back(arg);
while (!stack.empty()) {
Value *value = stack.back();
stack.pop_back();
if (value->getType() == new_type)
continue;
auto old_type = value->getType();
value->mutateType(new_type);
for (auto &u : value->uses()) {
if (auto call = dyn_cast<CallInst>(u.getUser())) {
auto called = call->getCalledFunction();
auto &func_info = Builtins::Lookup(called);
if (BUILTIN_IN_GROUP(func_info.getType(), Image)) {
auto new_func = ReplaceImageBuiltin(called, new_type);
call->setCalledFunction(new_func);
if (called->getNumUses() == 0)
called->eraseFromParent();
} else {
SpecializeArg(called, called->getArg(u.getOperandNo()), new_type);
// Ensure the called function type matches the called function's type.
call->setCalledFunction(call->getCalledFunction());
}
}
if (old_type == u.getUser()->getType()) {
stack.push_back(u.getUser());
}
}
}
}
Function *SpecializeImageTypesPass::ReplaceImageBuiltin(Function *f,
Type *type) {
std::string name = f->getName().str();
name += ".";
name += cast<StructType>(type->getPointerElementType())->getName();
if (auto replaced = f->getParent()->getFunction(name))
return replaced;
// Change the image argument to the specialized type.
SmallVector<Type *, 4> paramTys;
for (auto &Arg : f->args()) {
if (IsImageType(Arg.getType()))
paramTys.push_back(type);
else
paramTys.push_back(Arg.getType());
}
auto func_type =
FunctionType::get(f->getReturnType(), paramTys, f->isVarArg());
auto callee =
f->getParent()->getOrInsertFunction(name, func_type, f->getAttributes());
auto new_func = cast<Function>(callee.getCallee());
new_func->setCallingConv(f->getCallingConv());
new_func->copyMetadata(f, 0);
return new_func;
}
void SpecializeImageTypesPass::RewriteFunction(Function *f) {
auto module = f->getParent();
SmallVector<Type *, 8> arg_types;
for (auto &arg : f->args()) {
auto where = remapped_args_.find(&arg);
if (where == remapped_args_.end())
arg_types.push_back(arg.getType());
else
arg_types.push_back(where->second);
}
auto func_type =
FunctionType::get(f->getReturnType(), arg_types, f->isVarArg());
if (func_type == f->getFunctionType())
return;
f->removeFromParent();
auto callee =
module->getOrInsertFunction(f->getName(), func_type, f->getAttributes());
auto new_func = cast<Function>(callee.getCallee());
new_func->setCallingConv(f->getCallingConv());
new_func->copyMetadata(f, 0);
// Move the basic blocks.
if (!f->isDeclaration()) {
std::vector<BasicBlock *> blocks;
for (auto &BB : *f) {
blocks.push_back(&BB);
}
for (auto *BB : blocks) {
BB->removeFromParent();
BB->insertInto(new_func);
}
}
// Replace args uses.
for (auto old_arg_iter = f->arg_begin(), new_arg_iter = new_func->arg_begin();
old_arg_iter != f->arg_end(); ++old_arg_iter, ++new_arg_iter) {
// Mutate the old arg type to satisfy RAUW.
old_arg_iter->mutateType(new_arg_iter->getType());
old_arg_iter->replaceAllUsesWith(&*new_arg_iter);
new_arg_iter->takeName(&*old_arg_iter);
}
// Copy uses because they will be modified.
SmallVector<Value *, 8> users;
for (auto U : f->users()) {
users.push_back(U);
}
for (auto U : users) {
if (auto call = dyn_cast<CallInst>(U)) {
call->setCalledFunction(new_func);
}
}
delete f;
}
} // namespace