blob: 521ad81f225bc4b7db53db5ddda2ee98fb99d37b [file] [log] [blame]
// Copyright 2019 The Clspv Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/raw_ostream.h"
#include "clspv/AddressSpace.h"
#include "NormalizeGlobalVariable.h"
using namespace llvm;
namespace {
// Returns the sole non-array, non-struct type contained in |type|. Returns
// nullptr if there is no such type.
Type *SoleContainedType(Type *type) {
if (auto *array_ty = dyn_cast<ArrayType>(type)) {
return SoleContainedType(array_ty->getArrayElementType());
} else if (auto *struct_ty = dyn_cast<StructType>(type)) {
Type *unique_ty = nullptr;
for (auto ele_ty : struct_ty->elements()) {
if (unique_ty == nullptr) {
unique_ty = SoleContainedType(ele_ty);
if (!unique_ty)
return nullptr;
} else if (unique_ty != SoleContainedType(ele_ty)) {
return nullptr;
}
}
return unique_ty;
}
return type;
}
// Returns the number of subtypes in |type|.
uint64_t GetNumElements(Type *type) {
if (type->isStructTy()) {
return type->getStructNumElements();
} else if (type->isArrayTy()) {
return type->getArrayNumElements();
} else {
return 0;
}
}
// Flattens |constant| into |flattened|. |flattened| is populated with all
// arrays and structs broken down into constituent constants.
void FlattenConstant(Constant *constant, std::vector<Constant *> *flattened) {
uint64_t num_elements = GetNumElements(constant->getType());
for (uint64_t i = 0; i != num_elements; ++i) {
auto *const_element = constant->getAggregateElement(i);
auto *element_ty = const_element->getType();
// Special cases for constant aggregate zero and constant data sequential
// to populate the right number of constant elements into |flattened|.
if (auto caz = dyn_cast<ConstantAggregateZero>(const_element)) {
for (size_t i = 0; i != GetNumElements(element_ty); ++i) {
if (element_ty->isStructTy()) {
flattened->push_back(caz->getStructElement(i));
} else {
flattened->push_back(caz->getSequentialElement());
}
}
} else if (auto cds = dyn_cast<ConstantDataSequential>(const_element)) {
for (uint64_t i = 0; i != GetNumElements(element_ty); ++i) {
auto *element = cds->getElementAsConstant(i);
flattened->push_back(element);
}
} else if (element_ty->isArrayTy() || element_ty->isStructTy()) {
FlattenConstant(const_element, flattened);
} else {
flattened->push_back(const_element);
}
}
}
// Returns a constant for |new_type| out |elements|. |ele_index| is used to
// index correctly into the full array of elements.
Constant *BuildConstant(Type *new_type, const std::vector<Constant *> elements,
uint64_t *ele_index) {
auto GetSubType = [](Type *type, uint64_t element) {
if (type->isStructTy()) {
return type->getContainedType(element);
} else if (type->isArrayTy()) {
return type->getArrayElementType();
}
return type;
};
std::vector<Constant *> constants;
uint64_t num_eles = GetNumElements(new_type);
for (uint64_t i = 0; i != num_eles; ++i) {
auto *subtype = GetSubType(new_type, i);
if (subtype->isArrayTy() || subtype->isStructTy()) {
constants.push_back(BuildConstant(subtype, elements, ele_index));
} else {
constants.push_back(elements[*ele_index]);
++(*ele_index);
}
}
// Generate the new constant.
if (auto struct_ty = dyn_cast<StructType>(new_type)) {
return ConstantStruct::get(struct_ty, constants);
} else {
return ConstantArray::get(cast<ArrayType>(new_type), constants);
}
}
// Returns |init| represented as |new_type|.
Constant *TranslateConstant(Constant *init, Type *new_type) {
assert(init->getType()->isStructTy() || init->getType()->isArrayTy());
std::vector<Constant *> flattened;
FlattenConstant(init, &flattened);
uint64_t ele_index = 0;
auto *new_constant = BuildConstant(new_type, flattened, &ele_index);
return new_constant;
}
// Returns true if |GV| can be normalized to |to_type|. |gv_contained_ty| is
// the sole contained type in |GV|.
bool VariableNeedsNormalized(GlobalVariable *GV, Type *gv_contained_ty,
Type *to_type) {
auto gv_pointee = GV->getType()->getPointerElementType();
if (!gv_pointee->isStructTy() && !gv_pointee->isArrayTy())
return false;
auto ce_pointee = to_type->getPointerElementType();
if (!ce_pointee->isStructTy() && !ce_pointee->isArrayTy())
return false;
const auto &DL = GV->getParent()->getDataLayout();
if (DL.getTypeStoreSize(gv_pointee) != DL.getTypeStoreSize(ce_pointee))
return false;
auto *ce_contained_ty = SoleContainedType(ce_pointee);
if (gv_contained_ty == ce_contained_ty)
return true;
return false;
}
// Normalize the user |user| of |GV|. Generates a global variable with
// appropriate initializer and replaces uses of |user| with the new variable.
GlobalVariable *NormalizeVariable(GlobalVariable *GV, Value *user) {
auto *new_type = user->getType()->getPointerElementType();
Constant *new_initializer = nullptr;
if (GV->hasInitializer()) {
auto *initializer = GV->getInitializer();
new_initializer = TranslateConstant(initializer, new_type);
}
GlobalVariable *new_gv = new GlobalVariable(
*GV->getParent(), new_type, GV->isConstant(), GV->getLinkage(),
new_initializer, "", nullptr, GV->getThreadLocalMode(),
GV->getType()->getPointerAddressSpace(), GV->isExternallyInitialized());
new_gv->takeName(GV);
user->replaceAllUsesWith(new_gv);
return new_gv;
}
// Normalize the users of |GV|.
void NormalizeVariableUsers(GlobalVariable *GV) {
for (auto *user : GV->users()) {
auto *bitcast = dyn_cast<ConstantExpr>(user);
if (!bitcast || bitcast->getOpcode() != Instruction::BitCast)
continue;
Type *gv_contained_ty =
SoleContainedType(GV->getType()->getPointerElementType());
if (!gv_contained_ty)
continue;
if (!VariableNeedsNormalized(GV, gv_contained_ty, bitcast->getType()))
continue;
NormalizeVariable(GV, bitcast);
}
GV->removeDeadConstantUsers();
if (GV->use_empty()) {
GV->eraseFromParent();
}
}
} // namespace
namespace clspv {
void NormalizeGlobalVariables(Module &M) {
SmallVector<GlobalVariable *, 8> globals;
for (auto &GV : M.globals()) {
if (GV.hasInitializer() && GV.getType()->getPointerAddressSpace() ==
clspv::AddressSpace::Constant) {
globals.push_back(&GV);
}
}
for (auto *GV : globals) {
NormalizeVariableUsers(GV);
}
}
} // namespace clspv