blob: 9850484a5e28457fc031933f921e0a2aa3928684 [file] [log] [blame] [edit]
///////////////////////////////////////////////////////////////////////////////
// //
// DxilRuntimeReflection.inl //
// Copyright (C) Microsoft Corporation. All rights reserved. //
// This file is distributed under the University of Illinois Open Source //
// License. See LICENSE.TXT for details. //
// //
// Defines shader reflection for runtime usage. //
// //
///////////////////////////////////////////////////////////////////////////////
#include "dxc/DxilContainer/DxilRuntimeReflection.h"
#include <cwchar>
#include <memory>
#include <unordered_map>
#include <vector>
namespace hlsl {
namespace RDAT {
#define DEF_RDAT_TYPES DEF_RDAT_READER_IMPL
#include "dxc/DxilContainer/RDAT_Macros.inl"
struct ResourceKey {
uint32_t Class, ID;
ResourceKey(uint32_t Class, uint32_t ID) : Class(Class), ID(ID) {}
bool operator==(const ResourceKey &other) const {
return other.Class == Class && other.ID == ID;
}
};
// Size-checked reader
// on overrun: throw buffer_overrun{};
// on overlap: throw buffer_overlap{};
class CheckedReader {
const char *Ptr;
size_t Size;
size_t Offset;
public:
class exception : public std::exception {};
class buffer_overrun : public exception {
public:
buffer_overrun() noexcept {}
virtual const char *what() const noexcept override {
return ("buffer_overrun");
}
};
class buffer_overlap : public exception {
public:
buffer_overlap() noexcept {}
virtual const char *what() const noexcept override {
return ("buffer_overlap");
}
};
CheckedReader(const void *ptr, size_t size)
: Ptr(reinterpret_cast<const char *>(ptr)), Size(size), Offset(0) {}
void Reset(size_t offset = 0) {
if (offset >= Size)
throw buffer_overrun{};
Offset = offset;
}
// offset is absolute, ensure offset is >= current offset
void Advance(size_t offset = 0) {
if (offset < Offset)
throw buffer_overlap{};
if (offset >= Size)
throw buffer_overrun{};
Offset = offset;
}
void CheckBounds(size_t size) const {
assert(Offset <= Size && "otherwise, offset larger than size");
if (size > Size - Offset)
throw buffer_overrun{};
}
template <typename T> const T *Cast(size_t size = 0) {
if (0 == size)
size = sizeof(T);
CheckBounds(size);
return reinterpret_cast<const T *>(Ptr + Offset);
}
template <typename T> const T &Read() {
const size_t size = sizeof(T);
const T *p = Cast<T>(size);
Offset += size;
return *p;
}
template <typename T> const T *ReadArray(size_t count = 1) {
const size_t size = sizeof(T) * count;
const T *p = Cast<T>(size);
Offset += size;
return p;
}
};
DxilRuntimeData::DxilRuntimeData() : DxilRuntimeData(nullptr, 0) {}
DxilRuntimeData::DxilRuntimeData(const void *ptr, size_t size) {
InitFromRDAT(ptr, size);
}
static void InitTable(RDATContext &ctx, CheckedReader &PR,
RecordTableIndex tableIndex) {
RuntimeDataTableHeader table = PR.Read<RuntimeDataTableHeader>();
size_t tableSize = table.RecordCount * table.RecordStride;
ctx.Table(tableIndex)
.Init(PR.ReadArray<char>(tableSize), table.RecordCount,
table.RecordStride);
}
// initializing reader from RDAT. return true if no error has occured.
bool DxilRuntimeData::InitFromRDAT(const void *pRDAT, size_t size) {
if (pRDAT) {
m_DataSize = size;
try {
CheckedReader Reader(pRDAT, size);
RuntimeDataHeader RDATHeader = Reader.Read<RuntimeDataHeader>();
if (RDATHeader.Version < RDAT_Version_10) {
return false;
}
const uint32_t *offsets =
Reader.ReadArray<uint32_t>(RDATHeader.PartCount);
for (uint32_t i = 0; i < RDATHeader.PartCount; ++i) {
Reader.Advance(offsets[i]);
RuntimeDataPartHeader part = Reader.Read<RuntimeDataPartHeader>();
CheckedReader PR(Reader.ReadArray<char>(part.Size), part.Size);
switch (part.Type) {
case RuntimeDataPartType::StringBuffer: {
m_Context.StringBuffer.Init(PR.ReadArray<char>(part.Size), part.Size);
break;
}
case RuntimeDataPartType::IndexArrays: {
uint32_t count = part.Size / sizeof(uint32_t);
m_Context.IndexTable.Init(PR.ReadArray<uint32_t>(count), count);
break;
}
case RuntimeDataPartType::RawBytes: {
m_Context.RawBytes.Init(PR.ReadArray<char>(part.Size), part.Size);
break;
}
// Once per table.
#define RDAT_STRUCT_TABLE(type, table) \
case RuntimeDataPartType::table: \
InitTable(m_Context, PR, RecordTableIndex::table); \
break;
#define DEF_RDAT_TYPES DEF_RDAT_DEFAULTS
#include "dxc/DxilContainer/RDAT_Macros.inl"
default:
continue; // Skip unrecognized parts
}
}
#ifndef NDEBUG
return Validate();
#else // NDEBUG
return true;
#endif // NDEBUG
} catch (CheckedReader::exception e) {
// TODO: error handling
// throw hlsl::Exception(DXC_E_MALFORMED_CONTAINER, e.what());
return false;
}
}
m_DataSize = 0;
return false;
}
// TODO: Incorporate field names and report errors in error stream
// TODO: Low-pri: Check other things like that all the index, string,
// and binary buffer space is actually used.
template <typename _RecordType>
static bool ValidateRecordRef(const RDATContext &ctx, uint32_t id) {
if (id == RDAT_NULL_REF)
return true;
// id should be a valid index into the appropriate table
auto &table = ctx.Table(RecordTraits<_RecordType>::TableIndex());
if (id >= table.Count())
return false;
return true;
}
static bool ValidateIndexArrayRef(const RDATContext &ctx, uint32_t id) {
if (id == RDAT_NULL_REF)
return true;
uint32_t size = ctx.IndexTable.Count();
// check that id < size of index array
if (id >= size)
return false;
// check that array size fits in remaining index space
if (id + ctx.IndexTable.Data()[id] >= size)
return false;
return true;
}
template <typename _RecordType>
static bool ValidateRecordArrayRef(const RDATContext &ctx, uint32_t id) {
// Make sure index array is well-formed
if (!ValidateIndexArrayRef(ctx, id))
return false;
// Make sure each record id is a valid record ref in the table
auto ids = ctx.IndexTable.getRow(id);
for (unsigned i = 0; i < ids.Count(); i++) {
if (!ValidateRecordRef<_RecordType>(ctx, ids.At(i)))
return false;
}
return true;
}
static bool ValidateStringRef(const RDATContext &ctx, uint32_t id) {
if (id == RDAT_NULL_REF)
return true;
uint32_t size = ctx.StringBuffer.Size();
if (id >= size)
return false;
return true;
}
static bool ValidateStringArrayRef(const RDATContext &ctx, uint32_t id) {
if (id == RDAT_NULL_REF)
return true;
// Make sure index array is well-formed
if (!ValidateIndexArrayRef(ctx, id))
return false;
// Make sure each index is valid in string buffer
auto ids = ctx.IndexTable.getRow(id);
for (unsigned i = 0; i < ids.Count(); i++) {
if (!ValidateStringRef(ctx, ids.At(i)))
return false;
}
return true;
}
// Specialized for each record type
template <typename _RecordType>
bool ValidateRecord(const RDATContext &ctx, const _RecordType *pRecord) {
return false;
}
#define DEF_RDAT_TYPES DEF_RDAT_STRUCT_VALIDATION
#include "dxc/DxilContainer/RDAT_Macros.inl"
// This class ensures that all versions of record to latest one supported by
// table stride are validated
class RecursiveRecordValidator {
const hlsl::RDAT::RDATContext &m_Context;
uint32_t m_RecordStride;
public:
RecursiveRecordValidator(const hlsl::RDAT::RDATContext &ctx,
uint32_t recordStride)
: m_Context(ctx), m_RecordStride(recordStride) {}
template <typename _RecordType>
bool Validate(const _RecordType *pRecord) const {
if (pRecord && sizeof(_RecordType) <= m_RecordStride) {
if (!ValidateRecord(m_Context, pRecord))
return false;
return ValidateDerived<_RecordType>(pRecord);
}
return true;
}
// Specialized for base type to recurse into derived
template <typename _RecordType>
bool ValidateDerived(const _RecordType *) const {
return true;
}
};
template <typename _RecordType>
static bool ValidateRecordTable(RDATContext &ctx, RecordTableIndex tableIndex) {
// iterate through records, bounds-checking all refs and index arrays
auto &table = ctx.Table(tableIndex);
for (unsigned i = 0; i < table.Count(); i++) {
RecursiveRecordValidator(ctx, table.Stride())
.Validate<_RecordType>(table.Row<_RecordType>(i));
}
return true;
}
bool DxilRuntimeData::Validate() {
if (m_Context.StringBuffer.Size()) {
if (m_Context.StringBuffer.Data()[m_Context.StringBuffer.Size() - 1] != 0)
return false;
}
// Once per table.
#define RDAT_STRUCT_TABLE(type, table) \
ValidateRecordTable<type>(m_Context, RecordTableIndex::table);
// As an assumption of the way record types are versioned, derived record
// types must always be larger than base record types.
#define RDAT_STRUCT_DERIVED(type, base) \
static_assert(sizeof(type) > sizeof(base), \
"otherwise, derived record type " #type \
" is not larger than base record type " #base ".");
#define DEF_RDAT_TYPES DEF_RDAT_DEFAULTS
#include "dxc/DxilContainer/RDAT_Macros.inl"
return true;
}
} // namespace RDAT
} // namespace hlsl