| // META: title=validation tests for WebNN API softmax operation |
| // META: global=window |
| // META: variant=?cpu |
| // META: variant=?gpu |
| // META: variant=?npu |
| // META: script=../resources/utils_validation.js |
| |
| 'use strict'; |
| |
| const tests_without_axis = [ |
| { |
| name: '[softmax] Test building Softmax with float32 input without axis.', |
| input: {dataType: 'float32', shape: [4, 3]}, |
| output: {dataType: 'float32', shape: [4, 3]} |
| }, |
| { |
| name: '[softmax] Test building Softmax with float16 input without axis.', |
| input: {dataType: 'float16', shape: [3, 5]}, |
| output: {dataType: 'float16', shape: [3, 5]} |
| }, |
| { |
| name: '[softmax] Throw if the input is not a non-floating point data.', |
| input: {dataType: 'int32', shape: [3, 2]} |
| }, |
| { |
| name: '[softmax] Throw if the input dimensions is not 2.', |
| input: {dataType: 'float32', shape: [1, 4, 3]} |
| } |
| ]; |
| |
| tests_without_axis.forEach( |
| test => promise_test(async t => { |
| const builder = new MLGraphBuilder(context); |
| let input = builder.input(`input`, test.input); |
| if (test.output) { |
| const output = builder.softmax(input); |
| assert_equals(output.dataType, test.output.dataType); |
| assert_array_equals(output.shape, test.output.shape); |
| } else { |
| const options = { |
| label: 'softmax_xxx', |
| }; |
| try { |
| builder.softmax(input, options); |
| } catch (e) { |
| assert_equals(e.name, 'TypeError'); |
| const error_message = e.message; |
| const regrexp = /\[softmax_xxx\]/; |
| assert_not_equals(error_message.match(regrexp), null); |
| } |
| } |
| }, test.name)); |
| |
| multi_builder_test(async (t, builder, otherBuilder) => { |
| const operandDescriptor = {dataType: 'float32', shape: [2, 3]}; |
| const inputFromOtherBuilder = otherBuilder.input('input', operandDescriptor); |
| |
| assert_throws_js(TypeError, () => builder.softmax(inputFromOtherBuilder)); |
| }, '[softmax without axis] throw if any input is from another builder'); |
| |
| const tests = [ |
| { |
| name: '[softmax] Test building Softmax with float32 input.', |
| input: {dataType: 'float32', shape: [4, 4, 3]}, |
| axis: 1, |
| output: {dataType: 'float32', shape: [4, 4, 3]} |
| }, |
| { |
| name: '[softmax] Test building Softmax with float16 input.', |
| input: {dataType: 'float16', shape: [3, 1, 5, 2]}, |
| axis: 2, |
| output: {dataType: 'float16', shape: [3, 1, 5, 2]} |
| }, |
| { |
| name: '[softmax] Throw if the input is not a non-floating-point data.', |
| input: {dataType: 'int32', shape: [3, 1, 5, 2]}, |
| axis: 3 |
| }, |
| { |
| name: '[softmax] Throw if the axis is greater than input rank - 1.', |
| input: {dataType: 'float16', shape: [3, 1, 5, 2]}, |
| axis: 4 |
| } |
| ]; |
| |
| tests.forEach( |
| test => promise_test(async t => { |
| const builder = new MLGraphBuilder(context); |
| let input = builder.input(`input`, test.input); |
| if (test.output) { |
| const output = builder.softmax(input, test.axis); |
| assert_equals(output.dataType, test.output.dataType); |
| assert_array_equals(output.shape, test.output.shape); |
| } else { |
| const label = 'softmax_xxx'; |
| const options = {label}; |
| const regrexp = new RegExp('\\[' + label + '\\]'); |
| assert_throws_with_label( |
| () => builder.softmax(input, test.axis, options), regrexp); |
| } |
| }, test.name)); |
| |
| multi_builder_test(async (t, builder, otherBuilder) => { |
| const operandDescriptor = {dataType: 'float32', shape: [1, 2, 3]}; |
| const inputFromOtherBuilder = otherBuilder.input('input', operandDescriptor); |
| const axis = 1; |
| |
| assert_throws_js( |
| TypeError, () => builder.softmax(inputFromOtherBuilder, axis)); |
| }, '[softmax] throw if any input is from another builder'); |