blob: 6ab49ac580d6a858bf6f5a39cc06af9a349b7f98 [file] [log] [blame]
// Copyright 2017 The Clspv Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "llvm/ADT/DenseMap.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/Pass.h"
#include "llvm/Support/raw_ostream.h"
#include "clspv/AddressSpace.h"
#include "Passes.h"
using namespace llvm;
using std::string;
#define DEBUG_TYPE "hideconstantloads"
namespace {
const char *kWrapFunctionPrefix = "clspv.wrap_constant_load.";
class HideConstantLoadsPass : public ModulePass {
public:
static char ID;
HideConstantLoadsPass() : ModulePass(ID) {}
bool runOnModule(Module &M) override;
private:
// Return the name for the wrap function for the given type.
string &WrapFunctionNameForType(Type *type) {
auto where = function_for_type_.find(type);
if (where == function_for_type_.end()) {
// Insert it.
auto &result = function_for_type_[type] =
string(kWrapFunctionPrefix) +
std::to_string(function_for_type_.size());
return result;
} else {
return where->second;
}
}
// Maps a loaded type to the name of the wrap function for that type.
DenseMap<Type *, string> function_for_type_;
};
} // namespace
char HideConstantLoadsPass::ID = 0;
INITIALIZE_PASS(HideConstantLoadsPass, "HideConstantLoads",
"Hide loads from __constant and push constant memory", false,
false)
namespace clspv {
llvm::ModulePass *createHideConstantLoadsPass() {
return new HideConstantLoadsPass();
}
} // namespace clspv
bool HideConstantLoadsPass::runOnModule(Module &M) {
bool Changed = false;
SmallVector<LoadInst *, 16> WorkList;
for (Function &F : M) {
for (BasicBlock &BB : F) {
for (Instruction &I : BB) {
if (LoadInst *load = dyn_cast<LoadInst>(&I)) {
if (clspv::AddressSpace::Constant == load->getPointerAddressSpace() ||
clspv::AddressSpace::PushConstant ==
load->getPointerAddressSpace()) {
WorkList.push_back(load);
}
}
}
}
}
if (WorkList.size() == 0) {
return Changed;
}
for (LoadInst *load : WorkList) {
Changed = true;
auto loadedTy = load->getType();
// The wrap function conceptually maps the loaded value to itself.
const string &fn_name = WrapFunctionNameForType(loadedTy);
Function *fn = M.getFunction(fn_name);
if (!fn) {
// Make the function.
FunctionType *fnTy = FunctionType::get(loadedTy, {loadedTy}, false);
auto fn_constant = M.getOrInsertFunction(fn_name, fnTy);
fn = cast<Function>(fn_constant.getCallee());
fn->addFnAttr(Attribute::ReadOnly);
}
// Wrap the load
auto call = CallInst::Create(fn, {load});
call->insertAfter(load);
// Replace other uses of the load with the result of the wrap call.
{
SmallVector<User *, 16> ToReplaceIn;
for (auto &use : load->uses()) {
User *user = use.getUser();
ToReplaceIn.push_back(user);
}
for (auto *user : ToReplaceIn) {
if (dyn_cast<CallInst>(user) != call) {
user->replaceUsesOfWith(load, call);
}
}
}
}
return Changed;
}
namespace {
class UnhideConstantLoadsPass : public ModulePass {
public:
static char ID;
UnhideConstantLoadsPass() : ModulePass(ID) {}
bool runOnModule(Module &M) override;
private:
// Maps a loaded type to the name of the wrap function for that type.
DenseMap<Type *, string> function_for_type_;
};
} // namespace
char UnhideConstantLoadsPass::ID = 0;
INITIALIZE_PASS(UnhideConstantLoadsPass, "UnhideConstantLoads",
"Unhide loads from __constant and push constant memory", false,
false)
namespace clspv {
llvm::ModulePass *createUnhideConstantLoadsPass() {
return new UnhideConstantLoadsPass();
}
} // namespace clspv
bool UnhideConstantLoadsPass::runOnModule(Module &M) {
bool Changed = false;
SmallVector<Function *, 16> WorkList;
for (auto &F : M.getFunctionList()) {
if (F.getName().startswith(kWrapFunctionPrefix)) {
WorkList.push_back(&F);
}
}
if (WorkList.size() == 0)
return Changed;
SmallVector<CallInst *, 16> RemoveList;
for (auto *F : WorkList) {
for (auto &use : F->uses()) {
if (auto *call = dyn_cast<CallInst>(use.getUser())) {
assert(call->getNumArgOperands() == 1);
auto *load = call->getArgOperand(0);
call->replaceAllUsesWith(load);
RemoveList.push_back(call);
}
}
}
for (auto *call : RemoveList) {
call->eraseFromParent();
}
for (auto *F : WorkList) {
F->eraseFromParent();
}
return Changed;
}