Skip to content

Commit 390d989

Browse files
KyleSiefringbarrbrain
authored andcommitted
Reduce the number of checkpoints used in tx rdo
1 parent f37c13c commit 390d989

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

src/rdo.rs

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ pub fn rdo_tx_size_type<T: Pixel>(
615615
let do_rdo_tx_size =
616616
fi.tx_mode_select && fi.config.speed_settings.rdo_tx_decision && !is_inter;
617617
let rdo_tx_depth = if do_rdo_tx_size { 2 } else { 0 };
618-
let mut cw_checkpoint = None;
618+
let mut cw_checkpoint: Option<ContextWriterCheckpoint> = None;
619619

620620
for _ in 0..=rdo_tx_depth {
621621
let tx_set = get_tx_set(tx_size, is_inter, fi.use_reduced_tx_set);
@@ -628,18 +628,22 @@ pub fn rdo_tx_size_type<T: Pixel>(
628628
return (best_tx_size, best_tx_type);
629629
};
630630

631-
if cw_checkpoint.is_none() {
632-
// Only runs on the first iteration of the loop.
633-
// Avoids creating the checkpoint if we early exit above.
634-
cw_checkpoint = Some(cw.checkpoint());
635-
}
636-
637631
let tx_types =
638632
if do_rdo_tx_type { RAV1E_TX_TYPES } else { &[TxType::DCT_DCT] };
639633

640634
// Luma plane transform type decision
641635
let (tx_type, rd_cost) = rdo_tx_type_decision(
642-
fi, ts, cw, luma_mode, ref_frames, mvs, bsize, tile_bo, tx_size, tx_set,
636+
fi,
637+
ts,
638+
cw,
639+
&mut cw_checkpoint,
640+
luma_mode,
641+
ref_frames,
642+
mvs,
643+
bsize,
644+
tile_bo,
645+
tx_size,
646+
tx_set,
643647
tx_types,
644648
);
645649

@@ -656,7 +660,6 @@ pub fn rdo_tx_size_type<T: Pixel>(
656660
);
657661

658662
let next_tx_size = sub_tx_size_map[tx_size as usize];
659-
cw.rollback(cw_checkpoint.as_ref().unwrap());
660663

661664
if next_tx_size == tx_size {
662665
break;
@@ -1436,12 +1439,15 @@ pub fn rdo_cfl_alpha<T: Pixel>(
14361439
}
14371440
}
14381441

1439-
// RDO-based transform type decision
1442+
/// RDO-based transform type decision
1443+
/// If cw_checkpoint is None, a checkpoint for cw's (ContextWriter) current
1444+
/// state is created and stored for later use.
14401445
pub fn rdo_tx_type_decision<T: Pixel>(
14411446
fi: &FrameInvariants<T>, ts: &mut TileStateMut<'_, T>,
1442-
cw: &mut ContextWriter, mode: PredictionMode, ref_frames: [RefType; 2],
1443-
mvs: [MotionVector; 2], bsize: BlockSize, tile_bo: TileBlockOffset,
1444-
tx_size: TxSize, tx_set: TxSet, tx_types: &[TxType],
1447+
cw: &mut ContextWriter, cw_checkpoint: &mut Option<ContextWriterCheckpoint>,
1448+
mode: PredictionMode, ref_frames: [RefType; 2], mvs: [MotionVector; 2],
1449+
bsize: BlockSize, tile_bo: TileBlockOffset, tx_size: TxSize, tx_set: TxSet,
1450+
tx_types: &[TxType],
14451451
) -> (TxType, f64) {
14461452
let mut best_type = TxType::DCT_DCT;
14471453
let mut best_rd = std::f64::MAX;
@@ -1451,7 +1457,11 @@ pub fn rdo_tx_type_decision<T: Pixel>(
14511457

14521458
let is_inter = !mode.is_intra();
14531459

1454-
let cw_checkpoint = cw.checkpoint();
1460+
if cw_checkpoint.is_none() {
1461+
// Only run the first call
1462+
// Prevents creating multiple checkpoints for own version of cw
1463+
*cw_checkpoint = Some(cw.checkpoint());
1464+
}
14551465

14561466
let rdo_type = if fi.use_tx_domain_distortion {
14571467
RDOType::TxDistRealRate
@@ -1533,7 +1543,7 @@ pub fn rdo_tx_type_decision<T: Pixel>(
15331543
best_type = tx_type;
15341544
}
15351545

1536-
cw.rollback(&cw_checkpoint);
1546+
cw.rollback(cw_checkpoint.as_ref().unwrap());
15371547
}
15381548

15391549
assert!(best_rd >= 0_f64);

0 commit comments

Comments
 (0)