WebNN: Support gruCell MLOperator
This CL adds the IDL and mojo definition for gruCell and implements
`MLGraphBuilder::gruCell()`.
This CL also adds gruCell validation tests in WPT and unit tests in
`WebNNGraphImplTest`.
Bug: 40206287
Change-Id: I0873a5c979bdd3e44a0dd30cba7a53410141ef78
Cq-Include-Trybots: luci.chromium.try:gpu-fyi-try-win11-qualcomm-rel-64, win11-blink-rel
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5376901
Reviewed-by: ningxin hu <ningxin.hu@intel.com>
Reviewed-by: Austin Sullivan <asully@chromium.org>
Commit-Queue: Bin Miao <bin.miao@intel.com>
Cr-Commit-Position: refs/heads/main@{#1280572}
diff --git a/webnn/validation_tests/gruCell.https.any.js b/webnn/validation_tests/gruCell.https.any.js
index 5056f8b..3cd9d32 100644
--- a/webnn/validation_tests/gruCell.https.any.js
+++ b/webnn/validation_tests/gruCell.https.any.js
@@ -14,6 +14,8 @@
// Dimensions required of optional inputs.
const kValidBiasDimensions = [3 * hiddenSize];
const kValidRecurrentBiasDimensions = [3 * hiddenSize];
+// Dimensions required of required output.
+const kValidOutputDimensions = [batchSize, hiddenSize];
// Example descriptors which are valid according to the above dimensions.
const kExampleInputDescriptor = {
@@ -40,6 +42,320 @@
dataType: 'float32',
dimensions: kValidRecurrentBiasDimensions
};
+const kExampleOutputDescriptor = {
+ dataType: 'float32',
+ dimensions: kValidOutputDimensions
+ };
+
+const tests = [
+ {
+ name: '[gruCell] Test with default options',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ output: kExampleOutputDescriptor
+ },
+ {
+ name: '[gruCell] Test with given options',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ options: {
+ bias: kExampleBiasDescriptor,
+ recurrentBias: kExampleRecurrentBiasDescriptor,
+ restAfter: true,
+ layout: 'rzn',
+ activations: ['sigmoid', 'relu']
+ },
+ output: kExampleOutputDescriptor
+ },
+ {
+ name: '[gruCell] Throw if hiddenSize equals to zero',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: 0
+ },
+ {
+ name: '[gruCell] Throw if hiddenSize is too large',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: 4294967295,
+ },
+ {
+ name:
+ '[gruCell] Throw if the data type of the inputs is not one of the floating point types',
+ input: { dataType: 'uint32', dimensions: kValidInputDimensions },
+ weight: { dataType: 'uint32', dimensions: kValidWeightDimensions },
+ recurrentWeight: {
+ dataType: 'uint32',
+ dimensions: kValidRecurrentWeightDimensions
+ },
+ hiddenState: {
+ dataType: 'uint32',
+ dimensions: kValidHiddenStateDimensions
+ },
+ hiddenSize: hiddenSize
+ },
+ {
+ name:
+ '[gruCell] Throw if the rank of input is not 2',
+ input: { dataType: 'float32', dimensions: [batchSize] },
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize
+ },
+ {
+ name:
+ '[gruCell] Throw if the input.dimensions[1] is incorrect',
+ input: { dataType: 'float32', dimensions: [inputSize, inputSize] },
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[gruCell] Throw if data type of weight is not one of the floating point types',
+ input: kExampleInputDescriptor,
+ weight: {
+ dataType: 'int8',
+ dimensions: [3 * hiddenSize, inputSize]
+ },
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[gruCell] Throw if rank of weight is not 2',
+ input: kExampleInputDescriptor,
+ weight: {
+ dataType: 'float32',
+ dimensions: [3 * hiddenSize]
+ },
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[gruCell] Throw if weight.dimensions[0] is not 3 * hiddenSize',
+ input: kExampleInputDescriptor,
+ weight: {
+ dataType: 'float32',
+ dimensions: [4 * hiddenSize, inputSize]
+ },
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[gruCell] Throw if data type of recurrentWeight is not one of the floating point types',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: {
+ dataType: 'int32',
+ dimensions: [3 * hiddenSize, hiddenSize]
+ },
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize
+ },
+ {
+ name:
+ '[gruCell] Throw if the rank of recurrentWeight is not 2',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight:
+ { dataType: 'float32', dimensions: [3 * hiddenSize] },
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize
+ },
+ {
+ name:
+ '[gruCell] Throw if the recurrentWeight.dimensions is invalid',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight:
+ { dataType: 'float32', dimensions: [4 * hiddenSize, inputSize] },
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize
+ },
+ {
+ name:
+ '[gruCell] Throw if data type of hiddenState is not one of the floating point types',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight:
+ kExampleRecurrentWeightDescriptor,
+ hiddenState: {
+ dataType: 'uint32',
+ dimensions: [batchSize, hiddenSize]
+ },
+ hiddenSize: hiddenSize
+ },
+ {
+ name:
+ '[gruCell] Throw if the rank of hiddenState is not 2',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight:
+ kExampleRecurrentWeightDescriptor,
+ hiddenState: {
+ dataType: 'float32',
+ dimensions: [hiddenSize]
+ },
+ hiddenSize: hiddenSize
+ },
+ {
+ name:
+ '[gruCell] Throw if the hiddenState.dimensions is invalid',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: {
+ dataType: 'float32',
+ dimensions: [batchSize, 3 * hiddenSize]
+ },
+ hiddenSize: hiddenSize
+ },
+ {
+ name:
+ '[gruCell] Throw if the size of options.activations is not 2',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ options: { activations: ['sigmoid', 'tanh', 'relu'] }
+ },
+ {
+ name:
+ '[gruCell] Throw if data type of options.bias is not one of the floating point types',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ options: { bias: { dataType: 'uint8', dimensions: [3 * hiddenSize] } }
+ },
+ {
+ name:
+ '[gruCell] Throw if the rank of options.bias is not 1',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ options: { bias: { dataType: 'float32', dimensions: [batchSize, 3 * hiddenSize] } }
+ },
+ {
+ name:
+ '[gruCell] Throw if options.bias.dimensions[0] is not 3 * hiddenSize',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ options: { bias: { dataType: 'float32', dimensions: [2 * hiddenSize] } }
+ },
+ {
+ name:
+ '[gruCell] Throw if data type of options.recurrentBias is not one of the floating point types',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ options: { recurrentBias: { dataType: 'int8', dimensions: [3 * hiddenSize] } }
+ },
+ {
+ name:
+ '[gruCell] Throw if the rank of options.recurrentBias is not 1',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ options: { recurrentBias: { dataType: 'float32', dimensions: [batchSize, 3 * hiddenSize] } }
+ },
+ {
+ name:
+ '[gruCell] Throw if options.recurrentBias.dimensions[0] is not 3 * hiddenSize',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ options: {
+ recurrentBias: { dataType: 'float16', dimensions: [4 * hiddenSize] }
+ }
+ }
+];
+
+tests.forEach(
+ test => promise_test(async t => {
+ const input = builder.input(
+ 'input',
+ { dataType: test.input.dataType, dimensions: test.input.dimensions });
+ const weight = builder.input(
+ 'weight',
+ { dataType: test.weight.dataType, dimensions: test.weight.dimensions });
+ const recurrentWeight = builder.input('recurrentWeight', {
+ dataType: test.recurrentWeight.dataType,
+ dimensions: test.recurrentWeight.dimensions
+ });
+ const hiddenState = builder.input('hiddenState', {
+ dataType: test.hiddenState.dataType,
+ dimensions: test.hiddenState.dimensions
+ });
+
+ const options = {};
+ if (test.options) {
+ if (test.options.bias) {
+ options.bias = builder.input('bias', {
+ dataType: test.options.bias.dataType,
+ dimensions: test.options.bias.dimensions
+ });
+ }
+ if (test.options.recurrentBias) {
+ options.bias = builder.input('recurrentBias', {
+ dataType: test.options.recurrentBias.dataType,
+ dimensions: test.options.recurrentBias.dimensions
+ });
+ }
+ if (test.options.resetAfter) {
+ options.resetAfter = test.options.resetAfter;
+ }
+ if (test.options.layout) {
+ options.layout = test.options.layout;
+ }
+ if (test.options.activations) {
+ options.activations = [];
+ test.options.activations.forEach(
+ activation => options.activations.push(builder[activation]()));
+ }
+ }
+
+ if (test.output) {
+ const output = builder.gruCell(
+ input, weight, recurrentWeight, hiddenState, test.hiddenSize,
+ options);
+ assert_equals(output.dataType(), test.output.dataType);
+ assert_array_equals(output.shape(), test.output.dimensions);
+ } else {
+ assert_throws_js(
+ TypeError,
+ () => builder.gruCell(
+ input, weight, recurrentWeight, hiddenState, test.hiddenSize,
+ options));
+ }
+ }, test.name));
multi_builder_test(async (t, builder, otherBuilder) => {
const inputFromOtherBuilder =