Optimize coefficient decoding
diff --git a/src/env.h b/src/env.h
index 2d4cc26..bd8ef08 100644
--- a/src/env.h
+++ b/src/env.h
@@ -609,25 +609,12 @@
     }
 }
 
-static inline int get_coef_nz_ctx(uint8_t *const levels, const int scan_idx,
-                                  const int rc, const int is_eob,
+static inline int get_coef_nz_ctx(uint8_t *const levels,
                                   const enum RectTxfmSize tx,
-                                  const enum TxClass tx_class)
+                                  const enum TxClass tx_class,
+                                  const int x, const int y,
+                                  const ptrdiff_t stride)
 {
-    const TxfmInfo *const t_dim = &dav1d_txfm_dimensions[tx];
-
-    if (is_eob) {
-        if (scan_idx == 0)         return 0;
-        const int eighth_sz = imin(t_dim->w, 8) * imin(t_dim->h, 8) * 2;
-        if (scan_idx <= eighth_sz) return 1;
-        const int quart_sz = eighth_sz * 2;
-        if (scan_idx <= quart_sz)  return 2;
-        return 3;
-    }
-
-    const int x = rc >> (2 + imin(t_dim->lh, 3));
-    const int y = rc & (4 * imin(t_dim->h, 8) - 1);
-    const ptrdiff_t stride = 4 * (imin(t_dim->h, 8) + 1);
     static const uint8_t offsets[3][5][2 /* x, y */] = {
         [TX_CLASS_2D] = {
             { 0, 1 }, { 1, 0 }, { 2, 0 }, { 0, 2 }, { 1, 1 }
@@ -643,8 +630,7 @@
         mag += imin(levels[(x + off[i][0]) * stride + (y + off[i][1])], 3);
     const int ctx = imin((mag + 1) >> 1, 4);
     if (tx_class == TX_CLASS_2D) {
-        return !rc ? 0 :
-            dav1d_nz_map_ctx_offset[tx][imin(y, 4)][imin(x, 4)] + ctx;
+        return dav1d_nz_map_ctx_offset[tx][imin(y, 4)][imin(x, 4)] + ctx;
     } else {
         return 26 + imin((tx_class == TX_CLASS_V) ? y : x, 2) * 5 + ctx;
     }
@@ -686,13 +672,10 @@
 }
 
 static inline int get_br_ctx(const uint8_t *const levels,
-                             const int rc, const enum RectTxfmSize tx,
-                             const enum TxClass tx_class)
+                             const int ac, const enum TxClass tx_class,
+                             const int x, const int y,
+                             const ptrdiff_t stride)
 {
-    const TxfmInfo *const t_dim = &dav1d_txfm_dimensions[tx];
-    const int x = rc >> (imin(t_dim->lh, 3) + 2);
-    const int y = rc & (4 * imin(t_dim->h, 8) - 1);
-    const int stride = 4 * (imin(t_dim->h, 8) + 1);
     int mag = 0;
     static const uint8_t offsets_from_txclass[3][3][2] = {
         [TX_CLASS_2D] = { { 0, 1 }, { 1, 0 }, { 1, 1 } },
@@ -704,7 +687,7 @@
         mag += levels[(x + offsets[i][1]) * stride + y + offsets[i][0]];
 
     mag = imin((mag + 1) >> 1, 6);
-    if (rc == 0) return mag;
+    if (!ac) return mag;
     switch (tx_class) {
     case TX_CLASS_2D:
         if (y < 2 && x < 2) return mag + 7;
diff --git a/src/recon_tmpl.c b/src/recon_tmpl.c
index 0cadaad..751ffab 100644
--- a/src/recon_tmpl.c
+++ b/src/recon_tmpl.c
@@ -69,19 +69,19 @@
     const TxfmInfo *const t_dim = &dav1d_txfm_dimensions[tx];
     const int dbg = DEBUG_BLOCK_INFO && plane && 0;
 
-    if (dbg) printf("Start: r=%d\n", ts->msac.rng);
+    if (dbg)
+        printf("Start: r=%d\n", ts->msac.rng);
 
     // does this block have any non-zero coefficients
     const int sctx = get_coef_skip_ctx(t_dim, bs, a, l, chroma, f->cur.p.layout);
     const int all_skip = dav1d_msac_decode_bool_adapt(&ts->msac,
                              ts->cdf.coef.skip[t_dim->ctx][sctx]);
     if (dbg)
-    printf("Post-non-zero[%d][%d][%d]: r=%d\n",
-           t_dim->ctx, sctx, all_skip, ts->msac.rng);
+        printf("Post-non-zero[%d][%d][%d]: r=%d\n",
+               t_dim->ctx, sctx, all_skip, ts->msac.rng);
     if (all_skip) {
         *res_ctx = 0x40;
-        *txtp = f->frame_hdr->segmentation.lossless[b->seg_id] ? WHT_WHT :
-                                                                DCT_DCT;
+        *txtp = f->frame_hdr->segmentation.lossless[b->seg_id] ? WHT_WHT : DCT_DCT;
         return -1;
     }
 
@@ -111,9 +111,9 @@
                      dav1d_msac_decode_symbol_adapt16)(&ts->msac, txtp_cdf, set_cnt);
 
             if (dbg)
-            printf("Post-txtp[%d->%d][%d->%d][%d][%d->%d]: r=%d\n",
-                   set, set_idx, tx, t_dim->min, intra ? (int)y_mode_nofilt : -1,
-                   idx, dav1d_tx_types_per_set[set][idx], ts->msac.rng);
+                printf("Post-txtp[%d->%d][%d->%d][%d][%d->%d]: r=%d\n",
+                       set, set_idx, tx, t_dim->min, intra ? (int)y_mode_nofilt : -1,
+                       idx, dav1d_tx_types_per_set[set][idx], ts->msac.rng);
         }
         *txtp = dav1d_tx_types_per_set[set][idx];
     }
@@ -140,26 +140,20 @@
 #undef case_sz
     }
     if (dbg)
-    printf("Post-eob_bin_%d[%d][%d][%d]: r=%d\n",
-           16 << tx2dszctx, chroma, is_1d, eob_bin, ts->msac.rng);
+        printf("Post-eob_bin_%d[%d][%d][%d]: r=%d\n",
+               16 << tx2dszctx, chroma, is_1d, eob_bin, ts->msac.rng);
     int eob;
     if (eob_bin > 1) {
-        eob = 1 << (eob_bin - 1);
         uint16_t *const eob_hi_bit_cdf =
             ts->cdf.coef.eob_hi_bit[t_dim->ctx][chroma][eob_bin];
-        const int eob_hi_bit = dav1d_msac_decode_bool_adapt(&ts->msac,
-                                                            eob_hi_bit_cdf);
+        const int eob_hi_bit = dav1d_msac_decode_bool_adapt(&ts->msac, eob_hi_bit_cdf);
         if (dbg)
-        printf("Post-eob_hi_bit[%d][%d][%d][%d]: r=%d\n",
-               t_dim->ctx, chroma, eob_bin, eob_hi_bit, ts->msac.rng);
-        unsigned mask = eob >> 1;
-        if (eob_hi_bit) eob |= mask;
-        for (mask >>= 1; mask; mask >>= 1) {
-            const int eob_bit = dav1d_msac_decode_bool_equi(&ts->msac);
-            if (eob_bit) eob |= mask;
-        }
+            printf("Post-eob_hi_bit[%d][%d][%d][%d]: r=%d\n",
+                   t_dim->ctx, chroma, eob_bin, eob_hi_bit, ts->msac.rng);
+        eob = ((eob_hi_bit | 2) << (eob_bin - 2)) |
+              dav1d_msac_decode_bools(&ts->msac, eob_bin - 2);
         if (dbg)
-        printf("Post-eob[%d]: r=%d\n", eob, ts->msac.rng);
+            printf("Post-eob[%d]: r=%d\n", eob, ts->msac.rng);
     } else {
         eob = eob_bin;
     }
@@ -168,98 +162,180 @@
     uint16_t (*const br_cdf)[5] =
         ts->cdf.coef.br_tok[imin(t_dim->ctx, 3)][chroma];
     const int16_t *const scan = dav1d_scans[tx][tx_class];
-    uint8_t levels[36 * 36];
-    ptrdiff_t stride = 4 * (imin(t_dim->h, 8) + 1);
-    memset(levels, 0, stride * 4 * (imin(t_dim->w, 8) + 1));
-    const int shift = 2 + imin(t_dim->lh, 3), mask = 4 * imin(t_dim->h, 8) - 1;
-    unsigned cul_level = 0;
-    for (int i = eob, is_last = 1; i >= 0; i--, is_last = 0) {
-        const int rc = scan[i], x = rc >> shift, y = rc & mask;
+    int dc_tok;
 
-        // lo tok
-        const int ctx = get_coef_nz_ctx(levels, i, rc, is_last, tx, tx_class);
-        uint16_t *const lo_cdf = is_last ?
-            ts->cdf.coef.eob_base_tok[t_dim->ctx][chroma][ctx] :
-            ts->cdf.coef.base_tok[t_dim->ctx][chroma][ctx];
-        int tok = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf,
-                                                  4 - is_last) + is_last;
+    if (eob) {
+        ALIGN_STK_16(uint8_t, levels, 36 * 36,);
+        const int sw = imin(t_dim->w, 8), sh = imin(t_dim->h, 8);
+        const ptrdiff_t stride = 4 * (sh + 1);
+        memset(levels, 0, stride * 4 * (sw + 1));
+        const int shift = 2 + imin(t_dim->lh, 3), mask = 4 * sh - 1;
+
+        { // eob
+            const int rc = scan[eob], x = rc >> shift, y = rc & mask;
+
+            const int ctx = 1 + (eob > sw * sh * 2) + (eob > sw * sh * 4);
+            uint16_t *const lo_cdf = ts->cdf.coef.eob_base_tok[t_dim->ctx][chroma][ctx];
+            int tok = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf, 3) + 1;
+            if (dbg)
+                printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
+                       t_dim->ctx, chroma, ctx, eob, rc, tok, ts->msac.rng);
+
+            if (tok == 3) {
+                const int br_ctx = get_br_ctx(levels, 1, tx_class, x, y, stride);
+                do {
+                    const int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
+                                           br_cdf[br_ctx], 4);
+                    if (dbg)
+                        printf("Post-hi_tok[%d][%d][%d][%d=%d=%d->%d]: r=%d\n",
+                               imin(t_dim->ctx, 3), chroma, br_ctx,
+                               eob, rc, tok_br, tok, ts->msac.rng);
+                    tok += tok_br;
+                    if (tok_br < 3) break;
+                } while (tok < 15);
+            }
+
+            cf[rc] = tok;
+            levels[x * stride + y] = (uint8_t) tok;
+        }
+        for (int i = eob - 1; i > 0; i--) { // ac
+            const int rc = scan[i], x = rc >> shift, y = rc & mask;
+
+            // lo tok
+            const int ctx = get_coef_nz_ctx(levels, tx, tx_class, x, y, stride);
+            uint16_t *const lo_cdf = ts->cdf.coef.base_tok[t_dim->ctx][chroma][ctx];
+            int tok = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf, 4);
+            if (dbg)
+                printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
+                       t_dim->ctx, chroma, ctx, i, rc, tok, ts->msac.rng);
+
+            // hi tok
+            if (tok == 3) {
+                const int br_ctx = get_br_ctx(levels, 1, tx_class, x, y, stride);
+                do {
+                    const int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
+                                           br_cdf[br_ctx], 4);
+                    if (dbg)
+                        printf("Post-hi_tok[%d][%d][%d][%d=%d=%d->%d]: r=%d\n",
+                               imin(t_dim->ctx, 3), chroma, br_ctx,
+                               i, rc, tok_br, tok, ts->msac.rng);
+                    tok += tok_br;
+                    if (tok_br < 3) break;
+                } while (tok < 15);
+            }
+
+            cf[rc] = tok;
+            levels[x * stride + y] = (uint8_t) tok;
+        }
+        { // dc
+            int ctx = 0;
+            if (tx_class != TX_CLASS_2D)
+                ctx = get_coef_nz_ctx(levels, tx, tx_class, 0, 0, stride);
+            uint16_t *const lo_cdf = ts->cdf.coef.base_tok[t_dim->ctx][chroma][ctx];
+            dc_tok = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf, 4);
+            if (dbg)
+                printf("Post-dc_lo_tok[%d][%d][%d][%d]: r=%d\n",
+                       t_dim->ctx, chroma, ctx, dc_tok, ts->msac.rng);
+
+            if (dc_tok == 3) {
+                const int br_ctx = get_br_ctx(levels, 0, tx_class, 0, 0, stride);
+                do {
+                    const int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
+                                           br_cdf[br_ctx], 4);
+                    if (dbg)
+                        printf("Post-dc_hi_tok[%d][%d][%d][%d->%d]: r=%d\n",
+                               imin(t_dim->ctx, 3), chroma, br_ctx,
+                               tok_br, dc_tok, ts->msac.rng);
+                    dc_tok += tok_br;
+                    if (tok_br < 3) break;
+                } while (dc_tok < 15);
+            }
+        }
+    } else { // dc-only
+        uint16_t *const lo_cdf = ts->cdf.coef.eob_base_tok[t_dim->ctx][chroma][0];
+        dc_tok = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf, 3) + 1;
         if (dbg)
-        printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
-               t_dim->ctx, chroma, ctx, i, rc, tok, ts->msac.rng);
-        if (!tok) continue;
+            printf("Post-dc_lo_tok[%d][%d][%d][%d]: r=%d\n",
+                   t_dim->ctx, chroma, 0, dc_tok, ts->msac.rng);
 
-        // hi tok
-        if (tok == 3) {
-            const int br_ctx = get_br_ctx(levels, rc, tx, tx_class);
+        if (dc_tok == 3) {
             do {
                 const int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                       br_cdf[br_ctx], 4);
+                                       br_cdf[0], 4);
                 if (dbg)
-                printf("Post-hi_tok[%d][%d][%d][%d=%d=%d->%d]: r=%d\n",
-                       imin(t_dim->ctx, 3), chroma, br_ctx,
-                       i, rc, tok_br, tok, ts->msac.rng);
-                tok += tok_br;
+                    printf("Post-dc_hi_tok[%d][%d][%d][%d->%d]: r=%d\n",
+                           imin(t_dim->ctx, 3), chroma, 0,
+                           tok_br, dc_tok, ts->msac.rng);
+                dc_tok += tok_br;
                 if (tok_br < 3) break;
-            } while (tok < 15);
+            } while (dc_tok < 15);
         }
-
-        cf[rc] = tok;
-        levels[x * stride + y] = (uint8_t) cf[rc];
     }
 
     // residual and sign
-    int dc_sign = 1;
+    int dc_sign = 1 << 6;
     const int lossless = f->frame_hdr->segmentation.lossless[b->seg_id];
     const uint16_t *const dq_tbl = ts->dq[b->seg_id][plane];
     const uint8_t *const qm_tbl = f->qm[lossless || is_1d || *txtp == IDTX][tx][plane];
     const int dq_shift = imax(0, t_dim->ctx - 2);
     const int bitdepth = BITDEPTH == 8 ? 8 : f->cur.p.bpc;
-    const int cf_min = -(1 << (7 + bitdepth));
     const int cf_max = (1 << (7 + bitdepth)) - 1;
-    for (int i = 0; i <= eob; i++) {
+    unsigned cul_level = 0;
+
+    if (dc_tok) { // dc
+        const int dc_sign_ctx = get_dc_sign_ctx(t_dim, a, l);
+        uint16_t *const dc_sign_cdf =
+            ts->cdf.coef.dc_sign[chroma][dc_sign_ctx];
+        const int sign = dav1d_msac_decode_bool_adapt(&ts->msac, dc_sign_cdf);
+        const unsigned dq = (dq_tbl[0] * qm_tbl[0] + 16) >> 5;
+        if (dbg)
+            printf("Post-dc_sign[%d][%d][%d]: r=%d\n",
+                   chroma, dc_sign_ctx, sign, ts->msac.rng);
+        dc_sign = (sign - 1) & (2 << 6);
+
+        if (dc_tok == 15) {
+            dc_tok += read_golomb(&ts->msac);
+            if (dbg)
+                printf("Post-dc_residual[%d->%d]: r=%d\n",
+                       dc_tok - 15, dc_tok, ts->msac.rng);
+
+            dc_tok &= 0xfffff;
+        }
+
+        cul_level += dc_tok;
+        dc_tok = ((dq * dc_tok) & 0xffffff) >> dq_shift;
+        cf[0] = imin(dc_tok - sign, cf_max) ^ -sign;
+    }
+    for (int i = 1; i <= eob; i++) { // ac
         const int rc = scan[i];
         int tok = cf[rc];
         if (!tok) continue;
-        int dq;
 
         // sign
-        int sign;
-        if (i == 0) {
-            const int dc_sign_ctx = get_dc_sign_ctx(t_dim, a, l);
-            uint16_t *const dc_sign_cdf =
-                ts->cdf.coef.dc_sign[chroma][dc_sign_ctx];
-            sign = dav1d_msac_decode_bool_adapt(&ts->msac, dc_sign_cdf);
-            if (dbg)
-            printf("Post-dc_sign[%d][%d][%d]: r=%d\n",
-                   chroma, dc_sign_ctx, sign, ts->msac.rng);
-            dc_sign = sign ? 0 : 2;
-            dq = (dq_tbl[0] * qm_tbl[0] + 16) >> 5;
-        } else {
-            sign = dav1d_msac_decode_bool_equi(&ts->msac);
-            if (dbg)
+        const int sign = dav1d_msac_decode_bool_equi(&ts->msac);
+        const unsigned dq = (dq_tbl[1] * qm_tbl[rc] + 16) >> 5;
+        if (dbg)
             printf("Post-sign[%d=%d=%d]: r=%d\n", i, rc, sign, ts->msac.rng);
-            dq = (dq_tbl[1] * qm_tbl[rc] + 16) >> 5;
-        }
 
         // residual
         if (tok == 15) {
             tok += read_golomb(&ts->msac);
             if (dbg)
-            printf("Post-residual[%d=%d=%d->%d]: r=%d\n",
-                   i, rc, tok - 15, tok, ts->msac.rng);
-        }
+                printf("Post-residual[%d=%d=%d->%d]: r=%d\n",
+                       i, rc, tok - 15, tok, ts->msac.rng);
 
-        // coefficient parsing, see 5.11.39
-        tok &= 0xfffff;
+            // coefficient parsing, see 5.11.39
+            tok &= 0xfffff;
+        }
 
         // dequant, see 7.12.3
         cul_level += tok;
-        tok = (((int64_t)dq * tok) & 0xffffff) >> dq_shift;
-        cf[rc] = iclip(sign ? -tok : tok, cf_min, cf_max);
+        tok = ((dq * tok) & 0xffffff) >> dq_shift;
+        cf[rc] = imin(tok - sign, cf_max) ^ -sign;
     }
 
     // context
-    *res_ctx = imin(cul_level, 63) | (dc_sign << 6);
+    *res_ctx = imin(cul_level, 63) | dc_sign;
 
     return eob;
 }