blob: c0b84f1898822df3bac0ce5102636005d6eba1b5 [file] [log] [blame] [edit]
/*
* Copyright (c) 2023 Apple Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#pragma once
#include "ConstantValue.h"
#include "Constraints.h"
#include "Types.h"
#include <bit>
#include <numbers>
#include <wtf/Assertions.h>
#include <wtf/DataLog.h>
#include <wtf/text/MakeString.h>
namespace WGSL {
// Zero values
using ConstantResult = Expected<ConstantValue, String>;
using ConstantFunction = ConstantResult(*)(const Type*, const FixedVector<ConstantValue>&);
#define CONSTANT_FUNCTION(name) \
static Expected<ConstantValue, String>(constant ## name)(const Type* resultType, const FixedVector<ConstantValue>& arguments)
#define CALL_(__tmp, __variable, __fnName, ...) \
auto __tmp = constant##__fnName(__VA_ARGS__); \
if (!__tmp) \
return makeUnexpected(__tmp.error()); \
auto __variable = WTF::move(*__tmp)
#define CALL(__variable, __fnName, ...) \
CALL_(WTF_LAZY_JOIN(tmp, __COUNTER__), __variable, __fnName, __VA_ARGS__)
#define CALL_MOVE_(__tmp, __target, __fnName, ...) \
do { \
auto __tmp = constant##__fnName(__VA_ARGS__); \
if (!__tmp) \
return makeUnexpected(__tmp.error()); \
__target = WTF::move(*__tmp); \
} while (0)
#define CALL_MOVE(__target, __fnName, ...) \
CALL_MOVE_(tmp ## __COUNTER__, __target, __fnName, __VA_ARGS__)
static ConstantValue zeroValue(const Type* type)
{
return WTF::switchOn(*type,
[&](const Types::Primitive& primitive) -> ConstantValue {
switch (primitive.kind) {
case Types::Primitive::AbstractInt:
return static_cast<int64_t>(0);
case Types::Primitive::I32:
return 0;
case Types::Primitive::U32:
return 0u;
case Types::Primitive::AbstractFloat:
return 0.0;
case Types::Primitive::F32:
return 0.0f;
case Types::Primitive::F16:
return static_cast<half>(0.0);
case Types::Primitive::Bool:
return false;
case Types::Primitive::Void:
case Types::Primitive::Sampler:
case Types::Primitive::SamplerComparison:
case Types::Primitive::TextureExternal:
case Types::Primitive::AccessMode:
case Types::Primitive::TexelFormat:
case Types::Primitive::AddressSpace:
RELEASE_ASSERT_NOT_REACHED();
}
},
[&](const Types::Vector& vector) -> ConstantValue {
ConstantVector result(vector.size);
auto value = zeroValue(vector.element);
for (unsigned i = 0; i < vector.size; ++i)
result.elements[i] = value;
return result;
},
[&](const Types::Array& array) -> ConstantValue {
ASSERT(std::holds_alternative<unsigned>(array.size));
auto size = std::get<unsigned>(array.size);
ConstantArray result(size);
auto value = zeroValue(array.element);
for (unsigned i = 0; i < size; ++i)
result.elements[i] = value;
return result;
},
[&](const Types::Struct& structType) -> ConstantValue {
HashMap<String, ConstantValue> constantFields;
for (auto& [key, type] : structType.fields)
constantFields.set(key, zeroValue(type));
return ConstantStruct { WTF::move(constantFields) };
},
[&](const Types::PrimitiveStruct&) -> ConstantValue {
// Primitive structs can't be zero initialized
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Types::Matrix& matrix) -> ConstantValue {
ConstantMatrix result(matrix.columns, matrix.rows);
auto value = zeroValue(matrix.element);
for (unsigned i = 0; i < result.elements.size(); ++i)
result.elements[i] = value;
return result;
},
[&](const Types::Reference&) -> ConstantValue {
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Types::Pointer&) -> ConstantValue {
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Types::Function&) -> ConstantValue {
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Types::Texture&) -> ConstantValue {
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Types::TextureStorage&) -> ConstantValue {
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Types::TextureDepth&) -> ConstantValue {
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Types::Atomic&) -> ConstantValue {
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Types::TypeConstructor&) -> ConstantValue {
RELEASE_ASSERT_NOT_REACHED();
});
}
// Helpers
template<Constraint constraint, typename Functor>
static ConstantResult constantUnaryOperation(const FixedVector<ConstantValue>& arguments, const Functor& fn)
{
ASSERT(arguments.size() == 1);
return scalarOrVector([&](auto& arg) -> ConstantResult {
if constexpr (constraint & Constraints::Bool) {
if (auto* boolean = std::get_if<bool>(&arg))
return { { fn(*boolean) } };
}
if constexpr (constraint & Constraints::I32) {
if (auto* i32 = std::get_if<int32_t>(&arg))
return { { fn(*i32) } };
}
if constexpr (constraint & Constraints::U32) {
if (auto* u32 = std::get_if<uint32_t>(&arg))
return { { fn(*u32) } };
}
if constexpr (constraint & Constraints::AbstractInt) {
if (auto* abstractInt = std::get_if<int64_t>(&arg))
return { { fn(*abstractInt) } };
}
if constexpr (constraint & Constraints::F32) {
if (auto* f32 = std::get_if<float>(&arg))
return { { fn(*f32) } };
}
if constexpr (constraint & Constraints::F16) {
if (auto* f16 = std::get_if<half>(&arg))
return { { fn(*f16) } };
}
if constexpr (constraint & Constraints::AbstractFloat) {
if (auto* abstractFloat = std::get_if<double>(&arg))
return { { fn(*abstractFloat) } };
}
RELEASE_ASSERT_NOT_REACHED();
}, arguments[0]);
}
template<Constraint constraint, typename Functor>
static ConstantResult constantBinaryOperation(const FixedVector<ConstantValue>& arguments, const Functor& fn)
{
ASSERT(arguments.size() == 2);
return scalarOrVector([&](auto& left, auto& right) -> ConstantResult {
if constexpr (constraint & Constraints::Bool) {
if (auto* leftBool = std::get_if<bool>(&left))
return { { fn(*leftBool, std::get<bool>(right)) } };
}
if constexpr (constraint & Constraints::I32) {
if (auto* leftI32 = std::get_if<int32_t>(&left))
return { { fn(*leftI32, std::get<int32_t>(right)) } };
}
if constexpr (constraint & Constraints::U32) {
if (auto* leftU32 = std::get_if<uint32_t>(&left))
return { { fn(*leftU32, std::get<uint32_t>(right)) } };
}
if constexpr (constraint & Constraints::AbstractInt) {
if (auto* leftAbstractInt = std::get_if<int64_t>(&left))
return { { fn(*leftAbstractInt, std::get<int64_t>(right)) } };
}
if constexpr (constraint & Constraints::F32) {
if (auto* leftF32 = std::get_if<float>(&left))
return { { fn(*leftF32, std::get<float>(right)) } };
}
if constexpr (constraint & Constraints::F16) {
if (auto* leftF16 = std::get_if<half>(&left))
return { { fn(*leftF16, std::get<half>(right)) } };
}
if constexpr (constraint & Constraints::AbstractFloat) {
if (auto* leftAbstractFloat = std::get_if<double>(&left))
return { { fn(*leftAbstractFloat, std::get<double>(right)) } };
}
RELEASE_ASSERT_NOT_REACHED();
}, arguments[0], arguments[1]);
}
template<Constraint constraint, typename Functor>
static ConstantResult constantTernaryOperation(const FixedVector<ConstantValue>& arguments, const Functor& fn)
{
ASSERT(arguments.size() == 3);
return scalarOrVector([&](auto& first, auto& second, auto& third) -> ConstantResult {
if constexpr (constraint & Constraints::Bool) {
if (auto* firstBool = std::get_if<bool>(&first))
return { { fn(*firstBool, std::get<bool>(second), std::get<bool>(third)) } };
}
if constexpr (constraint & Constraints::I32) {
if (auto* firstI32 = std::get_if<int32_t>(&first))
return { { fn(*firstI32, std::get<int32_t>(second), std::get<int32_t>(third)) } };
}
if constexpr (constraint & Constraints::U32) {
if (auto* firstU32 = std::get_if<uint32_t>(&first))
return { { fn(*firstU32, std::get<uint32_t>(second), std::get<uint32_t>(third)) } };
}
if constexpr (constraint & Constraints::AbstractInt) {
if (auto* firstAbstractInt = std::get_if<int64_t>(&first))
return { { fn(*firstAbstractInt, std::get<int64_t>(second), std::get<int64_t>(third)) } };
}
if constexpr (constraint & Constraints::F32) {
if (auto* firstF32 = std::get_if<float>(&first))
return { { fn(*firstF32, std::get<float>(second), std::get<float>(third)) } };
}
if constexpr (constraint & Constraints::F16) {
if (auto* firstF16 = std::get_if<half>(&first))
return { { fn(*firstF16, std::get<half>(second), std::get<half>(third)) } };
}
if constexpr (constraint & Constraints::AbstractFloat) {
if (auto* firstAbstractFloat = std::get_if<double>(&first))
return { { fn(*firstAbstractFloat, std::get<double>(second), std::get<double>(third)) } };
}
RELEASE_ASSERT_NOT_REACHED();
}, arguments[0], arguments[1], arguments[2]);
}
template<typename U>
struct StaticCast {
template<typename T>
static U cast(T t) { return static_cast<U>(t); }
};
template<typename U>
struct BitwiseCast {
template<typename T>
static U cast(T t) { return std::bit_cast<U>(t); }
};
template<typename DestinationType, template <typename U> typename Cast = StaticCast>
static ConstantValue convertValue(ConstantValue value)
{
if constexpr (std::is_same_v<Cast<DestinationType>, StaticCast<DestinationType>> || sizeof(DestinationType) == 4) {
if (auto* i32 = std::get_if<int32_t>(&value))
return Cast<DestinationType>::cast(*i32);
if (auto* u32 = std::get_if<uint32_t>(&value))
return Cast<DestinationType>::cast(*u32);
if (auto* f32 = std::get_if<float>(&value))
return Cast<DestinationType>::cast(*f32);
}
if constexpr (std::is_same_v<Cast<DestinationType>, StaticCast<DestinationType>> || sizeof(DestinationType) == 2) {
if (auto* f16 = std::get_if<half>(&value))
return Cast<DestinationType>::cast(*f16);
}
if constexpr (std::is_same_v<Cast<DestinationType>, StaticCast<DestinationType>>) {
if (auto* boolean = std::get_if<bool>(&value))
return Cast<DestinationType>::cast(*boolean);
if (auto* abstractInt = std::get_if<int64_t>(&value))
return Cast<DestinationType>::cast(*abstractInt);
if (auto* abstractFloat = std::get_if<double>(&value))
return Cast<DestinationType>::cast(*abstractFloat);
}
RELEASE_ASSERT_NOT_REACHED();
}
template<template <typename U> typename Cast = StaticCast>
static ConstantValue convertValue(const Type* targetType, ConstantValue value)
{
ASSERT(std::holds_alternative<Types::Primitive>(*targetType));
auto& primitive = std::get<Types::Primitive>(*targetType);
switch (primitive.kind) {
case Types::Primitive::AbstractInt:
return convertValue<int64_t, Cast>(value);
case Types::Primitive::I32:
return convertValue<int32_t, Cast>(value);
case Types::Primitive::U32:
return convertValue<uint32_t, Cast>(value);
case Types::Primitive::AbstractFloat:
return convertValue<double, Cast>(value);
case Types::Primitive::F32:
return convertValue<float, Cast>(value);
case Types::Primitive::F16:
return convertValue<half, Cast>(value);
case Types::Primitive::Bool:
return convertValue<bool, Cast>(value);
default:
RELEASE_ASSERT_NOT_REACHED();
}
}
template<typename DestinationType>
static ConstantValue constantConstructor(const Type* resultType, const FixedVector<ConstantValue>& arguments)
{
if (arguments.isEmpty())
return zeroValue(resultType);
ASSERT(arguments.size() == 1);
return convertValue<DestinationType>(arguments[0]);
}
template<typename Functor, typename... Arguments>
static ConstantResult scalarOrVector(const Functor& functor, Arguments&&... unpackedArguments)
{
unsigned vectorSize = 0;
std::initializer_list<ConstantValue> arguments { unpackedArguments... };
for (auto argument : arguments) {
if (auto* vector = std::get_if<ConstantVector>(&argument)) {
vectorSize = vector->elements.size();
break;
}
}
if (!vectorSize)
return functor(std::forward<Arguments>(unpackedArguments)...);
constexpr auto argumentCount = sizeof...(Arguments);
ConstantVector result(vectorSize);
std::array<ConstantValue, argumentCount> scalars;
for (unsigned i = 0; i < vectorSize; ++i) {
unsigned j = 0;
for (const auto& argument : arguments) {
if (auto* vector = std::get_if<ConstantVector>(&argument))
scalars[j++] = vector->elements[i];
else
scalars[j++] = argument;
}
ConstantResult tmp = std::apply(functor, scalars);
if (!tmp)
return makeUnexpected(tmp.error());
result.elements[i] = WTF::move(*tmp);
}
return { { result } };
}
static ConstantValue constantVector(const Type* resultType, const FixedVector<ConstantValue>& arguments, unsigned size)
{
ConstantVector result(size);
auto argumentCount = arguments.size();
ASSERT(std::holds_alternative<Types::Vector>(*resultType));
if (!argumentCount) {
return zeroValue(resultType);
}
if (argumentCount == 1) {
auto& vectorType = std::get<Types::Vector>(*resultType);
auto* elementType = vectorType.element;
if (auto* vector = std::get_if<ConstantVector>(&arguments[0])) {
ASSERT(vector->elements.size() == vectorType.size);
unsigned i = 0;
for (auto element : vector->elements)
result.elements[i++] = convertValue(elementType, element);
} else {
for (unsigned i = 0; i < size; ++i)
result.elements[i] = convertValue(elementType, arguments[0]);
}
return result;
}
unsigned i = 0;
for (const auto& argument : arguments) {
const auto* vector = std::get_if<ConstantVector>(&argument);
if (!vector) {
result.elements[i++] = argument;
continue;
}
for (auto element : vector->elements)
result.elements[i++] = element;
}
ASSERT(i == size);
return { result };
}
static ConstantValue constantMatrix(const Type* resultType, const FixedVector<ConstantValue>& arguments, unsigned columns, unsigned rows)
{
if (arguments.isEmpty())
return zeroValue(resultType);
if (arguments.size() == 1) {
auto& matrixType = std::get<Types::Matrix>(*resultType);
auto* elementType = matrixType.element;
auto& matrix = std::get<ConstantMatrix>(arguments[0]);
ASSERT(matrix.elements.size() == matrixType.rows * matrixType.columns);
ConstantMatrix result(columns, rows);
unsigned i = 0;
for (auto& element : matrix.elements)
result.elements[i++] = convertValue(elementType, element);
return result;
}
if (arguments.size() == columns * rows)
return ConstantMatrix { columns, rows, arguments };
RELEASE_ASSERT(arguments.size() == columns);
ConstantMatrix result(columns, rows);
unsigned i = 0;
for (auto& arg : arguments) {
ASSERT(arg.isVector());
auto& vector = arg.toVector();
ASSERT(vector.elements.size() == rows);
for (auto& element : vector.elements)
result.elements[i++] = element;
}
ASSERT(i == columns * rows);
return result;
}
#define UNARY_OPERATION(name, constraint, fn) \
CONSTANT_FUNCTION(name) \
{ \
UNUSED_PARAM(resultType); \
return constantUnaryOperation<Constraints::constraint>(arguments, fn); \
}
#define BINARY_OPERATION(name, constraint, fn) \
CONSTANT_FUNCTION(name) \
{ \
UNUSED_PARAM(resultType); \
return constantBinaryOperation<Constraints::constraint>(arguments, fn); \
}
#define TERNARY_OPERATION(name, constraint, fn) \
CONSTANT_FUNCTION(name) \
{ \
UNUSED_PARAM(resultType); \
return constantTernaryOperation<Constraints::constraint>(arguments, fn); \
}
#define CONSTANT_CONSTRUCTOR(name, type) \
CONSTANT_FUNCTION(name) \
{ \
return { constantConstructor<type>(resultType, arguments) }; \
}
#define MATRIX_CONSTRUCTOR(columns, rows) \
CONSTANT_FUNCTION(Mat ## columns ## x ## rows) \
{ \
return constantMatrix(resultType, arguments, columns, rows); \
}
#define CONSTANT_TRIGONOMETRIC(name, fn) UNARY_OPERATION(name, Float, WRAP_STD(fn))
#define WRAP_STD(fn) \
[&]<typename T, typename... Args>(T arg, Args&&... args) -> T { return std::fn(arg, std::forward<Args>(args)...); }
// Arithmetic operators
CONSTANT_FUNCTION(Add)
{
ASSERT(arguments.size() == 2);
if (auto* left = std::get_if<ConstantMatrix>(&arguments[0])) {
auto& right = std::get<ConstantMatrix>(arguments[1]);
auto* elementType = std::get<Types::Matrix>(*resultType).element;
ASSERT(left->columns == right.columns);
ASSERT(left->rows == right.rows);
ConstantMatrix result(left->columns, left->rows);
for (unsigned i = 0; i < result.elements.size(); ++i)
CALL_MOVE(result.elements[i], Add, elementType, { left->elements[i], right.elements[i] });
return { { result } };
}
return constantBinaryOperation<Constraints::Number>(arguments, [&]<typename T>(T left, T right) -> T {
return left + right;
});
}
CONSTANT_FUNCTION(Minus)
{
UNUSED_PARAM(resultType);
if (arguments.size() == 1) {
return constantUnaryOperation<Constraints::Number>(arguments, [&]<typename T>(T arg) -> T {
return -arg;
});
}
if (auto* left = std::get_if<ConstantMatrix>(&arguments[0])) {
auto& right = std::get<ConstantMatrix>(arguments[1]);
auto* elementType = std::get<Types::Matrix>(*resultType).element;
ASSERT(left->columns == right.columns);
ASSERT(left->rows == right.rows);
ConstantMatrix result(left->columns, left->rows);
for (unsigned i = 0; i < result.elements.size(); ++i)
CALL_MOVE(result.elements[i], Minus, elementType, { left->elements[i], right.elements[i] });
return { { result } };
}
return constantBinaryOperation<Constraints::Number>(arguments, [&]<typename T>(T left, T right) -> T {
return left - right;
});
}
CONSTANT_FUNCTION(Dot);
CONSTANT_FUNCTION(Multiply)
{
ASSERT(arguments.size() == 2);
auto* leftMatrix = std::get_if<ConstantMatrix>(&arguments[0]);
auto* rightMatrix = std::get_if<ConstantMatrix>(&arguments[1]);
if (leftMatrix && rightMatrix) {
ASSERT(leftMatrix->columns == rightMatrix->rows);
auto* elementType = std::get<Types::Matrix>(*resultType).element;
ConstantMatrix result(rightMatrix->columns, leftMatrix->rows);
for (unsigned i = 0; i < rightMatrix->columns; ++i) {
for (unsigned j = 0; j < leftMatrix->rows; ++j) {
ConstantValue value = zeroValue(elementType);
for (unsigned k = 0; k < leftMatrix->columns; ++k) {
CALL(tmp, Multiply, elementType, { leftMatrix->elements[k * leftMatrix->rows + j], rightMatrix->elements[i * rightMatrix->rows + k] });
CALL_MOVE(value, Add, elementType, { value, tmp });
}
result.elements[i * result.rows + j] = value;
}
}
return { { result } };
}
if (leftMatrix || rightMatrix) {
if (auto* rightVector = std::get_if<ConstantVector>(&arguments[1])) {
auto* elementType = std::get<Types::Vector>(*resultType).element;
auto columns = leftMatrix->columns;
auto rows = leftMatrix->rows;
ConstantVector result(rows);
ConstantVector leftVector(columns);
for (unsigned i = 0; i < rows; ++i) {
for (unsigned j = 0; j < columns; ++j)
leftVector.elements[j] = leftMatrix->elements[j * rows + i];
CALL_MOVE(result.elements[i], Dot, elementType, { leftVector, *rightVector });
}
return { { result } };
}
if (auto* leftVector = std::get_if<ConstantVector>(&arguments[0])) {
auto* elementType = std::get<Types::Vector>(*resultType).element;
auto columns = rightMatrix->columns;
auto rows = rightMatrix->rows;
ConstantVector result(columns);
ConstantVector rightVector(rows);
for (unsigned i = 0; i < columns; ++i) {
for (unsigned j = 0; j < rows; ++j)
rightVector.elements[j] = rightMatrix->elements[i * rows + j];
CALL_MOVE(result.elements[i], Dot, elementType, { *leftVector, rightVector });
}
return { { result } };
}
const ConstantMatrix* matrix;
ConstantValue scalar;
if (leftMatrix) {
matrix = leftMatrix;
scalar = arguments[1];
} else {
matrix = rightMatrix;
scalar = arguments[0];
}
auto* elementType = std::get<Types::Matrix>(*resultType).element;
ConstantMatrix result(matrix->columns, matrix->rows);
for (unsigned i = 0; i < result.elements.size(); ++i)
CALL_MOVE(result.elements[i], Multiply, elementType, { matrix->elements[i], scalar });
return { { result } };
}
return constantBinaryOperation<Constraints::Number>(arguments, [&]<typename T>(T left, T right) -> T {
return left * right;
});
}
BINARY_OPERATION(Divide, Number, [&]<typename T>(T left, T right) -> ConstantResult {
if constexpr (std::is_integral_v<T>) {
if (!right)
return makeUnexpected("invalid division by zero"_s);
if constexpr (std::is_signed_v<T>) {
if (left == std::numeric_limits<T>::lowest() && right == -1)
return makeUnexpected("invalid division overflow"_s);
}
}
return { { static_cast<T>(left / right) } };
});
BINARY_OPERATION(Modulo, Number, [&]<typename T>(T left, T right) -> ConstantResult {
if constexpr (std::is_floating_point_v<decltype(left)> || std::is_same_v<T, half>)
return { { static_cast<T>(fmod(left, right)) } };
else {
if (!right)
return makeUnexpected("invalid modulo by zero"_s);
if constexpr (std::is_signed_v<T>) {
if (left == std::numeric_limits<T>::lowest() && right == -1)
return makeUnexpected("invalid modulo overflow"_s);
}
return { { left % right } };
}
});
// Comparison Operations
BINARY_OPERATION(Equal, Scalar, [&](auto left, auto right) { return left == right; })
BINARY_OPERATION(NotEqual, Scalar, [&](auto left, auto right) { return left != right; })
BINARY_OPERATION(Lt, Scalar, [&](auto left, auto right) { return left < right; })
BINARY_OPERATION(LtEq, Scalar, [&](auto left, auto right) { return left <= right; })
BINARY_OPERATION(Gt, Scalar, [&](auto left, auto right) { return left > right; })
BINARY_OPERATION(GtEq, Scalar, [&](auto left, auto right) { return left >= right; })
// Logical Expressions
UNARY_OPERATION(Not, Bool, [&](bool arg) { return !arg; })
BINARY_OPERATION(Or, Bool, [&](bool left, bool right) { return left || right; })
BINARY_OPERATION(And, Bool, [&](bool left, bool right) { return left && right; })
// Bit Expressions
CONSTANT_FUNCTION(BitwiseOr)
{
UNUSED_PARAM(resultType);
return scalarOrVector([&](const auto& left, auto& right) -> ConstantValue {
if (auto* leftBool = std::get_if<bool>(&left))
return static_cast<bool>(static_cast<int>(*leftBool) | static_cast<int>(std::get<bool>(right)));
if (auto* leftI32 = std::get_if<int32_t>(&left))
return *leftI32 | std::get<int32_t>(right);
if (auto* leftU32 = std::get_if<uint32_t>(&left))
return *leftU32 | std::get<uint32_t>(right);
if (auto* leftAbstractInt = std::get_if<int64_t>(&left))
return *leftAbstractInt | std::get<int64_t>(right);
RELEASE_ASSERT_NOT_REACHED();
}, arguments[0], arguments[1]);
}
CONSTANT_FUNCTION(BitwiseAnd)
{
UNUSED_PARAM(resultType);
return scalarOrVector([&](const auto& left, auto& right) -> ConstantValue {
if (auto* leftBool = std::get_if<bool>(&left))
return static_cast<bool>(static_cast<int>(*leftBool) & static_cast<int>(std::get<bool>(right)));
if (auto* leftI32 = std::get_if<int32_t>(&left))
return *leftI32 & std::get<int32_t>(right);
if (auto* leftU32 = std::get_if<uint32_t>(&left))
return *leftU32 & std::get<uint32_t>(right);
if (auto* leftAbstractInt = std::get_if<int64_t>(&left))
return *leftAbstractInt & std::get<int64_t>(right);
RELEASE_ASSERT_NOT_REACHED();
}, arguments[0], arguments[1]);
}
UNARY_OPERATION(BitwiseNot, Integer, [&](auto arg) { return ~arg; })
BINARY_OPERATION(BitwiseXor, Integer, [&](auto left, auto right) { return left ^ right; })
CONSTANT_FUNCTION(BitwiseShiftLeft)
{
// We can't use a BINARY_OPERATION here since the arguments might not all have the same type
// i.e. we accept (u32, u32) as well as (i32, u32)
UNUSED_PARAM(resultType);
ASSERT(arguments.size() == 2);
const auto& shift = [&]<typename T>(T left, uint32_t right) -> ConstantResult {
constexpr auto bitSize = sizeof(T) * 8;
if (right >= bitSize)
return makeUnexpected(makeString("shift left value must be less than the bit width of the shifted value, which is "_s, bitSize));
if constexpr (std::is_unsigned_v<T>) {
uint64_t mask = -1ull << (bitSize - right);
if (left & mask)
return makeUnexpected("shift left overflows"_s);
} else {
uint64_t mask = -1ull << (bitSize - (right + 1));
auto leftBits = left & mask;
if (leftBits && leftBits != mask)
return makeUnexpected("shift left overflows"_s);
}
return { left << right };
};
return scalarOrVector([&](const auto& left, const auto& rightValue) -> ConstantResult {
auto right = std::get<uint32_t>(rightValue);
if (auto* i32 = std::get_if<int32_t>(&left))
return shift(*i32, right);
if (auto* u32 = std::get_if<uint32_t>(&left))
return shift(*u32, right);
if (auto* abstractInt = std::get_if<int64_t>(&left))
return shift(*abstractInt, right);
RELEASE_ASSERT_NOT_REACHED();
}, arguments[0], arguments[1]);
}
CONSTANT_FUNCTION(BitwiseShiftRight)
{
// We can't use a BINARY_OPERATION here since the arguments might not all have the same type
// i.e. we accept (u32, u32) as well as (i32, u32)
UNUSED_PARAM(resultType);
ASSERT(arguments.size() == 2);
const auto& shift = [&]<typename T>(T left, uint32_t right) -> ConstantResult {
constexpr auto bitSize = sizeof(T) * 8;
if (right >= bitSize)
return makeUnexpected(makeString("shift right value must be less than the bit width of the shifted value, which is "_s, bitSize));
return { left >> right };
};
return scalarOrVector([&](const auto& left, const auto& rightValue) -> ConstantResult {
auto right = std::get<uint32_t>(rightValue);
if (auto* i32 = std::get_if<int32_t>(&left))
return shift(*i32, right);
if (auto* u32 = std::get_if<uint32_t>(&left))
return shift(*u32, right);
if (auto* abstractInt = std::get_if<int64_t>(&left))
return shift(*abstractInt, right);
RELEASE_ASSERT_NOT_REACHED();
}, arguments[0], arguments[1]);
}
// Constructors
CONSTANT_CONSTRUCTOR(Bool, bool)
CONSTANT_CONSTRUCTOR(F32, float)
CONSTANT_CONSTRUCTOR(F16, half)
CONSTANT_FUNCTION(I32)
{
if (arguments.size()) {
if (auto* abstractInt = std::get_if<int64_t>(&arguments[0])) {
auto result = convertInteger<int32_t>(*abstractInt);
if (result)
return { *result };
return makeUnexpected(makeString("value "_s, *abstractInt, " cannot be represented as 'i32'"_s));
}
}
return { constantConstructor<int32_t>(resultType, arguments) };
}
CONSTANT_FUNCTION(U32)
{
if (arguments.size()) {
if (auto* abstractInt = std::get_if<int64_t>(&arguments[0])) {
auto result = convertInteger<uint32_t>(*abstractInt);
if (result)
return { *result };
return makeUnexpected(makeString("value "_s, *abstractInt, " cannot be represented as 'u32'"_s));
}
}
return { constantConstructor<uint32_t>(resultType, arguments) };
}
CONSTANT_FUNCTION(Vec2)
{
return constantVector(resultType, arguments, 2);
}
CONSTANT_FUNCTION(Vec3)
{
return constantVector(resultType, arguments, 3);
}
CONSTANT_FUNCTION(Vec4)
{
return constantVector(resultType, arguments, 4);
}
MATRIX_CONSTRUCTOR(2, 2);
MATRIX_CONSTRUCTOR(2, 3);
MATRIX_CONSTRUCTOR(2, 4);
MATRIX_CONSTRUCTOR(3, 2);
MATRIX_CONSTRUCTOR(3, 3);
MATRIX_CONSTRUCTOR(3, 4);
MATRIX_CONSTRUCTOR(4, 2);
MATRIX_CONSTRUCTOR(4, 3);
MATRIX_CONSTRUCTOR(4, 4);
// Logical Built-in Functions
CONSTANT_FUNCTION(All)
{
UNUSED_PARAM(resultType);
ASSERT(arguments.size() == 1);
const auto& arg = arguments[0];
if (arg.isBool())
return arg;
ASSERT(arg.isVector());
for (auto element : arg.toVector().elements) {
if (!element.toBool())
return { { false } };
}
return { { true } };
}
CONSTANT_FUNCTION(Any)
{
UNUSED_PARAM(resultType);
ASSERT(arguments.size() == 1);
const auto& arg = arguments[0];
if (arg.isBool())
return arg;
ASSERT(arg.isVector());
for (auto element : arg.toVector().elements) {
if (element.toBool())
return { { true } };
}
return { { false } };
}
CONSTANT_FUNCTION(Select)
{
UNUSED_PARAM(resultType);
ASSERT(arguments.size() == 3);
const auto& falseValue = arguments[0];
const auto& trueValue = arguments[1];
const auto& condition = arguments[2];
if (condition.isBool())
return condition.toBool() ? trueValue : falseValue;
ASSERT(falseValue.isVector());
ASSERT(trueValue.isVector());
ASSERT(condition.isVector());
const auto& falseVector = arguments[0].toVector();
const auto& trueVector = arguments[1].toVector();
const auto& conditionVector = arguments[2].toVector();
auto size = conditionVector.elements.size();
ASSERT(size == falseVector.elements.size());
ASSERT(size == trueVector.elements.size());
ConstantVector result(size);
for (unsigned i = 0; i < size; ++i) {
result.elements[i] = conditionVector.elements[i].toBool()
? trueVector.elements[i]
: falseVector.elements[i];
}
return { { result } };
}
// Numeric Built-in Functions
CONSTANT_TRIGONOMETRIC(Acos, acos);
CONSTANT_TRIGONOMETRIC(Asin, asin);
CONSTANT_TRIGONOMETRIC(Atan, atan);
CONSTANT_TRIGONOMETRIC(Cos, cos);
CONSTANT_TRIGONOMETRIC(Sin, sin);
CONSTANT_TRIGONOMETRIC(Tan, tan);
CONSTANT_TRIGONOMETRIC(Acosh, acosh);
CONSTANT_TRIGONOMETRIC(Asinh, asinh);
CONSTANT_TRIGONOMETRIC(Atanh, atanh);
CONSTANT_TRIGONOMETRIC(Cosh, cosh);
CONSTANT_TRIGONOMETRIC(Sinh, sinh);
CONSTANT_TRIGONOMETRIC(Tanh, tanh);
UNARY_OPERATION(Abs, Number, [&]<typename T>(T n) -> T {
if constexpr (std::is_same_v<decltype(n), uint32_t>)
return n;
else
return std::abs(n);
});
BINARY_OPERATION(Atan2, Float, WRAP_STD(atan2))
UNARY_OPERATION(Ceil, Float, WRAP_STD(ceil))
TERNARY_OPERATION(Clamp, Number, [&](auto e, auto low, auto high) { return std::min(std::max(e, low), high); })
UNARY_OPERATION(CountLeadingZeros, ConcreteInteger, [&]<typename T>(T arg) -> T {
return std::countl_zero(static_cast<unsigned>(arg));
})
UNARY_OPERATION(CountOneBits, ConcreteInteger, [&]<typename T>(T arg) -> T {
return std::popcount(static_cast<unsigned>(arg));
})
UNARY_OPERATION(CountTrailingZeros, ConcreteInteger, [&]<typename T>(T arg) -> T {
return std::countr_zero(static_cast<unsigned>(arg));
})
CONSTANT_FUNCTION(Cross)
{
UNUSED_PARAM(resultType);
ASSERT(arguments.size() == 2);
auto& lhs = arguments[0].toVector();
auto& rhs = arguments[1].toVector();
const auto& cross = [&]<typename T>() -> ConstantResult {
auto u0 = std::get<T>(lhs.elements[0]);
auto u1 = std::get<T>(lhs.elements[1]);
auto u2 = std::get<T>(lhs.elements[2]);
auto v0 = std::get<T>(rhs.elements[0]);
auto v1 = std::get<T>(rhs.elements[1]);
auto v2 = std::get<T>(rhs.elements[2]);
ConstantVector result(3);
result.elements[0] = static_cast<T>(u1 * v2 - u2 * v1);
result.elements[1] = static_cast<T>(u2 * v0 - u0 * v2);
result.elements[2] = static_cast<T>(u0 * v1 - u1 * v0);
return { { result } };
};
if (std::holds_alternative<float>(lhs.elements[0]))
return cross.operator()<float>();
if (std::holds_alternative<half>(lhs.elements[0]))
return cross.operator()<half>();
if (std::holds_alternative<double>(lhs.elements[0]))
return cross.operator()<double>();
RELEASE_ASSERT_NOT_REACHED();
}
UNARY_OPERATION(Degrees, Float, [&]<typename T>(T arg) -> T { return arg * (180 / std::numbers::pi); })
CONSTANT_FUNCTION(Determinant)
{
UNUSED_PARAM(resultType);
ASSERT(arguments.size() == 1);
auto& matrix = std::get<ConstantMatrix>(arguments[0]);
auto columns = matrix.columns;
auto solve2 = [&]<typename T>(
T a, T b,
T c, T d
) -> T {
return a * d - b * c;
};
auto solve3 = [&]<typename T>(
T a, T b, T c,
T d, T e, T f,
T g, T h, T i
) -> T {
return a * e * i + b * f * g + c * d * h - c * e * g - b * d * i - a * f * h;
};
auto solve4 = [&]<typename T>(
T a, T b, T c, T d,
T e, T f, T g, T h,
T i, T j, T k, T l,
T m, T n, T o, T p
) -> T {
return a * solve3(f, g, h, j, k, l, n, o, p) - b * solve3(e, g, h, i, k, l, m, o, p) + c * solve3(e, f, h, i, j, l, m, n, p) - d * solve3(e, f, g, i, j, k, m, n, o);
};
const auto& determinant = [&]<typename T>() {
switch (columns) {
case 2:
return solve2(
std::get<T>(matrix.elements[0]), std::get<T>(matrix.elements[2]),
std::get<T>(matrix.elements[1]), std::get<T>(matrix.elements[3])
);
case 3:
return solve3(
std::get<T>(matrix.elements[0]), std::get<T>(matrix.elements[3]), std::get<T>(matrix.elements[6]),
std::get<T>(matrix.elements[1]), std::get<T>(matrix.elements[4]), std::get<T>(matrix.elements[7]),
std::get<T>(matrix.elements[2]), std::get<T>(matrix.elements[5]), std::get<T>(matrix.elements[8])
);
case 4:
return solve4(
std::get<T>(matrix.elements[0]), std::get<T>(matrix.elements[4]), std::get<T>(matrix.elements[8]), std::get<T>(matrix.elements[12]),
std::get<T>(matrix.elements[1]), std::get<T>(matrix.elements[5]), std::get<T>(matrix.elements[9]), std::get<T>(matrix.elements[13]),
std::get<T>(matrix.elements[2]), std::get<T>(matrix.elements[6]), std::get<T>(matrix.elements[10]), std::get<T>(matrix.elements[14]),
std::get<T>(matrix.elements[3]), std::get<T>(matrix.elements[7]), std::get<T>(matrix.elements[11]), std::get<T>(matrix.elements[15])
);
default:
RELEASE_ASSERT_NOT_REACHED();
}
};
if (std::holds_alternative<float>(matrix.elements[0]))
return { { determinant.operator()<float>() } };
if (std::holds_alternative<half>(matrix.elements[0]))
return { { determinant.operator()<half>() } };
if (std::holds_alternative<double>(matrix.elements[0]))
return { { determinant.operator()<double>() } };
RELEASE_ASSERT_NOT_REACHED();
}
CONSTANT_FUNCTION(Length);
CONSTANT_FUNCTION(Distance)
{
// we can't produce the type for the intermediate computation here. Pass
// nullptr instead so we crash if we ever start depending on it.
CALL(minus, Minus, nullptr, arguments);
return constantLength(resultType, { minus });
}
CONSTANT_FUNCTION(Dot)
{
CALL(product, Multiply, resultType, arguments);
ConstantValue result = zeroValue(resultType);
for (auto& element : product.toVector().elements)
CALL_MOVE(result, Add, resultType, { result, element });
return { { result } };
}
CONSTANT_FUNCTION(Dot4U8Packed)
{
UNUSED_PARAM(resultType);
auto lhs = std::bit_cast<std::array<uint8_t, 4>>(std::get<uint32_t>(arguments[0]));
auto rhs = std::bit_cast<std::array<uint8_t, 4>>(std::get<uint32_t>(arguments[1]));
uint32_t result = lhs[0] * rhs[0] + lhs[1] * rhs[1] + lhs[2] * rhs[2] + lhs[3] * rhs[3];
return { { result } };
}
CONSTANT_FUNCTION(Dot4I8Packed)
{
UNUSED_PARAM(resultType);
auto lhs = std::bit_cast<std::array<int8_t, 4>>(std::get<uint32_t>(arguments[0]));
auto rhs = std::bit_cast<std::array<int8_t, 4>>(std::get<uint32_t>(arguments[1]));
int32_t result = lhs[0] * rhs[0] + lhs[1] * rhs[1] + lhs[2] * rhs[2] + lhs[3] * rhs[3];
return { { result } };
}
CONSTANT_FUNCTION(Sqrt);
CONSTANT_FUNCTION(Length)
{
ASSERT(arguments.size() == 1);
const auto& arg = arguments[0];
if (!arg.isVector())
return constantAbs(resultType, arguments);
ConstantValue result = zeroValue(resultType);
for (auto& element : arg.toVector().elements) {
CALL(tmp, Multiply, resultType, { element, element });
CALL_MOVE(result, Add, resultType, { result, tmp });
}
return constantSqrt(resultType, { result });
}
UNARY_OPERATION(Exp, Float, WRAP_STD(exp));
UNARY_OPERATION(Exp2, Float, WRAP_STD(exp2));
CONSTANT_FUNCTION(ExtractBits)
{
// We can't use a TERNARY_OPERATION here since the arguments might not all have the same type
// i.e. we accept (u32, u32, u32) as well as (i32, u32, u32)
UNUSED_PARAM(resultType);
ASSERT(arguments.size() == 3);
auto offset = std::get<uint32_t>(arguments[1]);
auto count = std::get<uint32_t>(arguments[2]);
const auto& extractBits = [&]<typename T>(T e) {
constexpr unsigned w = 32;
unsigned o = std::min(offset, w);
unsigned c = std::min(count, w - o);
if (!c)
return static_cast<T>(0);
if (c == w)
return e;
unsigned srcMask = ((1 << c) - 1) << o;
T result = (e & srcMask) >> o;
if constexpr (std::is_signed_v<T>) {
if (result & (1 << (c - 1))) {
unsigned dstMask = srcMask >> o;
result |= (~0 & ~dstMask);
}
}
return result;
};
return scalarOrVector([&](const auto& e) -> ConstantValue {
if (auto* i32 = std::get_if<int32_t>(&e))
return extractBits(*i32);
if (auto* u32 = std::get_if<uint32_t>(&e))
return extractBits(*u32);
RELEASE_ASSERT_NOT_REACHED();
}, arguments[0]);
}
CONSTANT_FUNCTION(FaceForward)
{
ASSERT(arguments.size() == 3);
const auto& e1 = arguments[0];
const auto& e2 = arguments[1];
const auto& e3 = arguments[2];
auto* elementType = std::get<Types::Vector>(*resultType).element;
CALL(dot, Dot, elementType, { e2, e3 });
CALL(lt, Lt, elementType, { dot, zeroValue(elementType) });
if (lt.toBool())
return e1;
return constantMinus(resultType, { e1 });
}
UNARY_OPERATION(FirstLeadingBit, ConcreteInteger, [&](auto e) {
if constexpr (std::is_same_v<decltype(e), int32_t>) {
unsigned ue = e;
if (!e || e == -1)
return -1;
unsigned count = e < 0 ? std::countl_one(ue) : std::countl_zero(ue);
return static_cast<int32_t>(31 - count);
} else {
if (!e)
return static_cast<uint32_t>(-1);
return static_cast<uint32_t>(31 - std::countl_zero(e));
}
})
UNARY_OPERATION(FirstTrailingBit, ConcreteInteger, [&](auto e) {
if (!e)
return static_cast<decltype(e)>(-1);
unsigned ue = e;
return static_cast<decltype(e)>(std::countr_zero(ue));
})
UNARY_OPERATION(Floor, Float, WRAP_STD(floor));
CONSTANT_FUNCTION(Fma)
{
ASSERT(arguments.size() == 3);
CALL(product, Multiply, resultType, { arguments[0], arguments[1] });
return constantAdd(resultType, { product, arguments[2] });
}
CONSTANT_FUNCTION(Fract)
{
ASSERT(arguments.size() == 1);
const auto& arg = arguments[0];
CALL(floor, Floor, resultType, { arg });
return constantMinus(resultType, { arg, floor });
}
CONSTANT_FUNCTION(Frexp)
{
UNUSED_PARAM(resultType);
ASSERT(arguments.size() == 1);
const auto& frexpValue = [&](auto value) -> std::tuple<ConstantValue, ConstantValue> {
using T = decltype(value);
using Exp = std::conditional_t<std::is_same_v<T, double>, int64_t, int>;
int exp;
auto fract = std::frexp(value, &exp);
return { ConstantValue(static_cast<T>(fract)), ConstantValue(static_cast<Exp>(exp)) };
};
const auto& frexpScalar = [&](auto value) {
if (auto* f32 = std::get_if<float>(&value))
return frexpValue(*f32);
if (auto* f16 = std::get_if<half>(&value))
return frexpValue(*f16);
if (auto* abstractFloat = std::get_if<double>(&value))
return frexpValue(*abstractFloat);
RELEASE_ASSERT_NOT_REACHED();
};
auto [fract, exp] = [&]() -> std::tuple<ConstantValue, ConstantValue> {
auto& arg = arguments[0];
if (!std::holds_alternative<ConstantVector>(arg))
return frexpScalar(arg);
auto& argVector = std::get<ConstantVector>(arg);
auto size = argVector.elements.size();
ConstantVector fractVector(size);
ConstantVector expVector(size);
for (unsigned i = 0; i < size; ++i) {
auto [fract, exp] = frexpScalar(argVector.elements[i]);
fractVector.elements[i] = fract;
expVector.elements[i] = exp;
}
return { fractVector, expVector };
}();
return { ConstantStruct({ { { "fract"_s, fract }, { "exp"_s, exp } } }) };
}
CONSTANT_FUNCTION(InsertBits)
{
UNUSED_PARAM(resultType);
return scalarOrVector([&](auto eValue, auto newbitsValue, auto offset, auto count) -> ConstantValue {
constexpr unsigned w = 32;
unsigned o = std::min(std::get<uint32_t>(offset), w);
unsigned c = std::min(std::get<uint32_t>(count), w - o);
if (!c)
return eValue;
if (c == w)
return newbitsValue;
const auto& fn = [&](auto e, auto newbits) {
auto from = newbits << o;
auto mask = ((1 << c) - 1) << o;
auto result = e;
result &= ~mask;
result |= (from & mask);
return result;
};
if (auto* e = std::get_if<int32_t>(&eValue))
return fn(*e, std::get<int32_t>(newbitsValue));
return fn(std::get<uint32_t>(eValue), std::get<uint32_t>(newbitsValue));
}, arguments[0], arguments[1], arguments[2], arguments[3]);
}
UNARY_OPERATION(InverseSqrt, Float, [&]<typename T>(T arg) -> T {
return 1.0 / std::sqrt(arg);
})
CONSTANT_FUNCTION(Ldexp)
{
UNUSED_PARAM(resultType);
return scalarOrVector([&](const auto& e1, auto& e2) -> ConstantResult {
if (auto* abstractE1 = std::get_if<double>(&e1)) {
auto abstractE2 = std::get<int64_t>(e2);
constexpr int64_t bias = 1023;
if (abstractE2 + bias <= 0)
return { static_cast<double>(0) };
if (abstractE2 > bias + 1)
return makeUnexpected(makeString("e2 must be less than or equal to "_s, bias + 1));
return { std::ldexp(*abstractE1, abstractE2) };
}
auto i32E2 = std::get<int32_t>(e2);
if (auto* f32E1 = std::get_if<float>(&e1)) {
constexpr int32_t bias = 127;
if (i32E2 + bias <= 0)
return { static_cast<float>(0) };
if (i32E2 > bias + 1)
return makeUnexpected(makeString("e2 must be less than or equal to "_s, bias + 1));
return { std::ldexp(*f32E1, i32E2) };
}
if (auto* f16E1 = std::get_if<half>(&e1)) {
constexpr int32_t bias = 15;
if (i32E2 + bias <= 0)
return { static_cast<half>(0) };
if (i32E2 > bias + 1)
return makeUnexpected(makeString("e2 must be less than or equal to "_s, bias + 1));
return { static_cast<half>(std::ldexp(*f16E1, i32E2)) };
}
RELEASE_ASSERT_NOT_REACHED();
}, arguments[0], arguments[1]);
}
UNARY_OPERATION(Log, Float, WRAP_STD(log))
UNARY_OPERATION(Log2, Float, WRAP_STD(log2))
BINARY_OPERATION(Max, Number, WRAP_STD(max))
BINARY_OPERATION(Min, Number, WRAP_STD(min))
TERNARY_OPERATION(Mix, Number, [&]<typename T>(T e1, T e2, T e3) -> T { return e1 * (1 - e3) + e2 * e3; })
CONSTANT_FUNCTION(Modf)
{
UNUSED_PARAM(resultType);
ASSERT(arguments.size() == 1);
const auto& modfValue = [&](auto value) -> std::tuple<ConstantValue, ConstantValue> {
using T = decltype(value);
using Whole = std::conditional_t<std::is_same_v<T, half>, float, T>;
Whole whole;
T fract = std::modf(value, &whole);
return { ConstantValue(fract), ConstantValue(static_cast<T>(whole)) };
};
const auto& modfScalar = [&](auto value) {
if (auto* f32 = std::get_if<float>(&value))
return modfValue(*f32);
if (auto* f16 = std::get_if<half>(&value))
return modfValue(*f16);
if (auto* abstractFloat = std::get_if<double>(&value))
return modfValue(*abstractFloat);
RELEASE_ASSERT_NOT_REACHED();
};
auto [fract, whole] = [&]() -> std::tuple<ConstantValue, ConstantValue> {
auto& arg = arguments[0];
if (!std::holds_alternative<ConstantVector>(arg))
return modfScalar(arg);
auto& argVector = std::get<ConstantVector>(arg);
auto size = argVector.elements.size();
ConstantVector fractVector(size);
ConstantVector wholeVector(size);
for (unsigned i = 0; i < size; ++i) {
auto [fract, whole] = modfScalar(argVector.elements[i]);
fractVector.elements[i] = fract;
wholeVector.elements[i] = whole;
}
return { fractVector, wholeVector };
}();
return { ConstantStruct({ { { "fract"_s, fract }, { "whole"_s, whole } } }) };
}
CONSTANT_FUNCTION(Normalize)
{
ASSERT(arguments.size() == 1);
auto* elementType = std::get<Types::Vector>(*resultType).element;
const auto& arg = arguments[0];
CALL(length, Length, elementType, arguments);
return constantDivide(resultType, { arg, length });
}
BINARY_OPERATION(Pow, Float, WRAP_STD(pow))
UNARY_OPERATION(QuantizeToF16, F32, [&](float arg) -> ConstantResult {
UNUSED_PARAM(resultType);
auto converted = convertFloat<half>(arg);
if (!converted)
return makeUnexpected(makeString("value "_s, String::number(arg), " cannot be represented as 'f16'"_s));
return { { static_cast<float>(*converted) } };
});
UNARY_OPERATION(Radians, Float, [&]<typename T>(T arg) -> T {
return arg * (std::numbers::pi / static_cast<T>(180));
})
CONSTANT_FUNCTION(Reflect)
{
ASSERT(arguments.size() == 2);
const auto& e1 = arguments[0];
const auto& e2 = arguments[1];
auto* elementType = std::get<Types::Vector>(*resultType).element;
auto& primitive = std::get<Types::Primitive>(*elementType);
const auto& reflect = [&]<typename T>() -> ConstantResult {
CALL(dot, Dot, elementType, { e2, e1 });
CALL(doubleResult, Multiply, resultType, { static_cast<T>(2), dot });
CALL(prod, Multiply, resultType, { doubleResult, e2 });
return constantMinus(resultType, { e1, prod });
};
if (primitive.kind == Types::Primitive::F32)
return reflect.operator()<float>();
if (primitive.kind == Types::Primitive::F16)
return reflect.operator()<half>();
if (primitive.kind == Types::Primitive::AbstractFloat)
return reflect.operator()<double>();
RELEASE_ASSERT_NOT_REACHED();
}
CONSTANT_FUNCTION(Refract)
{
ASSERT(arguments.size() == 3);
const auto& e1 = arguments[0];
const auto& e2 = arguments[1];
const auto& e3 = arguments[2];
const auto& refract = [&]<typename T>(T e3) -> ConstantResult {
auto* elementType = std::get<Types::Vector>(*resultType).element;
CALL(dot, Dot, elementType, { e2, e1 });
CALL(pow, Pow, elementType, { dot, static_cast<T>(2.0) });
CALL(sub, Minus, elementType, { static_cast<T>(1.0), pow });
CALL(pow2, Pow, elementType, { e3, static_cast<T>(2.0) });
CALL(mul, Multiply, elementType, { pow2, sub });
CALL(k, Minus, elementType, { static_cast<T>(1.0), mul });
CALL(lt, Lt, elementType, { k, static_cast<T>(0.0) });
if (lt.toBool())
return zeroValue(resultType);
{
CALL(mul, Multiply, elementType, { e3, dot });
CALL(sqrt, Sqrt, elementType, { k });
CALL(sum, Add, elementType, { mul, sqrt });
CALL(mul2, Multiply, resultType, { e3, e1 });
CALL(mul3, Multiply, resultType, { sum, e2 });
return constantMinus(resultType, { mul2, mul3 });
}
};
if (auto* f32 = std::get_if<float>(&e3))
return refract(*f32);
if (auto* f16 = std::get_if<half>(&e3))
return refract(*f16);
if (auto* abstractFloat = std::get_if<double>(&e3))
return refract(*abstractFloat);
RELEASE_ASSERT_NOT_REACHED();
}
UNARY_OPERATION(ReverseBits, Integer, [&]<typename T>(T e) -> T {
unsigned v = e;
T result = 0;
for (unsigned k = 0; k < 32; ++k)
result |= !!(v & (1 << (31 - k))) << k;
return result;
})
UNARY_OPERATION(Round, Float, WRAP_STD(rint))
UNARY_OPERATION(Saturate, Float, [&](auto e) {
return std::min(std::max(e, static_cast<decltype(e)>(0)), static_cast<decltype(e)>(1));
})
UNARY_OPERATION(Sign, Number, [&]<typename T>(T e) -> T {
T result;
if (e > 0)
result = 1;
else if (e < 0)
result = -1;
else
result = 0;
return result;
})
TERNARY_OPERATION(Smoothstep, Float, [&](auto low, auto high, auto x) {
auto e = (x - low) / (high - low);
auto t = std::min(std::max(e, static_cast<decltype(e)>(0)), static_cast<decltype(e)>(1));
return t * t * (3.0 - 2.0 * t);
})
UNARY_OPERATION(Sqrt, Float, WRAP_STD(sqrt))
BINARY_OPERATION(Step, Float, [&](auto edge, auto x) { return edge <= x ? 1.0 : 0.0; })
CONSTANT_FUNCTION(Transpose)
{
UNUSED_PARAM(resultType);
ASSERT(arguments.size() == 1);
auto& matrix = std::get<ConstantMatrix>(arguments[0]);
auto columns = matrix.columns;
auto rows = matrix.rows;
ConstantMatrix result(rows, columns);
for (unsigned j = 0; j < rows; ++j) {
for (unsigned i = 0; i < columns; ++i)
result.elements[j * columns + i] = matrix.elements[i * rows + j];
}
return { { result } };
}
UNARY_OPERATION(Trunc, Float, WRAP_STD(trunc))
// Data Packing
CONSTANT_FUNCTION(Pack4x8snorm)
{
UNUSED_PARAM(resultType);
auto& vector = std::get<ConstantVector>(arguments[0]);
std::array<int8_t, 4> packed;
for (unsigned i = 0; i < 4; ++i) {
auto e = std::get<float>(vector.elements[i]);
packed[i] = static_cast<int8_t>(std::floor(0.5 + 127 * std::min(1.f, std::max(-1.f, e))));
}
return { { std::bit_cast<uint32_t>(packed) } };
}
CONSTANT_FUNCTION(Pack4x8unorm)
{
UNUSED_PARAM(resultType);
auto& vector = std::get<ConstantVector>(arguments[0]);
std::array<uint8_t, 4> packed;
for (unsigned i = 0; i < 4; ++i) {
auto e = std::get<float>(vector.elements[i]);
packed[i] = static_cast<uint8_t>(std::floor(0.5 + 255 * std::min(1.f, std::max(0.f, e))));
}
return { { std::bit_cast<uint32_t>(packed) } };
}
CONSTANT_FUNCTION(Pack4xI8)
{
UNUSED_PARAM(resultType);
auto& vector = std::get<ConstantVector>(arguments[0]);
std::array<uint8_t, 4> packed;
for (unsigned i = 0; i < 4; ++i) {
auto e = std::get<int32_t>(vector.elements[i]);
packed[i] = static_cast<uint8_t>(e);
}
return { { std::bit_cast<uint32_t>(packed) } };
}
CONSTANT_FUNCTION(Pack4xU8)
{
UNUSED_PARAM(resultType);
auto& vector = std::get<ConstantVector>(arguments[0]);
std::array<uint8_t, 4> packed;
for (unsigned i = 0; i < 4; ++i) {
auto e = std::get<uint32_t>(vector.elements[i]);
packed[i] = static_cast<uint8_t>(e);
}
return { { std::bit_cast<uint32_t>(packed) } };
}
CONSTANT_FUNCTION(Pack4xI8Clamp)
{
UNUSED_PARAM(resultType);
auto& vector = std::get<ConstantVector>(arguments[0]);
std::array<uint8_t, 4> packed;
for (unsigned i = 0; i < 4; ++i) {
auto e = std::get<int32_t>(vector.elements[i]);
packed[i] = static_cast<uint8_t>(std::min(127, std::max(-128, e)));
}
return { { std::bit_cast<uint32_t>(packed) } };
}
CONSTANT_FUNCTION(Pack4xU8Clamp)
{
UNUSED_PARAM(resultType);
auto& vector = std::get<ConstantVector>(arguments[0]);
std::array<uint8_t, 4> packed;
for (unsigned i = 0; i < 4; ++i) {
auto e = std::get<uint32_t>(vector.elements[i]);
packed[i] = static_cast<uint8_t>(std::min(255u, e));
}
return { { std::bit_cast<uint32_t>(packed) } };
}
CONSTANT_FUNCTION(Pack2x16snorm)
{
UNUSED_PARAM(resultType);
auto& vector = std::get<ConstantVector>(arguments[0]);
std::array<int16_t, 2> packed;
for (unsigned i = 0; i < 2; ++i) {
auto e = std::get<float>(vector.elements[i]);
packed[i] = static_cast<int16_t>(std::floor(0.5 + 32767 * std::min(1.f, std::max(-1.f, e))));
}
return { { std::bit_cast<uint32_t>(packed) } };
}
CONSTANT_FUNCTION(Pack2x16unorm)
{
UNUSED_PARAM(resultType);
auto& vector = std::get<ConstantVector>(arguments[0]);
std::array<uint16_t, 2> packed;
for (unsigned i = 0; i < 2; ++i) {
auto e = std::get<float>(vector.elements[i]);
packed[i] = static_cast<uint16_t>(std::floor(0.5 + 65535 * std::min(1.f, std::max(0.f, e))));
}
return { { std::bit_cast<uint32_t>(packed) } };
}
CONSTANT_FUNCTION(Pack2x16float)
{
UNUSED_PARAM(resultType);
auto& vector = std::get<ConstantVector>(arguments[0]);
std::array<half, 2> packed;
for (unsigned i = 0; i < 2; ++i) {
auto e = std::get<float>(vector.elements[i]);
auto converted = convertFloat<half>(e);
if (!converted)
return makeUnexpected(makeString("value "_s, e, " cannot be represented as 'f16'"_s));
packed[i] = *converted;
}
return { { std::bit_cast<uint32_t>(packed) } };
}
// Data Unpacking
CONSTANT_FUNCTION(Unpack4x8snorm)
{
UNUSED_PARAM(resultType);
auto argument = std::get<uint32_t>(arguments[0]);
auto packed = std::bit_cast<std::array<int8_t, 4>>(argument);
ConstantVector result(4);
for (unsigned i = 0; i < 4; ++i) {
auto e = static_cast<float>(packed[i]);
result.elements[i] = std::max(e / 127.f, -1.f);
}
return { result };
}
CONSTANT_FUNCTION(Unpack4x8unorm)
{
UNUSED_PARAM(resultType);
auto argument = std::get<uint32_t>(arguments[0]);
auto packed = std::bit_cast<std::array<uint8_t, 4>>(argument);
ConstantVector result(4);
for (unsigned i = 0; i < 4; ++i) {
auto e = static_cast<float>(packed[i]);
result.elements[i] = e / 255.f;
}
return { result };
}
CONSTANT_FUNCTION(Unpack4xI8)
{
UNUSED_PARAM(resultType);
auto argument = std::get<uint32_t>(arguments[0]);
auto packed = std::bit_cast<std::array<int8_t, 4>>(argument);
ConstantVector result(4);
for (unsigned i = 0; i < 4; ++i)
result.elements[i] = static_cast<int32_t>(packed[i]);
return { result };
}
CONSTANT_FUNCTION(Unpack4xU8)
{
UNUSED_PARAM(resultType);
auto argument = std::get<uint32_t>(arguments[0]);
auto packed = std::bit_cast<std::array<uint8_t, 4>>(argument);
ConstantVector result(4);
for (unsigned i = 0; i < 4; ++i)
result.elements[i] = static_cast<uint32_t>(packed[i]);
return { result };
}
CONSTANT_FUNCTION(Unpack2x16snorm)
{
UNUSED_PARAM(resultType);
auto argument = std::get<uint32_t>(arguments[0]);
auto packed = std::bit_cast<std::array<int16_t, 2>>(argument);
ConstantVector result(2);
for (unsigned i = 0; i < 2; ++i) {
auto e = static_cast<float>(packed[i]);
result.elements[i] = std::max(e / 32767.f, -1.f);
}
return { result };
}
CONSTANT_FUNCTION(Unpack2x16unorm)
{
UNUSED_PARAM(resultType);
auto argument = std::get<uint32_t>(arguments[0]);
auto packed = std::bit_cast<std::array<uint16_t, 2>>(argument);
ConstantVector result(2);
for (unsigned i = 0; i < 2; ++i) {
auto e = static_cast<float>(packed[i]);
result.elements[i] = e / 65535.f;
}
return { result };
}
CONSTANT_FUNCTION(Unpack2x16float)
{
UNUSED_PARAM(resultType);
auto argument = std::get<uint32_t>(arguments[0]);
auto packed = std::bit_cast<std::array<half, 2>>(argument);
ConstantVector result(2);
for (unsigned i = 0; i < 2; ++i)
result.elements[i] = static_cast<float>(packed[i]);
return { result };
}
CONSTANT_FUNCTION(Bitcast)
{
const auto& split = [&](ConstantVector& result, const ConstantValue& argument, unsigned offset) -> std::optional<String> {
uint32_t value;
if (auto* i32 = std::get_if<int32_t>(&argument))
value = std::bit_cast<uint32_t>(*i32);
else if (auto* u32 = std::get_if<uint32_t>(&argument))
value = *u32;
else if (auto* f32 = std::get_if<float>(&argument))
value = std::bit_cast<uint32_t>(*f32);
else if (auto* abstractInt = std::get_if<int64_t>(&argument)) {
auto i32 = convertInteger<int32_t>(*abstractInt);
if (!i32)
return { makeString("value "_s, String::number(*abstractInt), " cannot be represented as 'i32'"_s) };
value = std::bit_cast<uint32_t>(*i32);
} else if (auto* abstractFloat = std::get_if<double>(&argument)) {
auto f32 = convertFloat<float>(*abstractFloat);
if (!f32)
return { makeString("value "_s, String::number(*abstractFloat), " cannot be represented as 'f32'"_s) };
value = std::bit_cast<uint32_t>(*f32);
} else {
RELEASE_ASSERT_NOT_REACHED();
value = 0;
}
auto parts = std::bit_cast<std::array<half, 2>>(value);
result.elements[offset] = parts[0];
result.elements[offset + 1] = parts[1];
return std::nullopt;
};
const auto& join = [&](const Type* type, const ConstantVector& vector, unsigned offset) -> ConstantValue {
uint32_t value = 0;
value |= std::bit_cast<uint16_t>(std::get<half>(vector.elements[offset]));
value |= static_cast<uint32_t>(std::bit_cast<uint16_t>(std::get<half>(vector.elements[offset + 1]))) << 16;
return convertValue<BitwiseCast>(type, value);
};
const auto& vectorVector = [&](const Types::Vector& dst, const ConstantVector& src) -> ConstantResult {
if (dst.size == src.elements.size()) {
return scalarOrVector([&](auto& value) {
return constantBitcast(dst.element, { value });
}, src);
}
ConstantVector result(dst.size);
if (dst.size == 4) {
if (auto error = split(result, src.elements[0], 0))
return makeUnexpected(*error);
if (auto error = split(result, src.elements[1], 2))
return makeUnexpected(*error);
} else {
result.elements[0] = join(dst.element, src, 0);
result.elements[1] = join(dst.element, src, 2);
}
return { result };
};
auto& argument = arguments[0];
if (auto* dstVector = std::get_if<Types::Vector>(resultType)) {
if (auto* srcVector = std::get_if<ConstantVector>(&argument))
return vectorVector(*dstVector, *srcVector);
RELEASE_ASSERT(dstVector->size == 2);
ConstantVector result(2);
split(result, argument, 0);
return { result };
}
if (auto* srcVector = std::get_if<ConstantVector>(&argument))
return { join(resultType, *srcVector, 0) };
if (auto* abstractInt = std::get_if<int64_t>(&argument)) {
auto& primitive = std::get<Types::Primitive>(*resultType);
if (primitive.kind == Types::Primitive::U32) {
auto result = convertInteger<uint32_t>(*abstractInt);
if (!result.has_value())
return makeUnexpected(makeString("value "_s, *abstractInt, " cannot be represented as 'u32'"_s));
return { *result };
}
auto result = convertInteger<int32_t>(*abstractInt);
if (!result.has_value())
return makeUnexpected(makeString("value "_s, *abstractInt, " cannot be represented as 'i32'"_s));
return { convertValue<BitwiseCast>(resultType, *result) };
}
if (auto* abstractFloat = std::get_if<double>(&argument)) {
auto result = convertFloat<float>(*abstractFloat);
if (!result.has_value())
return makeUnexpected(makeString("value "_s, *abstractFloat, " cannot be represented as 'f32'"_s));
return { convertValue<BitwiseCast>(resultType, *result) };
}
return { convertValue<BitwiseCast>(resultType, argument) };
}
// Type checker helpers
static bool containsZero(ConstantValue value, const Type* valueType)
{
auto wrapped = [&]() -> Expected<bool, String> {
auto zero = zeroValue(valueType);
CALL(equal, Equal, nullptr, { value, zero });
CALL(any, Any, nullptr, { equal });
return any.toBool();
}();
if (!wrapped)
return false;
return *wrapped;
}
#undef CALL_
#undef CALL
#undef CALL_MOVE_
#undef CALL_MOVE
#undef CONSTANT_FUNCTION
} // namespace WGSL