blob: 8e638b7d89a336f758099b006392c7d7cce100bc [file] [log] [blame]
import { globalTestConfig } from '../../../../common/framework/test_config.js';
import { assert, objectEquals, unreachable } from '../../../../common/util/util.js';
import { GPUTest } from '../../../gpu_test.js';
import { Comparator, ComparatorImpl } from '../../../util/compare.js';
import { kValue } from '../../../util/constants.js';
import {
MatrixType,
ScalarValue,
ScalarType,
Type,
VectorType,
Value,
VectorValue,
isAbstractType,
scalarTypeOf,
ArrayType,
elementTypeOf,
} from '../../../util/conversion.js';
import { align } from '../../../util/math.js';
import { Case } from './case.js';
import { toComparator } from './expectation.js';
/** The input value source */
export type InputSource =
| 'const' // Shader creation time constant values (@const)
| 'uniform' // Uniform buffer
| 'storage_r' // Read-only storage buffer
| 'storage_rw'; // Read-write storage buffer
/** All possible input sources */
export const allInputSources: InputSource[] = ['const', 'uniform', 'storage_r', 'storage_rw'];
/** Just constant input source */
export const onlyConstInputSource: InputSource[] = ['const'];
/** All input sources except const */
export const allButConstInputSource: InputSource[] = ['uniform', 'storage_r', 'storage_rw'];
/**
* An enumerator of methods the const-expression is evaluated and assigned to the output.
* direct: Each case has a separate assignment statement to the output buffer, where the RHS of
* the assignment holds the case's evaluated expression.
* unrolled: The case expressions are all evaluated and stored in a module-scope 'const' array.
* This array is indexed and the value is copied to the output buffer using an unrolled
* sequence of assignment statements.
* loop: The case expressions are all evaluated and stored in a module-scope 'const' array.
* This array is indexed and the value is copied to the output buffer using a for loop.
*/
export type ConstEvaluationMode = 'direct' | 'unrolled' | 'loop';
/** Configuration for running a expression test */
export type Config = {
/** Where the input values are read from */
inputSource: InputSource;
/**
* If defined, scalar test cases will be packed into vectors of the given
* width, which must be 2, 3 or 4.
* Requires that all parameters of the expression overload are of a scalar
* type, and the return type of the expression overload is also a scalar type.
* If the number of test cases is not a multiple of the vector width, then the
* last scalar value is repeated to fill the last vector value.
*/
vectorize?: number;
/**
* The evaluation mode used when 'inputSource' is 'const'. If undefined, then an appropriate mode
* will be picked based on the input types.
*/
constEvaluationMode?: ConstEvaluationMode;
};
/**
* @returns the size and alignment in bytes of the type 'ty', taking into
* consideration storage alignment constraints and abstract numerics, which are
* encoded as a struct of holding two u32s.
*/
function sizeAndAlignmentOf(ty: Type, source: InputSource): { size: number; alignment: number } {
if (ty instanceof ScalarType) {
if (ty.kind === 'abstract-float' || ty.kind === 'abstract-int') {
// AbstractFloats and AbstractInts are passed out of the shader via structs of
// 2x u32s and unpacking containers as arrays
return { size: 8, alignment: 8 };
}
return { size: ty.size, alignment: ty.alignment };
}
if (ty instanceof VectorType) {
const out = sizeAndAlignmentOf(ty.elementType, source);
const n = ty.width === 3 ? 4 : ty.width;
out.size *= n;
out.alignment *= n;
return out;
}
if (ty instanceof MatrixType) {
const out = sizeAndAlignmentOf(ty.elementType, source);
const n = ty.rows === 3 ? 4 : ty.rows;
out.size *= n * ty.cols;
out.alignment *= n;
return out;
}
if (ty instanceof ArrayType) {
const out = sizeAndAlignmentOf(ty.elementType, source);
// MAINTENANCE_TODO(#4485): Remove this when all implementors support uniform_buffer_standard_layout.
if (source === 'uniform') {
out.alignment = align(out.alignment, 16);
}
out.size *= ty.count;
return out;
}
unreachable(`unhandled type: ${ty}`);
}
/**
* @returns the stride in bytes of the type 'ty', taking into consideration abstract numerics,
* which are encoded as a struct of 2 x u32.
*/
function strideOf(ty: Type, source: InputSource): number {
const sizeAndAlign = sizeAndAlignmentOf(ty, source);
return align(sizeAndAlign.size, sizeAndAlign.alignment);
}
/**
* Calls 'callback' with the layout information of each structure member with the types 'members'.
* @returns the byte size, stride and alignment of the structure.
*/
export function structLayout(
members: Type[],
source: InputSource,
callback?: (m: {
index: number;
type: Type;
size: number;
alignment: number;
offset: number;
}) => void
): { size: number; stride: number; alignment: number } {
let offset = 0;
let alignment = 1;
for (let i = 0; i < members.length; i++) {
const member = members[i];
const sizeAndAlign = sizeAndAlignmentOf(member, source);
offset = align(offset, sizeAndAlign.alignment);
if (callback) {
callback({
index: i,
type: member,
size: sizeAndAlign.size,
alignment: sizeAndAlign.alignment,
offset,
});
}
offset += sizeAndAlign.size;
alignment = Math.max(alignment, sizeAndAlign.alignment);
}
// MAINTENANCE_TODO(#4485): Remove this when all implementors support uniform_buffer_standard_layout.
if (source === 'uniform') {
alignment = align(alignment, 16);
}
const size = offset;
const stride = align(size, alignment);
return { size, stride, alignment };
}
/** @returns the stride in bytes between two consecutive structures with the given members */
export function structStride(members: Type[], source: InputSource): number {
return structLayout(members, source).stride;
}
/** @returns the WGSL to describe the structure members in 'members' */
function wgslMembers(members: Type[], source: InputSource, memberName: (i: number) => string) {
const lines: string[] = [];
const layout = structLayout(members, source, m => {
lines.push(` @size(${m.size}) ${memberName(lines.length)} : ${m.type},`);
});
const padding = layout.stride - layout.size;
if (padding > 0) {
// Pad with a 'f16' if the padding requires an odd multiple of 2 bytes.
// This is required as 'i32' has an alignment and size of 4 bytes.
const ty = (padding & 2) !== 0 ? 'f16' : 'i32';
lines.push(` @size(${padding}) padding : ${ty},`);
}
return lines.join('\n');
}
// Helper for returning the WGSL storage type for the given Type.
function storageType(ty: Type): Type {
if (ty instanceof ScalarType) {
assert(ty.kind !== 'f64', `No storage type defined for 'f64' values`);
assert(ty.kind !== 'abstract-int', `Custom handling is implemented for 'abstract-int' values`);
assert(
ty.kind !== 'abstract-float',
`Custom handling is implemented for 'abstract-float' values`
);
if (ty.kind === 'bool') {
return Type.u32;
}
}
if (ty instanceof VectorType) {
return Type.vec(ty.width, storageType(ty.elementType) as ScalarType);
}
if (ty instanceof ArrayType) {
return Type.array(ty.count, storageType(ty.elementType));
}
return ty;
}
/** Structure used to hold [from|to]Storage conversion helpers */
type TypeConversionHelpers = {
// The module-scope WGSL to emit with the shader.
wgsl: string;
// A function that generates a unique WGSL identifier.
uniqueID: () => string;
};
// Helper for converting a value of the type 'ty' from the storage type.
function fromStorage(ty: Type, expr: string, helpers: TypeConversionHelpers): string {
if (ty instanceof ScalarType) {
assert(ty.kind !== 'abstract-int', `'abstract-int' values should not be in input storage`);
assert(ty.kind !== 'abstract-float', `'abstract-float' values should not be in input storage`);
assert(ty.kind !== 'f64', `'No storage type defined for 'f64' values`);
if (ty.kind === 'bool') {
return `${expr} != 0u`;
}
}
if (ty instanceof VectorType) {
assert(
ty.elementType.kind !== 'abstract-int',
`'abstract-int' values cannot appear in input storage`
);
assert(
ty.elementType.kind !== 'abstract-float',
`'abstract-float' values cannot appear in input storage`
);
assert(ty.elementType.kind !== 'f64', `'No storage type defined for 'f64' values`);
if (ty.elementType.kind === 'bool') {
return `(${expr} != vec${ty.width}<u32>(0u))`;
}
}
if (ty instanceof ArrayType && elementTypeOf(ty) === Type.bool) {
// array<u32, N> -> array<bool, N>
const conv = helpers.uniqueID();
const inTy = Type.array(ty.count, Type.u32);
helpers.wgsl += `
fn ${conv}(in : ${inTy}) -> ${ty} {
var out : ${ty};
for (var i = 0; i < ${ty.count}; i++) {
out[i] = in[i] != 0;
}
return out;
}
`;
return `${conv}(${expr})`;
}
return expr;
}
// Helper for converting a value of the type 'ty' to the storage type.
function toStorage(ty: Type, expr: string, helpers: TypeConversionHelpers): string {
if (ty instanceof ScalarType) {
assert(
ty.kind !== 'abstract-int',
`'abstract-int' values have custom code for writing to storage`
);
assert(
ty.kind !== 'abstract-float',
`'abstract-float' values have custom code for writing to storage`
);
assert(ty.kind !== 'f64', `No storage type defined for 'f64' values`);
if (ty.kind === 'bool') {
return `select(0u, 1u, ${expr})`;
}
}
if (ty instanceof VectorType) {
assert(
ty.elementType.kind !== 'abstract-int',
`'abstract-int' values have custom code for writing to storage`
);
assert(
ty.elementType.kind !== 'abstract-float',
`'abstract-float' values have custom code for writing to storage`
);
assert(ty.elementType.kind !== 'f64', `'No storage type defined for 'f64' values`);
if (ty.elementType.kind === 'bool') {
return `select(vec${ty.width}<u32>(0u), vec${ty.width}<u32>(1u), ${expr})`;
}
}
if (ty instanceof ArrayType && elementTypeOf(ty) === Type.bool) {
// array<bool, N> -> array<u32, N>
const conv = helpers.uniqueID();
const outTy = Type.array(ty.count, Type.u32);
helpers.wgsl += `
fn ${conv}(in : ${ty}) -> ${outTy} {
var out : ${outTy};
for (var i = 0; i < ${ty.count}; i++) {
out[i] = select(0u, 1u, in[i]);
}
return out;
}
`;
return `${conv}(${expr})`;
}
return expr;
}
// A Pipeline is a map of WGSL shader source to a built pipeline
type PipelineCache = Map<String, GPUComputePipeline>;
/**
* Searches for an entry with the given key, adding and returning the result of calling
* `create` if the entry was not found.
* @param map the cache map
* @param key the entry's key
* @param create the function used to construct a value, if not found in the cache
* @returns the value, either fetched from the cache, or newly built.
*/
function getOrCreate<K, V>(map: Map<K, V>, key: K, create: () => V) {
const existing = map.get(key);
if (existing !== undefined) {
return existing;
}
const value = create();
map.set(key, value);
return value;
}
/**
* Runs the list of expression tests, possibly splitting the tests into multiple
* dispatches to keep the input data within the buffer binding limits.
* run() will pack the scalar test cases into smaller set of vectorized tests
* if `cfg.vectorize` is defined.
* @param t the GPUTest
* @param shaderBuilder the shader builder function
* @param parameterTypes the list of expression parameter types
* @param resultType the return type for the expression overload
* @param cfg test configuration values
* @param cases list of test cases
* @param batch_size override the calculated casesPerBatch.
*/
export async function run(
t: GPUTest,
shaderBuilder: ShaderBuilder,
parameterTypes: Array<Type>,
resultType: Type,
cfg: Config = { inputSource: 'storage_r' },
cases: Case[],
batch_size?: number
) {
// If the 'vectorize' config option was provided, pack the cases into vectors.
if (cfg.vectorize !== undefined) {
const packed = packScalarsToVector(parameterTypes, resultType, cases, cfg.vectorize);
cases = packed.cases;
parameterTypes = packed.parameterTypes;
resultType = packed.resultType;
}
// The size of the input buffer may exceed the maximum buffer binding size,
// so chunk the tests up into batches that fit into the limits. We also split
// the cases into smaller batches to help with shader compilation performance.
const casesPerBatch = (function () {
if (batch_size) {
return batch_size;
}
switch (cfg.inputSource) {
case 'const':
// Some drivers are slow to optimize shaders with many constant values,
// or statements. 32 is an empirically picked number of cases that works
// well for most drivers.
return 32;
case 'uniform':
// Some drivers are slow to build pipelines with large uniform buffers.
// 2k appears to be a sweet-spot when benchmarking.
return Math.floor(
Math.min(1024 * 2, t.device.limits.maxUniformBufferBindingSize) /
structStride(parameterTypes, cfg.inputSource)
);
case 'storage_r':
case 'storage_rw':
return Math.floor(
t.device.limits.maxStorageBufferBindingSize /
structStride(parameterTypes, cfg.inputSource)
);
}
})();
// A cache to hold built shader pipelines.
const pipelineCache = new Map<String, GPUComputePipeline>();
// Submit all the cases in batches, rate-limiting to ensure not too many
// batches are in flight simultaneously.
const maxBatchesInFlight = 5;
let batchesInFlight = 0;
let resolvePromiseBlockingBatch: (() => void) | undefined = undefined;
const batchFinishedCallback = () => {
batchesInFlight -= 1;
// If there is any batch waiting on a previous batch to finish,
// unblock it now, and clear the resolve callback.
if (resolvePromiseBlockingBatch) {
resolvePromiseBlockingBatch();
resolvePromiseBlockingBatch = undefined;
}
};
const processBatch = async (batchCases: Case[]) => {
const shaderBuilderParams: ShaderBuilderParams = {
parameterTypes,
resultType,
cases: batchCases,
inputSource: cfg.inputSource,
constEvaluationMode: cfg.constEvaluationMode,
};
const checkBatch = await submitBatch(t, shaderBuilder, shaderBuilderParams, pipelineCache);
checkBatch();
await t.queue.onSubmittedWorkDone();
};
const pendingBatches = [];
for (let i = 0; i < cases.length; i += casesPerBatch) {
const batchCases = cases.slice(i, Math.min(i + casesPerBatch, cases.length));
if (batchesInFlight > maxBatchesInFlight) {
await new Promise<void>(resolve => {
// There should only be one batch waiting at a time.
assert(resolvePromiseBlockingBatch === undefined);
resolvePromiseBlockingBatch = resolve;
});
}
batchesInFlight += 1;
pendingBatches.push(
processBatch(batchCases)
.catch(err => {
if (err instanceof GPUPipelineError) {
t.fail(`Pipeline Creation Error, ${err.reason}: ${err.message}`);
} else {
throw err;
}
})
.finally(batchFinishedCallback)
);
}
await Promise.all(pendingBatches);
}
/**
* Submits the list of expression tests. The input data must fit within the
* buffer binding limits of the given inputSource.
* @param t the GPUTest
* @param shaderBuilder the shader builder function
* @param shaderBuilderParams the shader builder parameters
* @param pipelineCache the cache of compute pipelines, shared between batches
* @returns a function that checks the results are as expected
*/
async function submitBatch(
t: GPUTest,
shaderBuilder: ShaderBuilder,
shaderBuilderParams: ShaderBuilderParams,
pipelineCache: PipelineCache
): Promise<() => void> {
const { resultType, cases } = shaderBuilderParams;
// Construct a buffer to hold the results of the expression tests
const outputStride = structStride([resultType], 'storage_rw');
const outputBufferSize = align(cases.length * outputStride, 4);
const outputBuffer = t.createBufferTracked({
size: outputBufferSize,
usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
});
const [pipeline, group] = await buildPipeline(
t,
shaderBuilder,
shaderBuilderParams,
outputBuffer,
pipelineCache
);
const encoder = t.device.createCommandEncoder();
const pass = encoder.beginComputePass();
pass.setPipeline(pipeline);
pass.setBindGroup(0, group);
pass.dispatchWorkgroups(1);
pass.end();
// Heartbeat to ensure CTS runners know we're alive.
globalTestConfig.testHeartbeatCallback();
t.queue.submit([encoder.finish()]);
// Return a function that can check the results of the shader
return () => {
const checkExpectation = (outputData: Uint8Array) => {
// Read the outputs from the output buffer
const outputs = new Array<Value>(cases.length);
for (let i = 0; i < cases.length; i++) {
outputs[i] = resultType.read(outputData, i * outputStride);
}
// The list of expectation failures
const errs: string[] = [];
// For each case...
for (let caseIdx = 0; caseIdx < cases.length; caseIdx++) {
const c = cases[caseIdx];
const got = outputs[caseIdx];
const cmp = toComparator(c.expected).compare(got);
if (!cmp.matched) {
errs.push(`(${c.input instanceof Array ? c.input.join(', ') : c.input})
returned: ${cmp.got}
expected: ${cmp.expected}`);
}
}
return errs.length > 0 ? new Error(errs.join('\n\n')) : undefined;
};
// Heartbeat to ensure CTS runners know we're alive.
globalTestConfig.testHeartbeatCallback();
t.expectGPUBufferValuesPassCheck(outputBuffer, checkExpectation, {
type: Uint8Array,
typedLength: outputBufferSize,
});
};
}
/**
* map is a helper for returning a new array with each element of `v`
* transformed with `fn`.
* If `v` is not an array, then `fn` is called with (v, 0).
*/
function map<T, U>(v: T | readonly T[], fn: (value: T, index?: number) => U): U[] {
if (v instanceof Array) {
return v.map(fn);
}
return [fn(v, 0)];
}
/** The structured arguments for a ShaderBuilder function */
export type ShaderBuilderParams = {
/** the list of expression parameter types */
parameterTypes: Array<Type>;
/** the return type for the expression overload */
resultType: Type;
/** list of test cases that fit within the binding limits of the device */
cases: Case[];
/** the source of the input values */
inputSource: InputSource;
/** the optional evaluation mode when 'inputSource' is 'const' */
constEvaluationMode?: ConstEvaluationMode;
};
/** ShaderBuilder is a function used to construct the WGSL shader used by an expression test. */
export type ShaderBuilder = (params: ShaderBuilderParams) => string;
/**
* Helper that returns the WGSL to declare the output storage buffer for a shader
*/
function wgslOutputs(resultType: Type, count: number): string {
let output_struct = undefined;
if (
scalarTypeOf(resultType).kind !== 'abstract-float' &&
scalarTypeOf(resultType).kind !== 'abstract-int'
) {
output_struct = `
struct Output {
@size(${strideOf(resultType, 'storage_rw')}) value : ${storageType(resultType)}
};`;
} else {
if (resultType instanceof ScalarType) {
output_struct = `struct AF {
low: u32,
high: u32,
};
struct Output {
@size(${strideOf(resultType, 'storage_rw')}) value: AF,
};`;
}
if (resultType instanceof VectorType) {
const dim = resultType.width;
output_struct = `struct AF {
low: u32,
high: u32,
};
struct Output {
@size(${strideOf(resultType, 'storage_rw')}) value: array<AF, ${dim}>,
};`;
}
if (resultType instanceof MatrixType) {
const cols = resultType.cols;
const rows = resultType.rows === 2 ? 2 : 4; // 3 element rows have a padding element
output_struct = `struct AF {
low: u32,
high: u32,
};
struct Output {
@size(${strideOf(resultType, 'storage_rw')}) value: array<array<AF, ${rows}>, ${cols}>,
};`;
}
assert(output_struct !== undefined, `No implementation for result type '${resultType}'`);
}
return `${output_struct}
@group(0) @binding(0) var<storage, read_write> outputs : array<Output, ${count}>;
`;
}
/**
* Helper that returns the WGSL to declare the values array for a shader
*/
function wgslValuesArray(cases: Case[], expressionBuilder: ExpressionBuilder): string {
return `
const values = array(
${cases.map(c => expressionBuilder(map(c.input, v => v.wgsl()))).join(',\n ')}
);`;
}
/**
* Helper that returns the WGSL 'var' declaration for the given input source
*/
function wgslInputVar(inputSource: InputSource, count: number) {
switch (inputSource) {
case 'storage_r':
return `@group(0) @binding(1) var<storage, read> inputs : array<Input, ${count}>;`;
case 'storage_rw':
return `@group(0) @binding(1) var<storage, read_write> inputs : array<Input, ${count}>;`;
case 'uniform':
return `@group(0) @binding(1) var<uniform> inputs : array<Input, ${count}>;`;
}
throw new Error(`InputSource ${inputSource} does not use an input var`);
}
/**
* Helper that returns the WGSL header before any other declaration, currently include f16
* enable directive if necessary.
*/
function wgslHeader(parameterTypes: Array<Type>, resultType: Type) {
const usedF16 =
scalarTypeOf(resultType).kind === 'f16' ||
parameterTypes.some((ty: Type) => scalarTypeOf(ty).kind === 'f16');
const header = usedF16 ? 'enable f16;\n' : '';
return header;
}
/**
* ExpressionBuilder returns the WGSL used to evaluate an expression with the
* given input values.
*/
export type ExpressionBuilder = (values: ReadonlyArray<string>) => string;
/**
* @returns the WGSL for a basic expression test shader.
* @param expressionBuilder the expression builder
*/
function basicExpressionShaderBody(
expressionBuilder: ExpressionBuilder,
params: ShaderBuilderParams
): string {
const { parameterTypes, resultType, cases, inputSource } = params;
assert(
scalarTypeOf(resultType).kind !== 'abstract-int',
`abstractIntShaderBuilder should be used when result type is 'abstract-int'`
);
assert(
scalarTypeOf(resultType).kind !== 'abstract-float',
`abstractFloatShaderBuilder should be used when result type is 'abstract-float'`
);
let nextUniqueIDSuffix = 0;
const convHelpers: TypeConversionHelpers = {
wgsl: '',
uniqueID: () => `cts_symbol_${nextUniqueIDSuffix++}`,
};
if (inputSource === 'const') {
let constEvaluationMode = params.constEvaluationMode;
if (constEvaluationMode === undefined) {
if (parameterTypes.some(ty => isAbstractType(scalarTypeOf(ty)))) {
// Directly assign the expression to the output, to avoid an
// intermediate store, which will concretize the value early
constEvaluationMode = 'direct';
} else {
constEvaluationMode = globalTestConfig.unrollConstEvalLoops ? 'unrolled' : 'loop';
}
}
//////////////////////////////////////////////////////////////////////////
// Constant eval
//////////////////////////////////////////////////////////////////////////
let body = '';
let valuesArray = '';
switch (constEvaluationMode) {
case 'direct': {
body = cases
.map(
(c, i) =>
` outputs[${i}].value = ${toStorage(
resultType,
expressionBuilder(map(c.input, v => v.wgsl())),
convHelpers
)};`
)
.join('\n ');
break;
}
case 'unrolled': {
body = cases
.map((_, i) => {
const value = `values[${i}]`;
return ` outputs[${i}].value = ${toStorage(resultType, value, convHelpers)};`;
})
.join('\n ');
valuesArray = wgslValuesArray(cases, expressionBuilder);
break;
}
case 'loop': {
body = `
for (var i = 0u; i < ${cases.length}; i++) {
outputs[i].value = ${toStorage(resultType, `values[i]`, convHelpers)};
}`;
valuesArray = wgslValuesArray(cases, expressionBuilder);
break;
}
}
return `
${wgslOutputs(resultType, cases.length)}
${valuesArray}
${convHelpers.wgsl}
@compute @workgroup_size(1)
fn main() {
${body}
}
`;
} else {
//////////////////////////////////////////////////////////////////////////
// Runtime eval
//////////////////////////////////////////////////////////////////////////
// returns the WGSL expression to load the ith parameter of the given type from the input buffer
const paramExpr = (ty: Type, i: number) => fromStorage(ty, `inputs[i].param${i}`, convHelpers);
// resolves to the expression that calls the builtin
const expr = toStorage(
resultType,
expressionBuilder(parameterTypes.map(paramExpr)),
convHelpers
);
return `
struct Input {
${wgslMembers(parameterTypes.map(storageType), inputSource, i => `param${i}`)}
}
${wgslOutputs(resultType, cases.length)}
${wgslInputVar(inputSource, cases.length)}
${convHelpers.wgsl}
@compute @workgroup_size(1)
fn main() {
for (var i = 0; i < ${cases.length}; i++) {
outputs[i].value = ${expr};
}
}
`;
}
}
/**
* Returns a ShaderBuilder that builds a basic expression test shader.
* @param expressionBuilder the expression builder
*/
export function basicExpressionBuilder(expressionBuilder: ExpressionBuilder): ShaderBuilder {
return (params: ShaderBuilderParams) => {
return `\
${wgslHeader(params.parameterTypes, params.resultType)}
${basicExpressionShaderBody(expressionBuilder, params)}`;
};
}
/**
* Returns a ShaderBuilder that builds a basic expression test shader with given predeclaration
* string goes after WGSL header (i.e. enable directives) if any but before anything else.
* @param expressionBuilder the expression builder
* @param predeclaration the predeclaration string
*/
export function basicExpressionWithPredeclarationBuilder(
expressionBuilder: ExpressionBuilder,
predeclaration: string
): ShaderBuilder {
return (params: ShaderBuilderParams) => {
return `\
${wgslHeader(params.parameterTypes, params.resultType)}
${predeclaration}
${basicExpressionShaderBody(expressionBuilder, params)}`;
};
}
/**
* Returns a ShaderBuilder that builds a compound assignment operator test shader.
* @param op the compound operator
*/
export function compoundAssignmentBuilder(op: string): ShaderBuilder {
return (params: ShaderBuilderParams) => {
const { parameterTypes, resultType, cases, inputSource } = params;
//////////////////////////////////////////////////////////////////////////
// Input validation
//////////////////////////////////////////////////////////////////////////
if (parameterTypes.length !== 2) {
throw new Error(`compoundBinaryOp() requires exactly two parameters values per case`);
}
const lhsType = parameterTypes[0];
const rhsType = parameterTypes[1];
if (!objectEquals(lhsType, resultType)) {
throw new Error(
`compoundBinaryOp() requires result type (${resultType}) to be equal to the LHS type (${lhsType})`
);
}
if (inputSource === 'const') {
//////////////////////////////////////////////////////////////////////////
// Constant eval
//////////////////////////////////////////////////////////////////////////
let body = '';
if (globalTestConfig.unrollConstEvalLoops) {
body = cases
.map((_, i) => {
return `
var ret_${i} = lhs[${i}];
ret_${i} ${op} rhs[${i}];
outputs[${i}].value = ${storageType(resultType)}(ret_${i});`;
})
.join('\n ');
} else {
body = `
for (var i = 0u; i < ${cases.length}; i++) {
var ret = lhs[i];
ret ${op} rhs[i];
outputs[i].value = ${storageType(resultType)}(ret);
}`;
}
const values = cases.map(c => (c.input as Value[]).map(v => v.wgsl()));
return `
${wgslHeader(parameterTypes, resultType)}
${wgslOutputs(resultType, cases.length)}
const lhs = array(
${values.map(c => `${c[0]}`).join(',\n ')}
);
const rhs = array(
${values.map(c => `${c[1]}`).join(',\n ')}
);
@compute @workgroup_size(1)
fn main() {
${body}
}`;
} else {
//////////////////////////////////////////////////////////////////////////
// Runtime eval
//////////////////////////////////////////////////////////////////////////
let operation = '';
if (inputSource === 'storage_rw' && objectEquals(resultType, storageType(resultType))) {
operation = `
outputs[i].value = ${storageType(resultType)}(inputs[i].lhs);
outputs[i].value ${op} ${rhsType}(inputs[i].rhs);`;
} else {
operation = `
var ret = ${lhsType}(inputs[i].lhs);
ret ${op} ${rhsType}(inputs[i].rhs);
outputs[i].value = ${storageType(resultType)}(ret);`;
}
return `
${wgslHeader(parameterTypes, resultType)}
${wgslOutputs(resultType, cases.length)}
struct Input {
${wgslMembers([lhsType, rhsType].map(storageType), inputSource, i => ['lhs', 'rhs'][i])}
}
${wgslInputVar(inputSource, cases.length)}
@compute @workgroup_size(1)
fn main() {
for (var i = 0; i < ${cases.length}; i++) {
${operation}
}
}
`;
}
};
}
/**
* @returns a string that extracts the value of an AbstractFloat into an output
* destination
* @param expr expression for an AbstractFloat value, if working with vectors or
* matrices, this string needs to include indexing into the
* container.
* @param case_idx index in the case output array to assign the result
* @param accessor string representing how access to the AF that needs to be
* operated on.
* For scalars this should be left as ''.
* For vectors this will be an indexing operation,
* i.e. '[i]'
* For matrices this will double indexing operation,
* i.e. '[c][r]'
*/
function abstractFloatSnippet(expr: string, case_idx: number, accessor: string = ''): string {
// AbstractFloats are f64s under the hood. WebGPU does not support
// putting f64s in buffers, so the result needs to be split up into u32s
// and rebuilt in the test framework.
//
// Since there is no 64-bit data type that can be used as an element for a
// vector or a matrix in WGSL, the testing framework needs to pass the u32s
// via a struct with two u32s, and deconstruct vectors and matrices into
// arrays.
//
// This is complicated by the fact that user defined functions cannot
// take/return AbstractFloats, and AbstractFloats cannot be stored in
// variables, so the code cannot just inject a simple utility function
// at the top of the shader, instead this snippet needs to be inlined
// everywhere the test needs to return an AbstractFloat.
//
// select is used below, since ifs are not available during constant
// eval. This has the side effect of short-circuiting doesn't occur, so
// both sides of the select have to evaluate and be valid.
//
// This snippet implements FTZ for subnormals to bypass the need for
// complex subnormal specific logic.
//
// Expressions resulting in subnormals can still be reasonably tested,
// since this snippet will return 0 with the correct sign, which is
// always in the acceptance interval for a subnormal result, since an
// implementation may FTZ.
//
// Documentation for the snippet working with scalar results is included here
// in this code block, since shader length affects compilation time
// significantly on some backends. The code for vectors and matrices basically
// the same thing, with extra indexing operations.
//
// Snippet with documentation:
// const kExponentBias = 1022;
//
// // Detect if the value is zero or subnormal, so that FTZ behaviour
// // can occur
// const subnormal_or_zero : bool = (${expr} <= ${kValue.f64.positive.subnormal.max}) && (${expr} >= ${kValue.f64.negative.subnormal.min});
//
// // MSB of the upper u32 is 1 if the value is negative, otherwise 0
// // Extract the sign bit early, so that abs() can be used with
// // frexp() so negative cases do not need to be handled
// const sign_bit : u32 = select(0, 0x80000000, ${expr} < 0);
//
// // Use frexp() to obtain the exponent and fractional parts, and
// // then perform FTZ if needed
// const f = frexp(abs(${expr}));
// const f_fract = select(f.fract, 0, subnormal_or_zero);
// const f_exp = select(f.exp, -kExponentBias, subnormal_or_zero);
//
// // Adjust for the exponent bias and shift for storing in bits
// // [20..31] of the upper u32
// const exponent_bits : u32 = (f_exp + kExponentBias) << 20;
//
// // Extract the portion of the mantissa that appears in upper u32 as
// // a float for later use
// const high_mantissa = ldexp(f_fract, 21);
//
// // Extract the portion of the mantissa that appears in upper u32 as
// // as bits. This value is masked, because normals will explicitly
// // have the implicit leading 1 that should not be in the final
// // result.
// const high_mantissa_bits : u32 = u32(ldexp(f_fract, 21)) & 0x000fffff;
//
// // Calculate the mantissa stored in the lower u32 as a float
// const low_mantissa = f_fract - ldexp(floor(high_mantissa), -21);
//
// // Convert the lower u32 mantissa to bits
// const low_mantissa_bits = u32(ldexp(low_mantissa, 53));
//
// outputs[${i}].value.high = sign_bit | exponent_bits | high_mantissa_bits;
// outputs[${i}].value.low = low_mantissa_bits;
// prettier-ignore
return ` {
const kExponentBias = 1022;
const subnormal_or_zero : bool = (${expr}${accessor} <= ${kValue.f64.positive.subnormal.max}) && (${expr}${accessor} >= ${kValue.f64.negative.subnormal.min});
const sign_bit : u32 = select(0, 0x80000000, ${expr}${accessor} < 0);
const f = frexp(abs(${expr}${accessor}));
const f_fract = select(f.fract, 0, subnormal_or_zero);
const f_exp = select(f.exp, -kExponentBias, subnormal_or_zero);
const exponent_bits : u32 = (f_exp + kExponentBias) << 20;
const high_mantissa = ldexp(f_fract, 21);
const high_mantissa_bits : u32 = u32(ldexp(f_fract, 21)) & 0x000fffff;
const low_mantissa = f_fract - ldexp(floor(high_mantissa), -21);
const low_mantissa_bits = u32(ldexp(low_mantissa, 53));
outputs[${case_idx}].value${accessor}.high = sign_bit | exponent_bits | high_mantissa_bits;
outputs[${case_idx}].value${accessor}.low = low_mantissa_bits;
}`;
}
/** @returns a string for a specific case that has a AbstractFloat result */
function abstractFloatCaseBody(expr: string, resultType: Type, i: number): string {
if (resultType instanceof ScalarType) {
return abstractFloatSnippet(expr, i);
}
if (resultType instanceof VectorType) {
return [...Array(resultType.width).keys()]
.map(idx => abstractFloatSnippet(expr, i, `[${idx}]`))
.join(' \n');
}
if (resultType instanceof MatrixType) {
const cols = resultType.cols;
const rows = resultType.rows;
const results: String[] = [...Array(cols * rows)];
for (let c = 0; c < cols; c++) {
for (let r = 0; r < rows; r++) {
results[c * rows + r] = abstractFloatSnippet(expr, i, `[${c}][${r}]`);
}
}
return results.join(' \n');
}
unreachable(`Results of type '${resultType}' not yet implemented`);
}
/**
* @returns a ShaderBuilder that builds a test shader hands AbstractFloat results.
* @param expressionBuilder an expression builder that will return AbstractFloats
*/
export function abstractFloatShaderBuilder(expressionBuilder: ExpressionBuilder): ShaderBuilder {
return (params: ShaderBuilderParams) => {
const { parameterTypes, resultType, cases, inputSource } = params;
assert(inputSource === 'const', `'abstract-float' results are only defined for const-eval`);
assert(
scalarTypeOf(resultType).kind === 'abstract-float',
`Expected resultType of 'abstract-float', received '${scalarTypeOf(resultType).kind}' instead`
);
const body = cases
.map((c, i) => {
const expr = `${expressionBuilder(map(c.input, v => v.wgsl()))}`;
return abstractFloatCaseBody(expr, resultType, i);
})
.join('\n ');
return `
${wgslHeader(parameterTypes, resultType)}
${wgslOutputs(resultType, cases.length)}
@compute @workgroup_size(1)
fn main() {
${body}
}`;
};
}
/**
* @returns a string that extracts the value of an AbstractInt into an output
* destination
* @param expr expression for an AbstractInt value, if working with vectors,
* this string needs to include indexing into the container.
* @param case_idx index in the case output array to assign the result
* @param accessor string representing how access to the AbstractInt that needs
* to be operated on.
* For scalars this should be left as ''.
* For vectors this will be an indexing operation,
* i.e. '[i]'
*/
function abstractIntSnippet(expr: string, case_idx: number, accessor: string = ''): string {
// AbstractInts are i64s under the hood. WebGPU does not support
// putting i64s in buffers, or any 64-bit simple types, so the result needs to
// be split up into u32 bitfields
//
// Since there is no 64-bit data type that can be used as an element for a
// vector or a matrix in WGSL, the testing framework needs to pass the u32s
// via a struct with two u32s, and deconstruct vectors into arrays.
//
// This is complicated by the fact that user defined functions cannot
// take/return AbstractInts, and AbstractInts cannot be stored in
// variables, so the code cannot just inject a simple utility function
// at the top of the shader, instead this snippet needs to be inlined
// everywhere the test needs to return an AbstractInt.
return ` {
outputs[${case_idx}].value${accessor}.high = bitcast<u32>(i32(${expr}${accessor} >> 32)) & 0xFFFFFFFF;
const low_sign = (${expr}${accessor} & (1 << 31));
outputs[${case_idx}].value${accessor}.low = bitcast<u32>((${expr}${accessor} & 0x7FFFFFFF)) | low_sign;
}`;
}
/** @returns a string for a specific case that has a AbstractInt result */
function abstractIntCaseBody(expr: string, resultType: Type, i: number): string {
if (resultType instanceof ScalarType) {
return abstractIntSnippet(expr, i);
}
if (resultType instanceof VectorType) {
return [...Array(resultType.width).keys()]
.map(idx => abstractIntSnippet(expr, i, `[${idx}]`))
.join(' \n');
}
unreachable(`Results of type '${resultType}' not yet implemented`);
}
/**
* @returns a ShaderBuilder that builds a test shader hands AbstractInt results.
* @param expressionBuilder an expression builder that will return AbstractInts
*/
export function abstractIntShaderBuilder(expressionBuilder: ExpressionBuilder): ShaderBuilder {
return (params: ShaderBuilderParams) => {
const { parameterTypes, resultType, cases, inputSource } = params;
assert(inputSource === 'const', `'abstract-int' results are only defined for const-eval`);
assert(
scalarTypeOf(resultType).kind === 'abstract-int',
`Expected resultType of 'abstract-int', received '${scalarTypeOf(resultType).kind}' instead`
);
const body = cases
.map((c, i) => {
const expr = `${expressionBuilder(map(c.input, v => v.wgsl()))}`;
return abstractIntCaseBody(expr, resultType, i);
})
.join('\n ');
return `
${wgslHeader(parameterTypes, resultType)}
${wgslOutputs(resultType, cases.length)}
@compute @workgroup_size(1)
fn main() {
${body}
}`;
};
}
/**
* Constructs and returns a GPUComputePipeline and GPUBindGroup for running a
* batch of test cases. If a pre-created pipeline can be found in
* `pipelineCache`, then this may be returned instead of creating a new
* pipeline.
* @param t the GPUTest
* @param shaderBuilder the shader builder
* @param shaderBuilderParams the parameters for the shader builder
* @param outputBuffer the buffer that will hold the output values of the tests
* @param pipelineCache the cache of compute pipelines, shared between batches
*/
async function buildPipeline(
t: GPUTest,
shaderBuilder: ShaderBuilder,
shaderBuilderParams: ShaderBuilderParams,
outputBuffer: GPUBuffer,
pipelineCache: PipelineCache
): Promise<[GPUComputePipeline, GPUBindGroup]> {
const { parameterTypes, cases, inputSource } = shaderBuilderParams;
cases.forEach(c => {
const inputTypes = c.input instanceof Array ? c.input.map(i => i.type) : [c.input.type];
if (!objectEquals(inputTypes, parameterTypes)) {
const input_str = `[${inputTypes.join(',')}]`;
const param_str = `[${parameterTypes.join(',')}]`;
throw new Error(
`case input types ${input_str} do not match provided runner parameter types ${param_str}`
);
}
});
const source = shaderBuilder(shaderBuilderParams);
switch (inputSource) {
case 'const': {
// build the shader module
const module = t.device.createShaderModule({ code: source });
// build the pipeline
const pipeline = await t.device.createComputePipelineAsync({
layout: 'auto',
compute: { module, entryPoint: 'main' },
});
// build the bind group
const group = t.device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [{ binding: 0, resource: { buffer: outputBuffer } }],
});
return [pipeline, group];
}
case 'uniform':
case 'storage_r':
case 'storage_rw': {
// Input values come from a uniform or storage buffer
// size in bytes of the input buffer
const caseStride = structStride(parameterTypes, inputSource);
const inputSize = align(cases.length * caseStride, 4);
// Holds all the parameter values for all cases
const inputData = new Uint8Array(inputSize);
// Pack all the input parameter values into the inputData buffer
for (let caseIdx = 0; caseIdx < cases.length; caseIdx++) {
const offset = caseIdx * caseStride;
structLayout(parameterTypes, inputSource, m => {
const arg = cases[caseIdx].input;
if (arg instanceof Array) {
arg[m.index].copyTo(inputData, offset + m.offset);
} else {
arg.copyTo(inputData, offset + m.offset);
}
});
}
// build the compute pipeline, if the shader hasn't been compiled already.
const pipeline = getOrCreate(pipelineCache, source, () => {
// build the shader module
const module = t.device.createShaderModule({ code: source });
// build the pipeline
return t.device.createComputePipeline({
layout: 'auto',
compute: { module, entryPoint: 'main' },
});
});
// build the input buffer
const inputBuffer = t.makeBufferWithContents(
inputData,
GPUBufferUsage.COPY_SRC |
(inputSource === 'uniform' ? GPUBufferUsage.UNIFORM : GPUBufferUsage.STORAGE)
);
// build the bind group
const group = t.device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: outputBuffer } },
{ binding: 1, resource: { buffer: inputBuffer } },
],
});
return [pipeline, group];
}
}
}
/**
* Packs a list of scalar test cases into a smaller list of vector cases.
* Requires that all parameters of the expression overload are of a scalar type,
* and the return type of the expression overload is also a scalar type.
* If `cases.length` is not a multiple of `vectorWidth`, then the last scalar
* test case value is repeated to fill the vector value.
*/
export function packScalarsToVector(
parameterTypes: Array<Type>,
resultType: Type,
cases: Case[],
vectorWidth: number
): { cases: Case[]; parameterTypes: Array<Type>; resultType: Type } {
// Validate that the parameters and return type are all vectorizable
for (let i = 0; i < parameterTypes.length; i++) {
const ty = parameterTypes[i];
if (!(ty instanceof ScalarType)) {
throw new Error(
`packScalarsToVector() can only be used on scalar parameter types, but the ${i}'th parameter type is a ${ty}'`
);
}
}
if (!(resultType instanceof ScalarType)) {
throw new Error(
`packScalarsToVector() can only be used with a scalar return type, but the return type is a ${resultType}'`
);
}
const packedCases: Array<Case> = [];
const packedParameterTypes = parameterTypes.map(p => Type.vec(vectorWidth, p as ScalarType));
const packedResultType = Type.vec(vectorWidth, resultType);
const clampCaseIdx = (idx: number) => Math.min(idx, cases.length - 1);
let caseIdx = 0;
while (caseIdx < cases.length) {
// Construct the vectorized inputs from the scalar cases
const packedInputs = new Array<VectorValue>(parameterTypes.length);
for (let paramIdx = 0; paramIdx < parameterTypes.length; paramIdx++) {
const inputElements = new Array<ScalarValue>(vectorWidth);
for (let i = 0; i < vectorWidth; i++) {
const input = cases[clampCaseIdx(caseIdx + i)].input;
inputElements[i] = (input instanceof Array ? input[paramIdx] : input) as ScalarValue;
}
packedInputs[paramIdx] = new VectorValue(inputElements);
}
// Gather the comparators for the packed cases
const cmp_impls = new Array<ComparatorImpl>(vectorWidth);
for (let i = 0; i < vectorWidth; i++) {
cmp_impls[i] = toComparator(cases[clampCaseIdx(caseIdx + i)].expected).compare;
}
const comparators: Comparator = {
compare: (got: Value) => {
let matched = true;
const gElements = new Array<string>(vectorWidth);
const eElements = new Array<string>(vectorWidth);
for (let i = 0; i < vectorWidth; i++) {
const d = cmp_impls[i]((got as VectorValue).elements[i]);
matched = matched && d.matched;
gElements[i] = d.got;
eElements[i] = d.expected;
}
return {
matched,
got: `${packedResultType}(${gElements.join(', ')})`,
expected: `${packedResultType}(${eElements.join(', ')})`,
};
},
kind: 'packed',
};
// Append the new packed case
packedCases.push({ input: packedInputs, expected: comparators });
caseIdx += vectorWidth;
}
return {
cases: packedCases,
parameterTypes: packedParameterTypes,
resultType: packedResultType,
};
}