| #!/usr/bin/env python3 |
| # SPDX-License-Identifier: Apache-2.0 |
| # ----------------------------------------------------------------------------- |
| # Copyright 2021 Arm Limited |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); you may not |
| # use this file except in compliance with the License. You may obtain a copy |
| # of the License at: |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| # License for the specific language governing permissions and limitations |
| # under the License. |
| # ----------------------------------------------------------------------------- |
| """ |
| The ``astc_trace_analysis`` utility provides a tool to analyze trace files. |
| |
| WARNING: Trace files are an engineering tool, and not part of the standard |
| product, so traces and their associated tools are volatile and may change |
| significantly without notice. |
| """ |
| |
| import argparse |
| from collections import defaultdict as ddict |
| import json |
| import numpy as np |
| import sys |
| |
| QUANT_TABLE = { |
| 0: 2, |
| 1: 3, |
| 2: 4, |
| 3: 5, |
| 4: 6, |
| 5: 8, |
| 6: 10, |
| 7: 12, |
| 8: 16, |
| 9: 20, |
| 10: 24, |
| 11: 32 |
| } |
| |
| CHANNEL_TABLE = { |
| 0: "R", |
| 1: "G", |
| 2: "B", |
| 3: "A" |
| } |
| |
| class Trace: |
| |
| def __init__(self, block_x, block_y, block_z): |
| self.block_x = block_x |
| self.block_y = block_y |
| self.block_z = block_z |
| self.blocks = [] |
| |
| def add_block(self, block): |
| self.blocks.append(block) |
| |
| def __getitem__(self, i): |
| return self.blocks[i] |
| |
| def __delitem__(self, i): |
| del self.blocks[i] |
| |
| def __len__(self): |
| return len(self.blocks) |
| |
| class Block: |
| |
| def __init__(self, pos_x, pos_y, pos_z, error_target): |
| self.pos_x = pos_x |
| self.pos_y = pos_y |
| self.pos_z = pos_z |
| |
| self.raw_min = None |
| self.raw_max = None |
| |
| self.ldr_min = None |
| self.ldr_max = None |
| |
| self.error_target = error_target |
| self.passes = [] |
| self.qualityHit = None |
| |
| def add_minimums(self, r, g, b, a): |
| self.raw_min = (r, g, b, a) |
| |
| def ldr(x): |
| cmax = 65535.0 |
| return int((r / cmax) * 255.0) |
| |
| self.ldr_min = (ldr(r), ldr(g), ldr(b), ldr(a)) |
| |
| def add_maximums(self, r, g, b, a): |
| self.raw_max = (r, g, b, a) |
| |
| def ldr(x): |
| cmax = 65535.0 |
| return int((r / cmax) * 255.0) |
| |
| self.ldr_max = (ldr(r), ldr(g), ldr(b), ldr(a)) |
| |
| def add_pass(self, pas): |
| self.passes.append(pas) |
| |
| def __getitem__(self, i): |
| return self.passes[i] |
| |
| def __delitem__(self, i): |
| del self.passes[i] |
| |
| def __len__(self): |
| return len(self.passes) |
| |
| |
| class Pass: |
| |
| def __init__(self, partitions, partition, planes, target_hit, mode, component): |
| self.partitions = partitions |
| self.partition_index = 0 if partition is None else partition |
| self.planes = planes |
| self.plane2_component = component |
| self.target_hit = target_hit |
| self.search_mode = mode |
| self.candidates = [] |
| |
| def add_candidate(self, candidate): |
| self.candidates.append(candidate) |
| |
| def __getitem__(self, i): |
| return self.candidates[i] |
| |
| def __delitem__(self, i): |
| del self.candidates[i] |
| |
| def __len__(self): |
| return len(self.candidates) |
| |
| |
| class Candidate: |
| |
| def __init__(self, weight_x, weight_y, weight_z, weight_quant): |
| self.weight_x = weight_x |
| self.weight_y = weight_y |
| self.weight_z = weight_z |
| self.weight_quant = weight_quant |
| self.refinement_errors = [] |
| |
| def add_refinement(self, errorval): |
| self.refinement_errors.append(errorval) |
| |
| |
| def get_attrib(data, name, multiple=False, hard_fail=True): |
| results = [] |
| for attrib in data: |
| if len(attrib) == 2 and attrib[0] == name: |
| results.append(attrib[1]) |
| |
| if not results: |
| if hard_fail: |
| print(json.dumps(data, indent=2)) |
| assert False, "Attribute %s not found" % name |
| if multiple: |
| return list() |
| return None |
| |
| if not multiple: |
| if len(results) > 1: |
| print(json.dumps(data, indent=2)) |
| assert False, "Attribute %s found %u times" % (name, len(results)) |
| return results[0] |
| |
| return results |
| |
| |
| def rev_enumerate(seq): |
| return zip(reversed(range(len(seq))), reversed(seq)) |
| |
| def foreach_block(data): |
| |
| for block in data: |
| yield block |
| |
| def foreach_pass(data): |
| |
| for block in data: |
| for pas in block: |
| yield (block, pas) |
| |
| def foreach_candidate(data): |
| |
| for block in data: |
| for pas in block: |
| # Special case - None candidates for 0 partition |
| if not len(pas): |
| yield (block, pas, None) |
| |
| for candidate in pas: |
| yield (block, pas, candidate) |
| |
| def get_node(data, name, multiple=False, hard_fail=True): |
| results = [] |
| for attrib in data: |
| if len(attrib) == 3 and attrib[0] == "node" and attrib[1] == name: |
| results.append(attrib[2]) |
| |
| if not results: |
| if hard_fail: |
| print(json.dumps(data, indent=2)) |
| assert False, "Node %s not found" % name |
| return None |
| |
| if not multiple: |
| if len(results) > 1: |
| print(json.dumps(data, indent=2)) |
| assert False, "Node %s found %u times" % (name, len(results)) |
| return results[0] |
| |
| return results |
| |
| |
| def find_best_pass_and_candidate(block): |
| explicit_pass = None |
| |
| best_error = 1e30 |
| best_pass = None |
| best_candidate = None |
| |
| for pas in block: |
| # Special case for constant color blocks - no trial candidates |
| if pas.target_hit and pas.partitions == 0: |
| return (pas, None) |
| |
| for candidate in pas: |
| errorval = candidate.refinement_errors[-1] |
| if errorval <= best_error: |
| best_error = errorval |
| best_pass = pas |
| best_candidate = candidate |
| |
| # Every other return type must have both best pass and best candidate |
| assert (best_pass and best_candidate) |
| return (best_pass, best_candidate) |
| |
| |
| def generate_database(data): |
| # Skip header |
| assert(data[0] == "node") |
| assert(data[1] == "root") |
| data = data[2] |
| |
| bx = get_attrib(data, "block_x") |
| by = get_attrib(data, "block_y") |
| bz = get_attrib(data, "block_z") |
| dbStruct = Trace(bx, by, bz) |
| |
| for block in get_node(data, "block", True): |
| px = get_attrib(block, "pos_x") |
| py = get_attrib(block, "pos_y") |
| pz = get_attrib(block, "pos_z") |
| |
| minr = get_attrib(block, "min_r") |
| ming = get_attrib(block, "min_g") |
| minb = get_attrib(block, "min_b") |
| mina = get_attrib(block, "min_a") |
| |
| maxr = get_attrib(block, "max_r") |
| maxg = get_attrib(block, "max_g") |
| maxb = get_attrib(block, "max_b") |
| maxa = get_attrib(block, "max_a") |
| |
| et = get_attrib(block, "tune_error_threshold") |
| |
| blockStruct = Block(px, py, pz, et) |
| blockStruct.add_minimums(minr, ming, minb, mina) |
| blockStruct.add_maximums(maxr, maxg, maxb, maxa) |
| dbStruct.add_block(blockStruct) |
| |
| for pas in get_node(block, "pass", True): |
| # Don't copy across passes we skipped due to heuristics |
| skipped = get_attrib(pas, "skip", False, False) |
| if skipped: |
| continue |
| |
| prts = get_attrib(pas, "partition_count") |
| prti = get_attrib(pas, "partition_index", False, False) |
| plns = get_attrib(pas, "plane_count") |
| chan = get_attrib(pas, "plane_component", False, plns > 2) |
| mode = get_attrib(pas, "search_mode", False, False) |
| ehit = get_attrib(pas, "exit", False, False) == "quality hit" |
| |
| passStruct = Pass(prts, prti, plns, ehit, mode, chan) |
| blockStruct.add_pass(passStruct) |
| |
| # Constant color blocks don't have any candidates |
| if prts == 0: |
| continue |
| |
| for candidate in get_node(pas, "candidate", True): |
| # Don't copy across candidates we couldn't encode |
| failed = get_attrib(candidate, "failed", False, False) |
| if failed: |
| continue |
| |
| wx = get_attrib(candidate, "weight_x") |
| wy = get_attrib(candidate, "weight_y") |
| wz = get_attrib(candidate, "weight_z") |
| wq = QUANT_TABLE[get_attrib(candidate, "weight_quant")] |
| epre = get_attrib(candidate, "error_prerealign", True, False) |
| epst = get_attrib(candidate, "error_postrealign", True, False) |
| |
| candStruct = Candidate(wx, wy, wz, wq) |
| passStruct.add_candidate(candStruct) |
| for value in epre: |
| candStruct.add_refinement(value) |
| for value in epst: |
| candStruct.add_refinement(value) |
| |
| return dbStruct |
| |
| |
| def filter_database(data): |
| |
| for block in data: |
| best_pass, best_candidate = find_best_pass_and_candidate(block) |
| |
| for i, pas in rev_enumerate(block): |
| if pas != best_pass: |
| del block[i] |
| continue |
| |
| if best_candidate is None: |
| continue |
| |
| for j, candidate in rev_enumerate(pas): |
| if candidate != best_candidate: |
| del pas[j] |
| |
| |
| def generate_pass_statistics(data): |
| pass |
| |
| |
| def generate_feature_statistics(data): |
| # ------------------------------------------------------------------------- |
| # Config |
| print("Compressor Config") |
| print("=================") |
| |
| if data.block_z > 1: |
| dat = (data.block_x, data.block_y, data.block_z) |
| print(" - Block size: %ux%ux%u" % dat) |
| else: |
| dat = (data.block_x, data.block_y) |
| print(" - Block size: %ux%u" % dat) |
| |
| print("") |
| |
| # ------------------------------------------------------------------------- |
| # Block metrics |
| result = ddict(int) |
| |
| RANGE_QUANT = 16 |
| |
| for block in foreach_block(data): |
| ranges = [] |
| for i in range(0, 4): |
| ranges.append(block.ldr_max[i] - block.ldr_min[i]) |
| |
| max_range = max(ranges) |
| max_range = int(max_range / RANGE_QUANT) * RANGE_QUANT |
| |
| result[max_range] += 1 |
| |
| print("Channel Range") |
| print("=============") |
| keys = sorted(result.keys()) |
| for key in keys: |
| dat = (key, key + RANGE_QUANT - 1, result[key]) |
| print(" - %3u-%3u dynamic range = %6u blocks" % dat) |
| |
| print("") |
| |
| # ------------------------------------------------------------------------- |
| # Partition usage |
| result_totals = ddict(int) |
| results = ddict(lambda: ddict(int)) |
| |
| for _, pas in foreach_pass(data): |
| result_totals[pas.partitions] += 1 |
| results[pas.partitions][pas.partition_index] += 1 |
| |
| print("Partition Count") |
| print("===============") |
| keys = sorted(result_totals.keys()) |
| for key in keys: |
| dat = (key, result_totals[key], len(results[key])) |
| print(" - %u partition(s) = %6u blocks / %4u indicies" % dat) |
| |
| print("") |
| |
| # ------------------------------------------------------------------------- |
| # Plane usage |
| result_count = ddict(lambda: ddict(int)) |
| result_channel = ddict(lambda: ddict(int)) |
| |
| for _, pas in foreach_pass(data): |
| result_count[pas.partitions][pas.planes] += 1 |
| if (pas.planes > 1): |
| result_channel[pas.partitions][pas.plane2_component] += 1 |
| |
| print("Plane Usage") |
| print("===========") |
| keys = sorted(result_count.keys()) |
| for key in keys: |
| keys2 = sorted(result_count[key]) |
| for key2 in keys2: |
| val2 = result_count[key][key2] |
| dat = (key, key2, val2) |
| print(" - %u partition(s) %u plane(s) = %6u blocks" % dat) |
| if key2 == 2: |
| keys3 = sorted(result_channel[key]) |
| for key3 in keys3: |
| dat = (CHANNEL_TABLE[key3], result_channel[key][key3]) |
| print(" - %s plane = %6u blocks" % dat) |
| |
| print("") |
| |
| # ------------------------------------------------------------------------- |
| # Decimation usage |
| decim_count = ddict(lambda: ddict(int)) |
| quant_count = ddict(lambda: ddict(lambda: ddict(int))) |
| |
| |
| MERGE_ROTATIONS = True |
| |
| for _, pas, can in foreach_candidate(data): |
| # Skip constant color blocks |
| if can is None: |
| continue |
| |
| wx = can.weight_x |
| wy = can.weight_y |
| |
| if MERGE_ROTATIONS and wx < wy: |
| wx, wy = wy, wx |
| |
| decim_count[wx][wy] += 1 |
| quant_count[wx][wy][can.weight_quant] += 1 |
| |
| print("Decimation Usage") |
| print("================") |
| |
| if MERGE_ROTATIONS: |
| print(" - Note: data merging grid rotations") |
| |
| x_keys = sorted(decim_count.keys()) |
| for x_key in x_keys: |
| y_keys = sorted(decim_count[x_key]) |
| |
| for y_key in y_keys: |
| count = decim_count[x_key][y_key] |
| dat = (x_key, y_key, count) |
| print(" - %ux%u weights = %6u blocks" % dat) |
| |
| q_keys = sorted(quant_count[x_key][y_key]) |
| for q_key in q_keys: |
| count = quant_count[x_key][y_key][q_key] |
| dat = (q_key, count) |
| print(" - %2u quant range = %6u blocks" % dat) |
| |
| print("") |
| |
| # ------------------------------------------------------------------------- |
| # Refinement usage |
| |
| total_count = 0 |
| better_count = 0 |
| could_have_count = 0 |
| success_count = 0 |
| |
| refinement_step = [] |
| |
| for block, pas, candidate in foreach_candidate(data): |
| # Ignore zero partition blocks - they don't use refinement |
| if not candidate: |
| continue |
| |
| target_error = block.error_target |
| start_error = candidate.refinement_errors[0] |
| end_error = candidate.refinement_errors[-1] |
| |
| rpf = float(start_error - end_error) / float(len(candidate.refinement_errors)) |
| rpf = abs(rpf) |
| refinement_step.append(rpf / start_error) |
| |
| total_count += 1 |
| if end_error <= start_error: |
| better_count += 1 |
| |
| if end_error <= target_error: |
| success_count += 1 |
| else: |
| for refinement in candidate.refinement_errors: |
| if refinement <= target_error: |
| could_have_count += 1 |
| break |
| |
| |
| print("Refinement Usage") |
| print("================") |
| print(" - %u refinements(s)" % total_count) |
| print(" - %u refinements(s) improved" % better_count) |
| print(" - %u refinements(s) worsened" % (total_count - better_count)) |
| print(" - %u refinements(s) could hit target, but didn't" % could_have_count) |
| print(" - %u refinements(s) hit target" % success_count) |
| print(" - %f mean step improvement" % np.mean(refinement_step)) |
| |
| |
| def parse_command_line(): |
| """ |
| Parse the command line. |
| |
| Returns: |
| Namespace: The parsed command line container. |
| """ |
| parser = argparse.ArgumentParser() |
| |
| parser.add_argument("trace", type=argparse.FileType("r"), |
| help="The trace file to analyze") |
| |
| return parser.parse_args() |
| |
| |
| def main(): |
| """ |
| The main function. |
| |
| Returns: |
| int: The process return code. |
| """ |
| args = parse_command_line() |
| |
| data = json.load(args.trace) |
| db = generate_database(data) |
| filter_database(db) |
| |
| generate_feature_statistics(db) |
| |
| return 0 |
| |
| |
| if __name__ == "__main__": |
| sys.exit(main()) |