| // |
| // Copyright (c) 2021 The Khronos Group Inc. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| #include "procs.h" |
| #include "subhelpers.h" |
| #include "subgroup_common_templates.h" |
| #include "harness/typeWrappers.h" |
| #include <bitset> |
| |
| namespace { |
| // Test for ballot functions |
| template <typename Ty> struct BALLOT |
| { |
| static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params) |
| { |
| // no work here |
| int gws = test_params.global_workgroup_size; |
| int lws = test_params.local_workgroup_size; |
| int sbs = test_params.subgroup_size; |
| int non_uniform_size = gws % lws; |
| log_info(" sub_group_ballot...\n"); |
| if (non_uniform_size) |
| { |
| log_info(" non uniform work group size mode ON\n"); |
| } |
| } |
| |
| static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m, |
| const WorkGroupParams &test_params) |
| { |
| int wi_id, wg_id, sb_id; |
| int gws = test_params.global_workgroup_size; |
| int lws = test_params.local_workgroup_size; |
| int sbs = test_params.subgroup_size; |
| int sb_number = (lws + sbs - 1) / sbs; |
| int current_sbs = 0; |
| cl_uint expected_result, device_result; |
| int non_uniform_size = gws % lws; |
| int wg_number = gws / lws; |
| wg_number = non_uniform_size ? wg_number + 1 : wg_number; |
| int last_subgroup_size = 0; |
| |
| for (wg_id = 0; wg_id < wg_number; ++wg_id) |
| { // for each work_group |
| if (non_uniform_size && wg_id == wg_number - 1) |
| { |
| set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws, |
| last_subgroup_size); |
| } |
| |
| for (wi_id = 0; wi_id < lws; ++wi_id) |
| { // inside the work_group |
| // read device outputs for work_group |
| my[wi_id] = y[wi_id]; |
| } |
| |
| for (sb_id = 0; sb_id < sb_number; ++sb_id) |
| { // for each subgroup |
| int wg_offset = sb_id * sbs; |
| if (last_subgroup_size && sb_id == sb_number - 1) |
| { |
| current_sbs = last_subgroup_size; |
| } |
| else |
| { |
| current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; |
| } |
| for (wi_id = 0; wi_id < current_sbs; ++wi_id) |
| { |
| device_result = my[wg_offset + wi_id]; |
| expected_result = 1; |
| if (!compare(device_result, expected_result)) |
| { |
| log_error( |
| "ERROR: sub_group_ballot mismatch for local id " |
| "%d in sub group %d in group %d obtained {%d}, " |
| "expected {%d} \n", |
| wi_id, sb_id, wg_id, device_result, |
| expected_result); |
| return TEST_FAIL; |
| } |
| } |
| } |
| y += lws; |
| m += 4 * lws; |
| } |
| log_info(" sub_group_ballot... passed\n"); |
| return TEST_PASS; |
| } |
| }; |
| |
| // Test for bit extract ballot functions |
| template <typename Ty, BallotOp operation> struct BALLOT_BIT_EXTRACT |
| { |
| static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params) |
| { |
| int wi_id, sb_id, wg_id, l; |
| int gws = test_params.global_workgroup_size; |
| int lws = test_params.local_workgroup_size; |
| int sbs = test_params.subgroup_size; |
| int sb_number = (lws + sbs - 1) / sbs; |
| int wg_number = gws / lws; |
| int limit_sbs = sbs > 100 ? 100 : sbs; |
| int non_uniform_size = gws % lws; |
| log_info(" sub_group_%s(%s)...\n", operation_names(operation), |
| TypeManager<Ty>::name()); |
| |
| if (non_uniform_size) |
| { |
| log_info(" non uniform work group size mode ON\n"); |
| } |
| |
| for (wg_id = 0; wg_id < wg_number; ++wg_id) |
| { // for each work_group |
| for (sb_id = 0; sb_id < sb_number; ++sb_id) |
| { // for each subgroup |
| int wg_offset = sb_id * sbs; |
| int current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; |
| // rand index to bit extract |
| int index_for_odd = (int)(genrand_int32(gMTdata) & 0x7fffffff) |
| % (limit_sbs > current_sbs ? current_sbs : limit_sbs); |
| int index_for_even = (int)(genrand_int32(gMTdata) & 0x7fffffff) |
| % (limit_sbs > current_sbs ? current_sbs : limit_sbs); |
| for (wi_id = 0; wi_id < current_sbs; ++wi_id) |
| { |
| // index of the third element int the vector. |
| int midx = 4 * wg_offset + 4 * wi_id + 2; |
| // storing information about index to bit extract |
| m[midx] = (cl_int)index_for_odd; |
| m[++midx] = (cl_int)index_for_even; |
| } |
| set_randomdata_for_subgroup<Ty>(t, wg_offset, current_sbs); |
| } |
| |
| // Now map into work group using map from device |
| for (wi_id = 0; wi_id < lws; ++wi_id) |
| { |
| x[wi_id] = t[wi_id]; |
| } |
| |
| x += lws; |
| m += 4 * lws; |
| } |
| } |
| |
| static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m, |
| const WorkGroupParams &test_params) |
| { |
| int wi_id, wg_id, l, sb_id; |
| int gws = test_params.global_workgroup_size; |
| int lws = test_params.local_workgroup_size; |
| int sbs = test_params.subgroup_size; |
| int sb_number = (lws + sbs - 1) / sbs; |
| int wg_number = gws / lws; |
| cl_uint4 expected_result, device_result; |
| int last_subgroup_size = 0; |
| int current_sbs = 0; |
| int non_uniform_size = gws % lws; |
| |
| for (wg_id = 0; wg_id < wg_number; ++wg_id) |
| { // for each work_group |
| if (non_uniform_size && wg_id == wg_number - 1) |
| { |
| set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws, |
| last_subgroup_size); |
| } |
| // Map to array indexed to array indexed by local ID and sub group |
| for (wi_id = 0; wi_id < lws; ++wi_id) |
| { // inside the work_group |
| // read host inputs for work_group |
| mx[wi_id] = x[wi_id]; |
| // read device outputs for work_group |
| my[wi_id] = y[wi_id]; |
| } |
| |
| for (sb_id = 0; sb_id < sb_number; ++sb_id) |
| { // for each subgroup |
| int wg_offset = sb_id * sbs; |
| if (last_subgroup_size && sb_id == sb_number - 1) |
| { |
| current_sbs = last_subgroup_size; |
| } |
| else |
| { |
| current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; |
| } |
| // take index of array where info which work_item will |
| // be broadcast its value is stored |
| int midx = 4 * wg_offset + 2; |
| // take subgroup local id of this work_item |
| int index_for_odd = (int)m[midx]; |
| int index_for_even = (int)m[++midx]; |
| |
| for (wi_id = 0; wi_id < current_sbs; ++wi_id) |
| { // for each subgroup |
| int bit_value = 0; |
| // from which value of bitfield bit |
| // verification will be done |
| int take_shift = |
| (wi_id & 1) ? index_for_odd % 32 : index_for_even % 32; |
| int bit_mask = 1 << take_shift; |
| |
| if (wi_id < 32) |
| (mx[wg_offset + wi_id].s0 & bit_mask) > 0 |
| ? bit_value = 1 |
| : bit_value = 0; |
| if (wi_id >= 32 && wi_id < 64) |
| (mx[wg_offset + wi_id].s1 & bit_mask) > 0 |
| ? bit_value = 1 |
| : bit_value = 0; |
| if (wi_id >= 64 && wi_id < 96) |
| (mx[wg_offset + wi_id].s2 & bit_mask) > 0 |
| ? bit_value = 1 |
| : bit_value = 0; |
| if (wi_id >= 96 && wi_id < 128) |
| (mx[wg_offset + wi_id].s3 & bit_mask) > 0 |
| ? bit_value = 1 |
| : bit_value = 0; |
| |
| if (wi_id & 1) |
| { |
| bit_value ? expected_result = { 1, 0, 0, 1 } |
| : expected_result = { 0, 0, 0, 1 }; |
| } |
| else |
| { |
| bit_value ? expected_result = { 1, 0, 0, 2 } |
| : expected_result = { 0, 0, 0, 2 }; |
| } |
| |
| device_result = my[wg_offset + wi_id]; |
| if (!compare(device_result, expected_result)) |
| { |
| log_error( |
| "ERROR: sub_group_%s mismatch for local id %d in " |
| "sub group %d in group %d obtained {%d, %d, %d, " |
| "%d}, expected {%d, %d, %d, %d}\n", |
| operation_names(operation), wi_id, sb_id, wg_id, |
| device_result.s0, device_result.s1, |
| device_result.s2, device_result.s3, |
| expected_result.s0, expected_result.s1, |
| expected_result.s2, expected_result.s3); |
| return TEST_FAIL; |
| } |
| } |
| } |
| x += lws; |
| y += lws; |
| m += 4 * lws; |
| } |
| log_info(" sub_group_%s(%s)... passed\n", operation_names(operation), |
| TypeManager<Ty>::name()); |
| return TEST_PASS; |
| } |
| }; |
| |
| template <typename Ty, BallotOp operation> struct BALLOT_INVERSE |
| { |
| static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params) |
| { |
| int gws = test_params.global_workgroup_size; |
| int lws = test_params.local_workgroup_size; |
| int sbs = test_params.subgroup_size; |
| int non_uniform_size = gws % lws; |
| log_info(" sub_group_inverse_ballot...\n"); |
| if (non_uniform_size) |
| { |
| log_info(" non uniform work group size mode ON\n"); |
| } |
| // no work here |
| } |
| |
| static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m, |
| const WorkGroupParams &test_params) |
| { |
| int wi_id, wg_id, sb_id; |
| int gws = test_params.global_workgroup_size; |
| int lws = test_params.local_workgroup_size; |
| int sbs = test_params.subgroup_size; |
| int sb_number = (lws + sbs - 1) / sbs; |
| cl_uint4 expected_result, device_result; |
| int non_uniform_size = gws % lws; |
| int wg_number = gws / lws; |
| int last_subgroup_size = 0; |
| int current_sbs = 0; |
| if (non_uniform_size) wg_number++; |
| |
| for (wg_id = 0; wg_id < wg_number; ++wg_id) |
| { // for each work_group |
| if (non_uniform_size && wg_id == wg_number - 1) |
| { |
| set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws, |
| last_subgroup_size); |
| } |
| // Map to array indexed to array indexed by local ID and sub group |
| for (wi_id = 0; wi_id < lws; ++wi_id) |
| { // inside the work_group |
| mx[wi_id] = x[wi_id]; // read host inputs for work_group |
| my[wi_id] = y[wi_id]; // read device outputs for work_group |
| } |
| |
| for (sb_id = 0; sb_id < sb_number; ++sb_id) |
| { // for each subgroup |
| int wg_offset = sb_id * sbs; |
| if (last_subgroup_size && sb_id == sb_number - 1) |
| { |
| current_sbs = last_subgroup_size; |
| } |
| else |
| { |
| current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; |
| } |
| // take index of array where info which work_item will |
| // be broadcast its value is stored |
| int midx = 4 * wg_offset + 2; |
| // take subgroup local id of this work_item |
| // Check result |
| for (wi_id = 0; wi_id < current_sbs; ++wi_id) |
| { // for each subgroup work item |
| |
| wi_id & 1 ? expected_result = { 1, 0, 0, 1 } |
| : expected_result = { 1, 0, 0, 2 }; |
| |
| device_result = my[wg_offset + wi_id]; |
| if (!compare(device_result, expected_result)) |
| { |
| log_error( |
| "ERROR: sub_group_%s mismatch for local id %d in " |
| "sub group %d in group %d obtained {%d, %d, %d, " |
| "%d}, expected {%d, %d, %d, %d}\n", |
| operation_names(operation), wi_id, sb_id, wg_id, |
| device_result.s0, device_result.s1, |
| device_result.s2, device_result.s3, |
| expected_result.s0, expected_result.s1, |
| expected_result.s2, expected_result.s3); |
| return TEST_FAIL; |
| } |
| } |
| } |
| x += lws; |
| y += lws; |
| m += 4 * lws; |
| } |
| |
| log_info(" sub_group_inverse_ballot... passed\n"); |
| return TEST_PASS; |
| } |
| }; |
| |
| |
| // Test for bit count/inclusive and exclusive scan/ find lsb msb ballot function |
| template <typename Ty, BallotOp operation> struct BALLOT_COUNT_SCAN_FIND |
| { |
| static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params) |
| { |
| int wi_id, wg_id, sb_id; |
| int gws = test_params.global_workgroup_size; |
| int lws = test_params.local_workgroup_size; |
| int sbs = test_params.subgroup_size; |
| int sb_number = (lws + sbs - 1) / sbs; |
| int non_uniform_size = gws % lws; |
| int wg_number = gws / lws; |
| int last_subgroup_size = 0; |
| int current_sbs = 0; |
| |
| log_info(" sub_group_%s(%s)...\n", operation_names(operation), |
| TypeManager<Ty>::name()); |
| if (non_uniform_size) |
| { |
| log_info(" non uniform work group size mode ON\n"); |
| wg_number++; |
| } |
| int e; |
| for (wg_id = 0; wg_id < wg_number; ++wg_id) |
| { // for each work_group |
| if (non_uniform_size && wg_id == wg_number - 1) |
| { |
| set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws, |
| last_subgroup_size); |
| } |
| for (sb_id = 0; sb_id < sb_number; ++sb_id) |
| { // for each subgroup |
| int wg_offset = sb_id * sbs; |
| if (last_subgroup_size && sb_id == sb_number - 1) |
| { |
| current_sbs = last_subgroup_size; |
| } |
| else |
| { |
| current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; |
| } |
| if (operation == BallotOp::ballot_bit_count |
| || operation == BallotOp::ballot_inclusive_scan |
| || operation == BallotOp::ballot_exclusive_scan) |
| { |
| set_randomdata_for_subgroup<Ty>(t, wg_offset, current_sbs); |
| } |
| else if (operation == BallotOp::ballot_find_lsb |
| || operation == BallotOp::ballot_find_msb) |
| { |
| // Regarding to the spec, find lsb and find msb result is |
| // undefined behavior if input value is zero, so generate |
| // only non-zero values. |
| for (wi_id = 0; wi_id < current_sbs; ++wi_id) |
| { |
| char x = (genrand_int32(gMTdata)) & 0xff; |
| // undefined behaviour in case of 0; |
| x = x ? x : 1; |
| memset(&t[wg_offset + wi_id], x, sizeof(Ty)); |
| } |
| } |
| else |
| { |
| log_error("Unknown operation..."); |
| } |
| } |
| |
| // Now map into work group using map from device |
| for (wi_id = 0; wi_id < lws; ++wi_id) |
| { |
| x[wi_id] = t[wi_id]; |
| } |
| |
| x += lws; |
| m += 4 * lws; |
| } |
| } |
| |
| static bs128 getImportantBits(cl_uint sub_group_local_id, |
| cl_uint sub_group_size) |
| { |
| bs128 mask; |
| if (operation == BallotOp::ballot_bit_count |
| || operation == BallotOp::ballot_find_lsb |
| || operation == BallotOp::ballot_find_msb) |
| { |
| for (cl_uint i = 0; i < sub_group_size; ++i) mask.set(i); |
| } |
| else if (operation == BallotOp::ballot_inclusive_scan |
| || operation == BallotOp::ballot_exclusive_scan) |
| { |
| for (cl_uint i = 0; i <= sub_group_local_id; ++i) mask.set(i); |
| if (operation == BallotOp::ballot_exclusive_scan) |
| mask.reset(sub_group_local_id); |
| } |
| return mask; |
| } |
| |
| static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m, |
| const WorkGroupParams &test_params) |
| { |
| int wi_id, wg_id, sb_id; |
| int gws = test_params.global_workgroup_size; |
| int lws = test_params.local_workgroup_size; |
| int sbs = test_params.subgroup_size; |
| int sb_number = (lws + sbs - 1) / sbs; |
| int non_uniform_size = gws % lws; |
| int wg_number = gws / lws; |
| wg_number = non_uniform_size ? wg_number + 1 : wg_number; |
| cl_uint4 expected_result, device_result; |
| int last_subgroup_size = 0; |
| int current_sbs = 0; |
| |
| for (wg_id = 0; wg_id < wg_number; ++wg_id) |
| { // for each work_group |
| if (non_uniform_size && wg_id == wg_number - 1) |
| { |
| set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws, |
| last_subgroup_size); |
| } |
| // Map to array indexed to array indexed by local ID and sub group |
| for (wi_id = 0; wi_id < lws; ++wi_id) |
| { // inside the work_group |
| // read host inputs for work_group |
| mx[wi_id] = x[wi_id]; |
| // read device outputs for work_group |
| my[wi_id] = y[wi_id]; |
| } |
| |
| for (sb_id = 0; sb_id < sb_number; ++sb_id) |
| { // for each subgroup |
| int wg_offset = sb_id * sbs; |
| if (last_subgroup_size && sb_id == sb_number - 1) |
| { |
| current_sbs = last_subgroup_size; |
| } |
| else |
| { |
| current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; |
| } |
| // Check result |
| expected_result = { 0, 0, 0, 0 }; |
| for (wi_id = 0; wi_id < current_sbs; ++wi_id) |
| { // for subgroup element |
| bs128 bs; |
| // convert cl_uint4 input into std::bitset<128> |
| bs |= bs128(mx[wg_offset + wi_id].s0) |
| | (bs128(mx[wg_offset + wi_id].s1) << 32) |
| | (bs128(mx[wg_offset + wi_id].s2) << 64) |
| | (bs128(mx[wg_offset + wi_id].s3) << 96); |
| bs &= getImportantBits(wi_id, current_sbs); |
| device_result = my[wg_offset + wi_id]; |
| if (operation == BallotOp::ballot_inclusive_scan |
| || operation == BallotOp::ballot_exclusive_scan |
| || operation == BallotOp::ballot_bit_count) |
| { |
| expected_result.s0 = bs.count(); |
| if (!compare(device_result, expected_result)) |
| { |
| log_error("ERROR: sub_group_%s " |
| "mismatch for local id %d in sub group " |
| "%d in group %d obtained {%d, %d, %d, " |
| "%d}, expected {%d, %d, %d, %d}\n", |
| operation_names(operation), wi_id, sb_id, |
| wg_id, device_result.s0, device_result.s1, |
| device_result.s2, device_result.s3, |
| expected_result.s0, expected_result.s1, |
| expected_result.s2, expected_result.s3); |
| return TEST_FAIL; |
| } |
| } |
| else if (operation == BallotOp::ballot_find_lsb) |
| { |
| for (int id = 0; id < current_sbs; ++id) |
| { |
| if (bs.test(id)) |
| { |
| expected_result.s0 = id; |
| break; |
| } |
| } |
| if (!compare(device_result, expected_result)) |
| { |
| log_error("ERROR: sub_group_ballot_find_lsb " |
| "mismatch for local id %d in sub group " |
| "%d in group %d obtained {%d, %d, %d, " |
| "%d}, expected {%d, %d, %d, %d}\n", |
| wi_id, sb_id, wg_id, device_result.s0, |
| device_result.s1, device_result.s2, |
| device_result.s3, expected_result.s0, |
| expected_result.s1, expected_result.s2, |
| expected_result.s3); |
| return TEST_FAIL; |
| } |
| } |
| else if (operation == BallotOp::ballot_find_msb) |
| { |
| for (int id = current_sbs - 1; id >= 0; --id) |
| { |
| if (bs.test(id)) |
| { |
| expected_result.s0 = id; |
| break; |
| } |
| } |
| if (!compare(device_result, expected_result)) |
| { |
| log_error("ERROR: sub_group_ballot_find_msb " |
| "mismatch for local id %d in sub group " |
| "%d in group %d obtained {%d, %d, %d, " |
| "%d}, expected {%d, %d, %d, %d}\n", |
| wi_id, sb_id, wg_id, device_result.s0, |
| device_result.s1, device_result.s2, |
| device_result.s3, expected_result.s0, |
| expected_result.s1, expected_result.s2, |
| expected_result.s3); |
| return TEST_FAIL; |
| } |
| } |
| } |
| } |
| x += lws; |
| y += lws; |
| m += 4 * lws; |
| } |
| log_info(" sub_group_ballot_%s(%s)... passed\n", |
| operation_names(operation), TypeManager<Ty>::name()); |
| return TEST_PASS; |
| } |
| }; |
| |
| // test mask functions |
| template <typename Ty, BallotOp operation> struct SMASK |
| { |
| static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params) |
| { |
| int wi_id, wg_id, l, sb_id; |
| int gws = test_params.global_workgroup_size; |
| int lws = test_params.local_workgroup_size; |
| int sbs = test_params.subgroup_size; |
| int sb_number = (lws + sbs - 1) / sbs; |
| int wg_number = gws / lws; |
| log_info(" get_sub_group_%s_mask...\n", operation_names(operation)); |
| for (wg_id = 0; wg_id < wg_number; ++wg_id) |
| { // for each work_group |
| for (sb_id = 0; sb_id < sb_number; ++sb_id) |
| { // for each subgroup |
| int wg_offset = sb_id * sbs; |
| int current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; |
| // Produce expected masks for each work item in the subgroup |
| for (wi_id = 0; wi_id < current_sbs; ++wi_id) |
| { |
| int midx = 4 * wg_offset + 4 * wi_id; |
| cl_uint max_sub_group_size = m[midx + 2]; |
| cl_uint4 expected_mask = { 0 }; |
| expected_mask = generate_bit_mask( |
| wi_id, operation_names(operation), max_sub_group_size); |
| set_value(t[wg_offset + wi_id], expected_mask); |
| } |
| } |
| |
| // Now map into work group using map from device |
| for (wi_id = 0; wi_id < lws; ++wi_id) |
| { |
| x[wi_id] = t[wi_id]; |
| } |
| x += lws; |
| m += 4 * lws; |
| } |
| } |
| |
| static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m, |
| const WorkGroupParams &test_params) |
| { |
| int wi_id, wg_id, sb_id; |
| int gws = test_params.global_workgroup_size; |
| int lws = test_params.local_workgroup_size; |
| int sbs = test_params.subgroup_size; |
| int sb_number = (lws + sbs - 1) / sbs; |
| Ty expected_result, device_result; |
| int wg_number = gws / lws; |
| |
| for (wg_id = 0; wg_id < wg_number; ++wg_id) |
| { // for each work_group |
| for (wi_id = 0; wi_id < lws; ++wi_id) |
| { // inside the work_group |
| mx[wi_id] = x[wi_id]; // read host inputs for work_group |
| my[wi_id] = y[wi_id]; // read device outputs for work_group |
| } |
| |
| for (sb_id = 0; sb_id < sb_number; ++sb_id) |
| { |
| int wg_offset = sb_id * sbs; |
| int current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; |
| |
| // Check result |
| for (wi_id = 0; wi_id < current_sbs; ++wi_id) |
| { // inside the subgroup |
| expected_result = |
| mx[wg_offset + wi_id]; // read host input for subgroup |
| device_result = |
| my[wg_offset |
| + wi_id]; // read device outputs for subgroup |
| if (!compare(device_result, expected_result)) |
| { |
| log_error("ERROR: get_sub_group_%s_mask... mismatch " |
| "for local id %d in sub group %d in group " |
| "%d, obtained %d, expected %d\n", |
| operation_names(operation), wi_id, sb_id, |
| wg_id, device_result, expected_result); |
| return TEST_FAIL; |
| } |
| } |
| } |
| x += lws; |
| y += lws; |
| m += 4 * lws; |
| } |
| log_info(" get_sub_group_%s_mask... passed\n", |
| operation_names(operation)); |
| return TEST_PASS; |
| } |
| }; |
| |
| static const char *bcast_non_uniform_source = |
| "__kernel void test_bcast_non_uniform(const __global Type *in, __global " |
| "int4 *xy, __global Type *out)\n" |
| "{\n" |
| " int gid = get_global_id(0);\n" |
| " XY(xy,gid);\n" |
| " Type x = in[gid];\n" |
| " if (xy[gid].x < NR_OF_ACTIVE_WORK_ITEMS) {\n" |
| " out[gid] = sub_group_non_uniform_broadcast(x, xy[gid].z);\n" |
| " } else {\n" |
| " out[gid] = sub_group_non_uniform_broadcast(x, xy[gid].w);\n" |
| " }\n" |
| "}\n"; |
| |
| static const char *bcast_first_source = |
| "__kernel void test_bcast_first(const __global Type *in, __global int4 " |
| "*xy, __global Type *out)\n" |
| "{\n" |
| " int gid = get_global_id(0);\n" |
| " XY(xy,gid);\n" |
| " Type x = in[gid];\n" |
| " if (xy[gid].x < NR_OF_ACTIVE_WORK_ITEMS) {\n" |
| " out[gid] = sub_group_broadcast_first(x);\n" |
| " } else {\n" |
| " out[gid] = sub_group_broadcast_first(x);\n" |
| " }\n" |
| "}\n"; |
| |
| static const char *ballot_bit_count_source = |
| "__kernel void test_sub_group_ballot_bit_count(const __global Type *in, " |
| "__global int4 *xy, __global Type *out)\n" |
| "{\n" |
| " int gid = get_global_id(0);\n" |
| " XY(xy,gid);\n" |
| " Type x = in[gid];\n" |
| " uint4 value = (uint4)(0,0,0,0);\n" |
| " value = (uint4)(sub_group_ballot_bit_count(x),0,0,0);\n" |
| " out[gid] = value;\n" |
| "}\n"; |
| |
| static const char *ballot_inclusive_scan_source = |
| "__kernel void test_sub_group_ballot_inclusive_scan(const __global Type " |
| "*in, __global int4 *xy, __global Type *out)\n" |
| "{\n" |
| " int gid = get_global_id(0);\n" |
| " XY(xy,gid);\n" |
| " Type x = in[gid];\n" |
| " uint4 value = (uint4)(0,0,0,0);\n" |
| " value = (uint4)(sub_group_ballot_inclusive_scan(x),0,0,0);\n" |
| " out[gid] = value;\n" |
| "}\n"; |
| |
| static const char *ballot_exclusive_scan_source = |
| "__kernel void test_sub_group_ballot_exclusive_scan(const __global Type " |
| "*in, __global int4 *xy, __global Type *out)\n" |
| "{\n" |
| " int gid = get_global_id(0);\n" |
| " XY(xy,gid);\n" |
| " Type x = in[gid];\n" |
| " uint4 value = (uint4)(0,0,0,0);\n" |
| " value = (uint4)(sub_group_ballot_exclusive_scan(x),0,0,0);\n" |
| " out[gid] = value;\n" |
| "}\n"; |
| |
| static const char *ballot_find_lsb_source = |
| "__kernel void test_sub_group_ballot_find_lsb(const __global Type *in, " |
| "__global int4 *xy, __global Type *out)\n" |
| "{\n" |
| " int gid = get_global_id(0);\n" |
| " XY(xy,gid);\n" |
| " Type x = in[gid];\n" |
| " uint4 value = (uint4)(0,0,0,0);\n" |
| " value = (uint4)(sub_group_ballot_find_lsb(x),0,0,0);\n" |
| " out[gid] = value;\n" |
| "}\n"; |
| |
| static const char *ballot_find_msb_source = |
| "__kernel void test_sub_group_ballot_find_msb(const __global Type *in, " |
| "__global int4 *xy, __global Type *out)\n" |
| "{\n" |
| " int gid = get_global_id(0);\n" |
| " XY(xy,gid);\n" |
| " Type x = in[gid];\n" |
| " uint4 value = (uint4)(0,0,0,0);" |
| " value = (uint4)(sub_group_ballot_find_msb(x),0,0,0);" |
| " out[gid] = value ;" |
| "}\n"; |
| |
| static const char *get_subgroup_ge_mask_source = |
| "__kernel void test_get_sub_group_ge_mask(const __global Type *in, " |
| "__global int4 *xy, __global Type *out)\n" |
| "{\n" |
| " int gid = get_global_id(0);\n" |
| " XY(xy,gid);\n" |
| " xy[gid].z = get_max_sub_group_size();\n" |
| " Type x = in[gid];\n" |
| " uint4 mask = get_sub_group_ge_mask();" |
| " out[gid] = mask;\n" |
| "}\n"; |
| |
| static const char *get_subgroup_gt_mask_source = |
| "__kernel void test_get_sub_group_gt_mask(const __global Type *in, " |
| "__global int4 *xy, __global Type *out)\n" |
| "{\n" |
| " int gid = get_global_id(0);\n" |
| " XY(xy,gid);\n" |
| " xy[gid].z = get_max_sub_group_size();\n" |
| " Type x = in[gid];\n" |
| " uint4 mask = get_sub_group_gt_mask();" |
| " out[gid] = mask;\n" |
| "}\n"; |
| |
| static const char *get_subgroup_le_mask_source = |
| "__kernel void test_get_sub_group_le_mask(const __global Type *in, " |
| "__global int4 *xy, __global Type *out)\n" |
| "{\n" |
| " int gid = get_global_id(0);\n" |
| " XY(xy,gid);\n" |
| " xy[gid].z = get_max_sub_group_size();\n" |
| " Type x = in[gid];\n" |
| " uint4 mask = get_sub_group_le_mask();" |
| " out[gid] = mask;\n" |
| "}\n"; |
| |
| static const char *get_subgroup_lt_mask_source = |
| "__kernel void test_get_sub_group_lt_mask(const __global Type *in, " |
| "__global int4 *xy, __global Type *out)\n" |
| "{\n" |
| " int gid = get_global_id(0);\n" |
| " XY(xy,gid);\n" |
| " xy[gid].z = get_max_sub_group_size();\n" |
| " Type x = in[gid];\n" |
| " uint4 mask = get_sub_group_lt_mask();" |
| " out[gid] = mask;\n" |
| "}\n"; |
| |
| static const char *get_subgroup_eq_mask_source = |
| "__kernel void test_get_sub_group_eq_mask(const __global Type *in, " |
| "__global int4 *xy, __global Type *out)\n" |
| "{\n" |
| " int gid = get_global_id(0);\n" |
| " XY(xy,gid);\n" |
| " xy[gid].z = get_max_sub_group_size();\n" |
| " Type x = in[gid];\n" |
| " uint4 mask = get_sub_group_eq_mask();" |
| " out[gid] = mask;\n" |
| "}\n"; |
| |
| static const char *ballot_source = |
| "__kernel void test_sub_group_ballot(const __global Type *in, " |
| "__global int4 *xy, __global Type *out)\n" |
| "{\n" |
| "uint4 full_ballot = sub_group_ballot(1);\n" |
| "uint divergence_mask;\n" |
| "uint4 partial_ballot;\n" |
| "uint gid = get_global_id(0);" |
| "XY(xy,gid);\n" |
| "if (get_sub_group_local_id() & 1) {\n" |
| " divergence_mask = 0xaaaaaaaa;\n" |
| " partial_ballot = sub_group_ballot(1);\n" |
| "} else {\n" |
| " divergence_mask = 0x55555555;\n" |
| " partial_ballot = sub_group_ballot(1);\n" |
| "}\n" |
| " size_t lws = get_local_size(0);\n" |
| "uint4 masked_ballot = full_ballot;\n" |
| "masked_ballot.x &= divergence_mask;\n" |
| "masked_ballot.y &= divergence_mask;\n" |
| "masked_ballot.z &= divergence_mask;\n" |
| "masked_ballot.w &= divergence_mask;\n" |
| "out[gid] = all(masked_ballot == partial_ballot);\n" |
| |
| "} \n"; |
| |
| static const char *ballot_source_inverse = |
| "__kernel void test_sub_group_ballot_inverse(const __global " |
| "Type *in, " |
| "__global int4 *xy, __global Type *out)\n" |
| "{\n" |
| " int gid = get_global_id(0);\n" |
| " XY(xy,gid);\n" |
| " Type x = in[gid];\n" |
| " uint4 value = (uint4)(10,0,0,0);\n" |
| " if (get_sub_group_local_id() & 1) {" |
| " uint4 partial_ballot_mask = " |
| "(uint4)(0xAAAAAAAA,0xAAAAAAAA,0xAAAAAAAA,0xAAAAAAAA);" |
| " if (sub_group_inverse_ballot(partial_ballot_mask)) {\n" |
| " value = (uint4)(1,0,0,1);\n" |
| " } else {\n" |
| " value = (uint4)(0,0,0,1);\n" |
| " }\n" |
| " } else {\n" |
| " uint4 partial_ballot_mask = " |
| "(uint4)(0x55555555,0x55555555,0x55555555,0x55555555);" |
| " if (sub_group_inverse_ballot(partial_ballot_mask)) {\n" |
| " value = (uint4)(1,0,0,2);\n" |
| " } else {\n" |
| " value = (uint4)(0,0,0,2);\n" |
| " }\n" |
| " }\n" |
| " out[gid] = value;\n" |
| "}\n"; |
| |
| static const char *ballot_bit_extract_source = |
| "__kernel void test_sub_group_ballot_bit_extract(const __global Type *in, " |
| "__global int4 *xy, __global Type *out)\n" |
| "{\n" |
| " int gid = get_global_id(0);\n" |
| " XY(xy,gid);\n" |
| " Type x = in[gid];\n" |
| " uint index = xy[gid].z;\n" |
| " uint4 value = (uint4)(10,0,0,0);\n" |
| " if (get_sub_group_local_id() & 1) {" |
| " if (sub_group_ballot_bit_extract(x, xy[gid].z)) {\n" |
| " value = (uint4)(1,0,0,1);\n" |
| " } else {\n" |
| " value = (uint4)(0,0,0,1);\n" |
| " }\n" |
| " } else {\n" |
| " if (sub_group_ballot_bit_extract(x, xy[gid].w)) {\n" |
| " value = (uint4)(1,0,0,2);\n" |
| " } else {\n" |
| " value = (uint4)(0,0,0,2);\n" |
| " }\n" |
| " }\n" |
| " out[gid] = value;\n" |
| "}\n"; |
| |
| template <typename T> int run_non_uniform_broadcast_for_type(RunTestForType rft) |
| { |
| int error = |
| rft.run_impl<T, BC<T, SubgroupsBroadcastOp::non_uniform_broadcast>>( |
| "test_bcast_non_uniform", bcast_non_uniform_source); |
| return error; |
| } |
| |
| |
| } |
| |
| int test_subgroup_functions_ballot(cl_device_id device, cl_context context, |
| cl_command_queue queue, int num_elements) |
| { |
| std::vector<std::string> required_extensions = { "cl_khr_subgroup_ballot" }; |
| constexpr size_t global_work_size = 170; |
| constexpr size_t local_work_size = 64; |
| WorkGroupParams test_params(global_work_size, local_work_size, |
| required_extensions); |
| RunTestForType rft(device, context, queue, num_elements, test_params); |
| |
| // non uniform broadcast functions |
| int error = run_non_uniform_broadcast_for_type<cl_int>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_int2>(rft); |
| error |= run_non_uniform_broadcast_for_type<subgroups::cl_int3>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_int4>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_int8>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_int16>(rft); |
| |
| error |= run_non_uniform_broadcast_for_type<cl_uint>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_uint2>(rft); |
| error |= run_non_uniform_broadcast_for_type<subgroups::cl_uint3>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_uint4>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_uint8>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_uint16>(rft); |
| |
| error |= run_non_uniform_broadcast_for_type<cl_char>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_char2>(rft); |
| error |= run_non_uniform_broadcast_for_type<subgroups::cl_char3>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_char4>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_char8>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_char16>(rft); |
| |
| error |= run_non_uniform_broadcast_for_type<cl_uchar>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_uchar2>(rft); |
| error |= run_non_uniform_broadcast_for_type<subgroups::cl_uchar3>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_uchar4>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_uchar8>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_uchar16>(rft); |
| |
| error |= run_non_uniform_broadcast_for_type<cl_short>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_short2>(rft); |
| error |= run_non_uniform_broadcast_for_type<subgroups::cl_short3>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_short4>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_short8>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_short16>(rft); |
| |
| error |= run_non_uniform_broadcast_for_type<cl_ushort>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_ushort2>(rft); |
| error |= run_non_uniform_broadcast_for_type<subgroups::cl_ushort3>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_ushort4>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_ushort8>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_ushort16>(rft); |
| |
| error |= run_non_uniform_broadcast_for_type<cl_long>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_long2>(rft); |
| error |= run_non_uniform_broadcast_for_type<subgroups::cl_long3>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_long4>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_long8>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_long16>(rft); |
| |
| error |= run_non_uniform_broadcast_for_type<cl_ulong>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_ulong2>(rft); |
| error |= run_non_uniform_broadcast_for_type<subgroups::cl_ulong3>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_ulong4>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_ulong8>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_ulong16>(rft); |
| |
| error |= run_non_uniform_broadcast_for_type<cl_float>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_float2>(rft); |
| error |= run_non_uniform_broadcast_for_type<subgroups::cl_float3>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_float4>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_float8>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_float16>(rft); |
| |
| error |= run_non_uniform_broadcast_for_type<cl_double>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_double2>(rft); |
| error |= run_non_uniform_broadcast_for_type<subgroups::cl_double3>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_double4>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_double8>(rft); |
| error |= run_non_uniform_broadcast_for_type<cl_double16>(rft); |
| |
| error |= run_non_uniform_broadcast_for_type<subgroups::cl_half>(rft); |
| error |= run_non_uniform_broadcast_for_type<subgroups::cl_half2>(rft); |
| error |= run_non_uniform_broadcast_for_type<subgroups::cl_half3>(rft); |
| error |= run_non_uniform_broadcast_for_type<subgroups::cl_half4>(rft); |
| error |= run_non_uniform_broadcast_for_type<subgroups::cl_half8>(rft); |
| error |= run_non_uniform_broadcast_for_type<subgroups::cl_half16>(rft); |
| |
| // broadcast first functions |
| error |= |
| rft.run_impl<cl_int, BC<cl_int, SubgroupsBroadcastOp::broadcast_first>>( |
| "test_bcast_first", bcast_first_source); |
| error |= rft.run_impl<cl_uint, |
| BC<cl_uint, SubgroupsBroadcastOp::broadcast_first>>( |
| "test_bcast_first", bcast_first_source); |
| error |= rft.run_impl<cl_long, |
| BC<cl_long, SubgroupsBroadcastOp::broadcast_first>>( |
| "test_bcast_first", bcast_first_source); |
| error |= rft.run_impl<cl_ulong, |
| BC<cl_ulong, SubgroupsBroadcastOp::broadcast_first>>( |
| "test_bcast_first", bcast_first_source); |
| error |= rft.run_impl<cl_short, |
| BC<cl_short, SubgroupsBroadcastOp::broadcast_first>>( |
| "test_bcast_first", bcast_first_source); |
| error |= rft.run_impl<cl_ushort, |
| BC<cl_ushort, SubgroupsBroadcastOp::broadcast_first>>( |
| "test_bcast_first", bcast_first_source); |
| error |= rft.run_impl<cl_char, |
| BC<cl_char, SubgroupsBroadcastOp::broadcast_first>>( |
| "test_bcast_first", bcast_first_source); |
| error |= rft.run_impl<cl_uchar, |
| BC<cl_uchar, SubgroupsBroadcastOp::broadcast_first>>( |
| "test_bcast_first", bcast_first_source); |
| error |= rft.run_impl<cl_float, |
| BC<cl_float, SubgroupsBroadcastOp::broadcast_first>>( |
| "test_bcast_first", bcast_first_source); |
| error |= rft.run_impl<cl_double, |
| BC<cl_double, SubgroupsBroadcastOp::broadcast_first>>( |
| "test_bcast_first", bcast_first_source); |
| error |= rft.run_impl< |
| subgroups::cl_half, |
| BC<subgroups::cl_half, SubgroupsBroadcastOp::broadcast_first>>( |
| "test_bcast_first", bcast_first_source); |
| |
| // mask functions |
| error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::eq_mask>>( |
| "test_get_sub_group_eq_mask", get_subgroup_eq_mask_source); |
| error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::ge_mask>>( |
| "test_get_sub_group_ge_mask", get_subgroup_ge_mask_source); |
| error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::gt_mask>>( |
| "test_get_sub_group_gt_mask", get_subgroup_gt_mask_source); |
| error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::le_mask>>( |
| "test_get_sub_group_le_mask", get_subgroup_le_mask_source); |
| error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::lt_mask>>( |
| "test_get_sub_group_lt_mask", get_subgroup_lt_mask_source); |
| |
| // ballot functions |
| error |= rft.run_impl<cl_uint, BALLOT<cl_uint>>("test_sub_group_ballot", |
| ballot_source); |
| error |= rft.run_impl<cl_uint4, |
| BALLOT_INVERSE<cl_uint4, BallotOp::inverse_ballot>>( |
| "test_sub_group_ballot_inverse", ballot_source_inverse); |
| error |= rft.run_impl< |
| cl_uint4, BALLOT_BIT_EXTRACT<cl_uint4, BallotOp::ballot_bit_extract>>( |
| "test_sub_group_ballot_bit_extract", ballot_bit_extract_source); |
| error |= rft.run_impl< |
| cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_bit_count>>( |
| "test_sub_group_ballot_bit_count", ballot_bit_count_source); |
| error |= rft.run_impl< |
| cl_uint4, |
| BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_inclusive_scan>>( |
| "test_sub_group_ballot_inclusive_scan", ballot_inclusive_scan_source); |
| error |= rft.run_impl< |
| cl_uint4, |
| BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_exclusive_scan>>( |
| "test_sub_group_ballot_exclusive_scan", ballot_exclusive_scan_source); |
| error |= rft.run_impl< |
| cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_find_lsb>>( |
| "test_sub_group_ballot_find_lsb", ballot_find_lsb_source); |
| error |= rft.run_impl< |
| cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_find_msb>>( |
| "test_sub_group_ballot_find_msb", ballot_find_msb_source); |
| return error; |
| } |