diff --git a/.gitignore b/.gitignore index ee2ba210..3aee11ae 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ __pycache__ Scratch Proto +.idea # Precompiled reference binaries for comparison tests bin @@ -19,6 +20,7 @@ Binaries # Build artifacts astcenc build* +cmake-build* # General build artifacts Test/DocOut diff --git a/Source/astcenc.h b/Source/astcenc.h index 3d04b4ea..de9752eb 100644 --- a/Source/astcenc.h +++ b/Source/astcenc.h @@ -571,6 +571,36 @@ struct astcenc_config */ float tune_search_mode0_enable; + /** + * @brief Enable Rate Distortion Optimization (RDO) post-processing. + */ + bool rdo_enabled; + + /** + * @brief RDO quality scalar (lambda). + */ + float rdo_quality; + + /** + * @brief RDO lookback size in blocks. + */ + unsigned int rdo_lookback; + + /** + * @brief RDO task partitions. + */ + unsigned int rdo_partitions; + + /** + * @brief RDO max smooth block error scale. + */ + float rdo_max_smooth_block_error_scale; + + /** + * @brief RDO max smooth block standard deviation. + */ + float rdo_max_smooth_block_std_dev; + /** * @brief The progress callback, can be @c nullptr. * diff --git a/Source/astcenc_entry.cpp b/Source/astcenc_entry.cpp index 5dc38016..bd088c85 100644 --- a/Source/astcenc_entry.cpp +++ b/Source/astcenc_entry.cpp @@ -56,6 +56,7 @@ struct astcenc_preset_config float tune_3partition_early_out_limit_factor; float tune_2plane_early_out_limit_correlation; float tune_search_mode0_enable; + unsigned int rdo_lookback; }; /** @@ -64,22 +65,22 @@ struct astcenc_preset_config static const std::array preset_configs_high {{ { ASTCENC_PRE_FASTEST, - 2, 10, 6, 4, 43, 2, 2, 2, 2, 2, 85.2f, 63.2f, 3.5f, 1.0f, 1.0f, 0.85f, 0.0f + 2, 10, 6, 4, 43, 2, 2, 2, 2, 2, 85.2f, 63.2f, 3.5f, 1.0f, 1.0f, 0.85f, 0.0f, 64 }, { ASTCENC_PRE_FAST, - 3, 18, 10, 8, 55, 3, 3, 2, 2, 2, 85.2f, 63.2f, 3.5f, 1.0f, 1.0f, 0.90f, 0.0f + 3, 18, 10, 8, 55, 3, 3, 2, 2, 2, 85.2f, 63.2f, 3.5f, 1.0f, 1.0f, 0.90f, 0.0f, 128 }, { ASTCENC_PRE_MEDIUM, - 4, 34, 28, 16, 77, 3, 3, 2, 2, 2, 95.0f, 70.0f, 2.5f, 1.1f, 1.05f, 0.95f, 0.0f + 4, 34, 28, 16, 77, 3, 3, 2, 2, 2, 95.0f, 70.0f, 2.5f, 1.1f, 1.05f, 0.95f, 0.0f, 256 }, { ASTCENC_PRE_THOROUGH, - 4, 82, 60, 30, 94, 4, 4, 3, 2, 2, 105.0f, 77.0f, 10.0f, 1.35f, 1.15f, 0.97f, 0.0f + 4, 82, 60, 30, 94, 4, 4, 3, 2, 2, 105.0f, 77.0f, 10.0f, 1.35f, 1.15f, 0.97f, 0.0f, 256 }, { ASTCENC_PRE_VERYTHOROUGH, - 4, 256, 128, 64, 98, 4, 6, 8, 6, 4, 200.0f, 200.0f, 10.0f, 1.6f, 1.4f, 0.98f, 0.0f + 4, 256, 128, 64, 98, 4, 6, 8, 6, 4, 200.0f, 200.0f, 10.0f, 1.6f, 1.4f, 0.98f, 0.0f, 256 }, { ASTCENC_PRE_EXHAUSTIVE, - 4, 512, 512, 512, 100, 4, 8, 8, 8, 8, 200.0f, 200.0f, 10.0f, 2.0f, 2.0f, 0.99f, 0.0f + 4, 512, 512, 512, 100, 4, 8, 8, 8, 8, 200.0f, 200.0f, 10.0f, 2.0f, 2.0f, 0.99f, 0.0f, 256 } }}; @@ -89,22 +90,22 @@ static const std::array preset_configs_high {{ static const std::array preset_configs_mid {{ { ASTCENC_PRE_FASTEST, - 2, 10, 6, 4, 43, 2, 2, 2, 2, 2, 85.2f, 63.2f, 3.5f, 1.0f, 1.0f, 0.80f, 1.0f + 2, 10, 6, 4, 43, 2, 2, 2, 2, 2, 85.2f, 63.2f, 3.5f, 1.0f, 1.0f, 0.80f, 1.0f, 64 }, { ASTCENC_PRE_FAST, - 3, 18, 12, 10, 55, 3, 3, 2, 2, 2, 85.2f, 63.2f, 3.5f, 1.0f, 1.0f, 0.85f, 1.0f + 3, 18, 12, 10, 55, 3, 3, 2, 2, 2, 85.2f, 63.2f, 3.5f, 1.0f, 1.0f, 0.85f, 1.0f, 128 }, { ASTCENC_PRE_MEDIUM, - 3, 34, 28, 16, 77, 3, 3, 2, 2, 2, 95.0f, 70.0f, 3.0f, 1.1f, 1.05f, 0.90f, 1.0f + 3, 34, 28, 16, 77, 3, 3, 2, 2, 2, 95.0f, 70.0f, 3.0f, 1.1f, 1.05f, 0.90f, 1.0f, 256 }, { ASTCENC_PRE_THOROUGH, - 4, 82, 60, 30, 94, 4, 4, 3, 2, 2, 105.0f, 77.0f, 10.0f, 1.4f, 1.2f, 0.95f, 0.0f + 4, 82, 60, 30, 94, 4, 4, 3, 2, 2, 105.0f, 77.0f, 10.0f, 1.4f, 1.2f, 0.95f, 0.0f, 256 }, { ASTCENC_PRE_VERYTHOROUGH, - 4, 256, 128, 64, 98, 4, 6, 8, 6, 3, 200.0f, 200.0f, 10.0f, 1.6f, 1.4f, 0.98f, 0.0f + 4, 256, 128, 64, 98, 4, 6, 8, 6, 3, 200.0f, 200.0f, 10.0f, 1.6f, 1.4f, 0.98f, 0.0f, 256 }, { ASTCENC_PRE_EXHAUSTIVE, - 4, 256, 256, 256, 100, 4, 8, 8, 8, 8, 200.0f, 200.0f, 10.0f, 2.0f, 2.0f, 0.99f, 0.0f + 4, 256, 256, 256, 100, 4, 8, 8, 8, 8, 200.0f, 200.0f, 10.0f, 2.0f, 2.0f, 0.99f, 0.0f, 256 } }}; @@ -114,22 +115,22 @@ static const std::array preset_configs_mid {{ static const std::array preset_configs_low {{ { ASTCENC_PRE_FASTEST, - 2, 10, 6, 4, 40, 2, 2, 2, 2, 2, 85.0f, 63.0f, 3.5f, 1.0f, 1.0f, 0.80f, 1.0f + 2, 10, 6, 4, 40, 2, 2, 2, 2, 2, 85.0f, 63.0f, 3.5f, 1.0f, 1.0f, 0.80f, 1.0f, 64 }, { ASTCENC_PRE_FAST, - 2, 18, 12, 10, 55, 3, 3, 2, 2, 2, 85.0f, 63.0f, 3.5f, 1.0f, 1.0f, 0.85f, 1.0f + 2, 18, 12, 10, 55, 3, 3, 2, 2, 2, 85.0f, 63.0f, 3.5f, 1.0f, 1.0f, 0.85f, 1.0f, 128 }, { ASTCENC_PRE_MEDIUM, - 3, 34, 28, 16, 77, 3, 3, 2, 2, 2, 95.0f, 70.0f, 3.5f, 1.1f, 1.05f, 0.90f, 1.0f + 3, 34, 28, 16, 77, 3, 3, 2, 2, 2, 95.0f, 70.0f, 3.5f, 1.1f, 1.05f, 0.90f, 1.0f, 256 }, { ASTCENC_PRE_THOROUGH, - 4, 82, 60, 30, 93, 4, 4, 3, 2, 2, 105.0f, 77.0f, 10.0f, 1.3f, 1.2f, 0.97f, 1.0f + 4, 82, 60, 30, 93, 4, 4, 3, 2, 2, 105.0f, 77.0f, 10.0f, 1.3f, 1.2f, 0.97f, 1.0f, 256 }, { ASTCENC_PRE_VERYTHOROUGH, - 4, 256, 128, 64, 98, 4, 6, 8, 5, 2, 200.0f, 200.0f, 10.0f, 1.6f, 1.4f, 0.98f, 1.0f + 4, 256, 128, 64, 98, 4, 6, 8, 5, 2, 200.0f, 200.0f, 10.0f, 1.6f, 1.4f, 0.98f, 1.0f, 256 }, { ASTCENC_PRE_EXHAUSTIVE, - 4, 256, 256, 256, 100, 4, 8, 8, 8, 8, 200.0f, 200.0f, 10.0f, 2.0f, 2.0f, 0.99f, 1.0f + 4, 256, 256, 256, 100, 4, 8, 8, 8, 8, 200.0f, 200.0f, 10.0f, 2.0f, 2.0f, 0.99f, 1.0f, 256 } }}; @@ -412,6 +413,11 @@ static astcenc_error validate_config( config.tune_3partition_early_out_limit_factor = astc::max(config.tune_3partition_early_out_limit_factor, 0.0f); config.tune_2plane_early_out_limit_correlation = astc::max(config.tune_2plane_early_out_limit_correlation, 0.0f); + config.rdo_quality = astc::clamp(config.rdo_quality, 0.001f, 50.0f); + config.rdo_lookback = astc::clamp(config.rdo_lookback, 4u, 4096u); + config.rdo_max_smooth_block_error_scale = astc::clamp(config.rdo_max_smooth_block_error_scale, 1.0f, 300.0f); + config.rdo_max_smooth_block_std_dev = astc::clamp(config.rdo_max_smooth_block_std_dev, 0.01f, 65536.0f); + // Specifying a zero weight color component is not allowed; force to small value float max_weight = astc::max(astc::max(config.cw_r_weight, config.cw_g_weight), astc::max(config.cw_b_weight, config.cw_a_weight)); @@ -528,6 +534,7 @@ astcenc_error astcenc_config_init( config.tune_3partition_early_out_limit_factor = (*preset_configs)[start].tune_3partition_early_out_limit_factor; config.tune_2plane_early_out_limit_correlation = (*preset_configs)[start].tune_2plane_early_out_limit_correlation; config.tune_search_mode0_enable = (*preset_configs)[start].tune_search_mode0_enable; + config.rdo_lookback = (*preset_configs)[start].rdo_lookback; } // Start and end node are not the same - so interpolate between them else @@ -567,11 +574,16 @@ astcenc_error astcenc_config_init( config.tune_3partition_early_out_limit_factor = LERP(tune_3partition_early_out_limit_factor); config.tune_2plane_early_out_limit_correlation = LERP(tune_2plane_early_out_limit_correlation); config.tune_search_mode0_enable = LERP(tune_search_mode0_enable); + config.rdo_lookback = LERPUI(rdo_lookback); #undef LERP #undef LERPI #undef LERPUI } + config.rdo_quality = 0.5f; + config.rdo_max_smooth_block_error_scale = 10.0f; + config.rdo_max_smooth_block_std_dev = 18.0f; + // Set heuristics to the defaults for each color profile config.cw_r_weight = 1.0f; config.cw_g_weight = 1.0f; @@ -1099,6 +1111,8 @@ astcenc_error astcenc_compress_image( // Only the first thread to arrive actually runs the term ctxo->manage_compress.term(term_compress); + rate_distortion_optimize(*ctxo, image, *swizzle, data_out); + return ASTCENC_SUCCESS; #endif } @@ -1119,6 +1133,7 @@ astcenc_error astcenc_compress_reset( ctxo->manage_avg.reset(); ctxo->manage_compress.reset(); + ctxo->manage_rdo.reset(); return ASTCENC_SUCCESS; #endif } diff --git a/Source/astcenc_internal.h b/Source/astcenc_internal.h index df6e07f9..74187625 100644 --- a/Source/astcenc_internal.h +++ b/Source/astcenc_internal.h @@ -1231,6 +1231,8 @@ struct astcenc_contexti #if !defined(ASTCENC_DECOMPRESS_ONLY) /** @brief The pixel region and variance worker arguments. */ avg_args avg_preprocess_args; + + struct astcenc_rdo_context* rdo_context; #endif #if defined(ASTCENC_DIAGNOSTICS) @@ -1966,7 +1968,7 @@ unsigned int compute_ideal_endpoint_formats( * @param pi The partition info for the current trial. * @param di The weight grid decimation table. * @param dec_weights_uquant The quantized weight set. - * @param[in,out] ep The color endpoints (modifed in place). + * @param[in,out] ep The color endpoints (modified in place). * @param[out] rgbs_vectors The RGB+scale vectors for LDR blocks. * @param[out] rgbo_vectors The RGB+offset vectors for HDR blocks. */ @@ -1990,7 +1992,7 @@ void recompute_ideal_colors_1plane( * @param di The weight grid decimation table. * @param dec_weights_uquant_plane1 The quantized weight set for plane 1. * @param dec_weights_uquant_plane2 The quantized weight set for plane 2. - * @param[in,out] ep The color endpoints (modifed in place). + * @param[in,out] ep The color endpoints (modified in place). * @param[out] rgbs_vector The RGB+scale color for LDR blocks. * @param[out] rgbo_vector The RGB+offset color for HDR blocks. * @param plane2_component The component assigned to plane 2. @@ -2163,7 +2165,7 @@ void symbolic_to_physical( * flagged as an error block if the encoding is invalid. * * @param bsd The block size information. - * @param pcb The physical compresesd block input. + * @param pcb The physical compressed block input. * @param[out] scb The output symbolic representation. */ void physical_to_symbolic( @@ -2171,6 +2173,20 @@ void physical_to_symbolic( const uint8_t pcb[16], symbolic_compressed_block& scb); +/** + * @brief Rate-distortion optimization main entry. + * + * @param ctxo The compressor context and configuration. + * @param image The input image data. + * @param swizzle The swizzle applied on store. + * @param[in,out] buffer The compressed buffer to be optimized (modified in place) + */ +void rate_distortion_optimize( + astcenc_context& ctxo, + const astcenc_image& image, + const astcenc_swizzle& swizzle, + uint8_t* buffer); + /* ============================================================================ Platform-specific functions. ============================================================================ */ diff --git a/Source/astcenc_internal_entry.h b/Source/astcenc_internal_entry.h index c283c5ac..d60278e3 100644 --- a/Source/astcenc_internal_entry.h +++ b/Source/astcenc_internal_entry.h @@ -164,13 +164,14 @@ class ParallelManager * @param init_func Callable which executes the stage initialization. It must return the * total number of tasks in the stage. */ - void init(std::function init_func) + void init(std::function init_func, astcenc_progress_callback callback = nullptr) { std::lock_guard lck(m_lock); if (!m_init_done) { m_task_count = init_func(); m_init_done = true; + if (callback) m_callback = callback; } } @@ -322,6 +323,9 @@ struct astcenc_context /** @brief The parallel manager for compression. */ ParallelManager manage_compress; + + /** @brief The parallel manager for rate-distortion optimization. */ + ParallelManager manage_rdo; #endif /** @brief The parallel manager for decompression. */ diff --git a/Source/astcenc_rate_distortion.cpp b/Source/astcenc_rate_distortion.cpp new file mode 100644 index 00000000..d6ec828f --- /dev/null +++ b/Source/astcenc_rate_distortion.cpp @@ -0,0 +1,481 @@ +// SPDX-License-Identifier: Apache-2.0 +// ---------------------------------------------------------------------------- +// Copyright 2011-2024 Arm Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy +// of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. +// ---------------------------------------------------------------------------- + +#if !defined(ASTCENC_DECOMPRESS_ONLY) + +#include "astcenc_internal_entry.h" +#include "ert.h" + +#include + +struct astcenc_rdo_context +{ + ert::reduce_entropy_params m_ert_params; + std::vector m_blocks; + + uint32_t m_image_x = 0; + uint32_t m_image_y = 0; + uint32_t m_image_z = 0; +}; + +#define ASTCENC_RDO_SPECIALIZE_DIFF 1 + +static constexpr uint32_t ASTCENC_BYTES_PER_BLOCK = 16; + +template T sqr(T v) { return v * v; } + +extern "C" void rdo_progress_emitter( + float value +) { + static float previous_value = 100.0f; + if (previous_value == 100.0f) + { + printf("\n\n"); + printf("Rate-distortion optimization\n"); + printf("============================\n\n"); + } + previous_value = value; + + const unsigned int bar_size = 25; + unsigned int parts = static_cast(value / 4.0f); + + char buffer[bar_size + 3]; + buffer[0] = '['; + + for (unsigned int i = 0; i < parts; i++) + { + buffer[i + 1] = '='; + } + + for (unsigned int i = parts; i < bar_size; i++) + { + buffer[i + 1] = ' '; + } + + buffer[bar_size + 1] = ']'; + buffer[bar_size + 2] = '\0'; + + printf(" Progress: %s %03.1f%%\r", buffer, static_cast(value)); + fflush(stdout); +} + +static uint32_t init_rdo_context( + astcenc_contexti& ctx, + const astcenc_image& image, + const astcenc_swizzle& swz +) { + ctx.rdo_context = new astcenc_rdo_context; + astcenc_rdo_context& rdo_ctx = *ctx.rdo_context; + + uint32_t block_dim_x = ctx.bsd->xdim; + uint32_t block_dim_y = ctx.bsd->ydim; + uint32_t block_dim_z = ctx.bsd->zdim; + uint32_t xblocks = (image.dim_x + block_dim_x - 1u) / block_dim_x; + uint32_t yblocks = (image.dim_y + block_dim_y - 1u) / block_dim_y; + uint32_t zblocks = (image.dim_z + block_dim_z - 1u) / block_dim_z; + uint32_t total_blocks = xblocks * yblocks * zblocks; + + // Generate quality parameters + auto& ert_params = rdo_ctx.m_ert_params; + ert_params.m_lambda = ctx.config.rdo_quality; + ert_params.m_lookback_window_size = ctx.config.rdo_lookback * ASTCENC_BYTES_PER_BLOCK; + ert_params.m_smooth_block_max_mse_scale = ctx.config.rdo_max_smooth_block_error_scale; + ert_params.m_max_smooth_block_std_dev = ctx.config.rdo_max_smooth_block_std_dev; + + ert_params.m_try_two_matches = true; + + rdo_ctx.m_blocks.resize(total_blocks); + rdo_ctx.m_image_x = image.dim_x; + rdo_ctx.m_image_y = image.dim_y; + rdo_ctx.m_image_z = image.dim_z; + + vfloat4 channel_weight = vfloat4(ctx.config.cw_r_weight, + ctx.config.cw_g_weight, + ctx.config.cw_b_weight, + ctx.config.cw_a_weight); + channel_weight = channel_weight / hadd_s(channel_weight); + + for (uint32_t block_z = 0, block_idx = 0; block_z < zblocks; ++block_z) + { + for (uint32_t block_y = 0; block_y < yblocks; ++block_y) + { + for (uint32_t block_x = 0; block_x < xblocks; ++block_x, ++block_idx) + { + image_block& blk = rdo_ctx.m_blocks[block_idx]; + blk.decode_unorm8 = ctx.config.flags & ASTCENC_FLG_USE_DECODE_UNORM8; + blk.texel_count = ctx.bsd->texel_count; + + load_image_block(ctx.config.profile, image, blk, *ctx.bsd, + block_dim_x * block_x, block_dim_y * block_y, block_dim_z * block_z, swz); + + if (ctx.config.flags & ASTCENC_FLG_USE_ALPHA_WEIGHT) + { + float alpha_scale = blk.data_max.lane<3>() * (1.0f / 65535.0f); + blk.channel_weight = vfloat4(ctx.config.cw_r_weight * alpha_scale, + ctx.config.cw_g_weight * alpha_scale, + ctx.config.cw_b_weight * alpha_scale, + ctx.config.cw_a_weight); + blk.channel_weight = blk.channel_weight / hadd_s(blk.channel_weight); + } + else + { + blk.channel_weight = channel_weight; + } + } + } + } + + return total_blocks; +} + +static float compute_block_std_dev( + const image_block& blk, + float scale +) { + if (all(blk.data_min == blk.data_max)) return 0.0f; + + vfloatacc summav = vfloatacc::zero(); + vint lane_id = vint::lane_id(); + uint32_t texel_count = blk.texel_count; + vfloat color_mean_r(blk.data_mean.lane<0>() * scale); + vfloat color_mean_g(blk.data_mean.lane<1>() * scale); + vfloat color_mean_b(blk.data_mean.lane<2>() * scale); + vfloat color_mean_a(blk.data_mean.lane<3>() * scale); + + for (uint32_t i = 0; i < texel_count; i += ASTCENC_SIMD_WIDTH) + { + vfloat color_orig_r = loada(blk.data_r + i) * scale; + vfloat color_orig_g = loada(blk.data_g + i) * scale; + vfloat color_orig_b = loada(blk.data_b + i) * scale; + vfloat color_orig_a = loada(blk.data_a + i) * scale; + + vfloat color_error_r = min(abs(color_orig_r - color_mean_r), vfloat(1e15f)); + vfloat color_error_g = min(abs(color_orig_g - color_mean_g), vfloat(1e15f)); + vfloat color_error_b = min(abs(color_orig_b - color_mean_b), vfloat(1e15f)); + vfloat color_error_a = min(abs(color_orig_a - color_mean_a), vfloat(1e15f)); + + // Compute squared error metric + color_error_r = color_error_r * color_error_r; + color_error_g = color_error_g * color_error_g; + color_error_b = color_error_b * color_error_b; + color_error_a = color_error_a * color_error_a; + + vfloat metric = astc::max(color_error_r * blk.channel_weight.lane<0>(), + color_error_g * blk.channel_weight.lane<1>(), + color_error_b * blk.channel_weight.lane<2>(), + color_error_a * blk.channel_weight.lane<3>()); + + // Mask off bad lanes + vmask mask = lane_id < vint(texel_count); + lane_id += vint(ASTCENC_SIMD_WIDTH); + haccumulate(summav, metric, mask); + } + + return sqrtf(hadd_s(summav) / texel_count); +} + +#if ASTCENC_RDO_SPECIALIZE_DIFF +static float compute_symbolic_block_difference_constant( + const astcenc_config& config, + const block_size_descriptor& bsd, + symbolic_compressed_block scb, + const image_block& blk +) { + vfloat4 color(0.0f); + + // UNORM16 constant color block + if (scb.block_type == SYM_BTYPE_CONST_U16) + { + vint4 colori(scb.constant_color); + + // Determine the UNORM8 rounding on the decode + vmask4 u8_mask = get_u8_component_mask(config.profile, blk); + + // The real decoder would just use the top 8 bits, but we rescale + // in to a 16-bit value that rounds correctly. + vint4 colori_u8 = asr<8>(colori) * 257; + colori = select(colori, colori_u8, u8_mask); + + vint4 colorf16 = unorm16_to_sf16(colori); + color = float16_to_float(colorf16); + } + // FLOAT16 constant color block + else + { + switch (config.profile) + { + case ASTCENC_PRF_LDR_SRGB: + case ASTCENC_PRF_LDR: + return -ERROR_CALC_DEFAULT; + case ASTCENC_PRF_HDR_RGB_LDR_A: + case ASTCENC_PRF_HDR: + // Constant-color block; unpack from FP16 to FP32. + color = float16_to_float(vint4(scb.constant_color)); + break; + } + } + + if (all(blk.data_min == blk.data_max)) // Original block is also constant + { + vfloat4 color_error = min(abs(blk.origin_texel - color) * 65535.0f, vfloat4(1e15f)); + return dot_s(color_error * color_error, blk.channel_weight) * bsd.texel_count; + } + + vfloatacc summav = vfloatacc::zero(); + vint lane_id = vint::lane_id(); + uint32_t texel_count = bsd.texel_count; + + vfloat color_r(color.lane<0>() * 65535.0f); + vfloat color_g(color.lane<1>() * 65535.0f); + vfloat color_b(color.lane<2>() * 65535.0f); + vfloat color_a(color.lane<3>() * 65535.0f); + + for (uint32_t i = 0; i < texel_count; i += ASTCENC_SIMD_WIDTH) + { + vfloat color_orig_r = loada(blk.data_r + i); + vfloat color_orig_g = loada(blk.data_g + i); + vfloat color_orig_b = loada(blk.data_b + i); + vfloat color_orig_a = loada(blk.data_a + i); + + vfloat color_error_r = min(abs(color_orig_r - color_r), vfloat(1e15f)); + vfloat color_error_g = min(abs(color_orig_g - color_g), vfloat(1e15f)); + vfloat color_error_b = min(abs(color_orig_b - color_b), vfloat(1e15f)); + vfloat color_error_a = min(abs(color_orig_a - color_a), vfloat(1e15f)); + + // Compute squared error metric + color_error_r = color_error_r * color_error_r; + color_error_g = color_error_g * color_error_g; + color_error_b = color_error_b * color_error_b; + color_error_a = color_error_a * color_error_a; + + vfloat metric = color_error_r * blk.channel_weight.lane<0>() + + color_error_g * blk.channel_weight.lane<1>() + + color_error_b * blk.channel_weight.lane<2>() + + color_error_a * blk.channel_weight.lane<3>(); + + // Mask off bad lanes + vmask mask = lane_id < vint(texel_count); + lane_id += vint(ASTCENC_SIMD_WIDTH); + haccumulate(summav, metric, mask); + } + + return hadd_s(summav); +} +#else +static float compute_block_mse( + const image_block& orig, + const image_block& cmp, + const block_size_descriptor& bsd, + uint32_t image_x, + uint32_t image_y, + uint32_t image_z, + float orig_scale, + float cmp_scale +) { + vfloatacc summav = vfloatacc::zero(); + vint lane_id = vint::lane_id(); + uint32_t texel_count = orig.texel_count; + + uint32_t block_x = astc::min(image_x - orig.xpos, (uint32_t)bsd.xdim); + uint32_t block_y = astc::min(image_y - orig.ypos, (uint32_t)bsd.ydim); + uint32_t block_z = astc::min(image_z - orig.zpos, (uint32_t)bsd.zdim); + + for (uint32_t i = 0; i < texel_count; i += ASTCENC_SIMD_WIDTH, lane_id += vint(ASTCENC_SIMD_WIDTH)) + { + vfloat color_orig_r = loada(orig.data_r + i) * orig_scale; + vfloat color_orig_g = loada(orig.data_g + i) * orig_scale; + vfloat color_orig_b = loada(orig.data_b + i) * orig_scale; + vfloat color_orig_a = loada(orig.data_a + i) * orig_scale; + + vfloat color_cmp_r = loada(cmp.data_r + i) * cmp_scale; + vfloat color_cmp_g = loada(cmp.data_g + i) * cmp_scale; + vfloat color_cmp_b = loada(cmp.data_b + i) * cmp_scale; + vfloat color_cmp_a = loada(cmp.data_a + i) * cmp_scale; + + vfloat color_error_r = min(abs(color_orig_r - color_cmp_r), vfloat(1e15f)); + vfloat color_error_g = min(abs(color_orig_g - color_cmp_g), vfloat(1e15f)); + vfloat color_error_b = min(abs(color_orig_b - color_cmp_b), vfloat(1e15f)); + vfloat color_error_a = min(abs(color_orig_a - color_cmp_a), vfloat(1e15f)); + + // Compute squared error metric + color_error_r = color_error_r * color_error_r; + color_error_g = color_error_g * color_error_g; + color_error_b = color_error_b * color_error_b; + color_error_a = color_error_a * color_error_a; + + vfloat metric = color_error_r * orig.channel_weight.lane<0>() + + color_error_g * orig.channel_weight.lane<1>() + + color_error_b * orig.channel_weight.lane<2>() + + color_error_a * orig.channel_weight.lane<3>(); + + // Mask off bad lanes + vint lane_id_z(float_to_int(int_to_float(lane_id) / float(bsd.xdim * bsd.ydim))); + vint rem_idx = lane_id - lane_id_z * vint(bsd.xdim * bsd.ydim); + vint lane_id_y = float_to_int(int_to_float(rem_idx) / bsd.xdim); + vint lane_id_x = rem_idx - lane_id_y * vint(bsd.xdim); + vmask mask = (lane_id_x < vint(block_x)) & (lane_id_y < vint(block_y)) & (lane_id_z < vint(block_z)); + haccumulate(summav, metric, mask); + } + + return hadd_s(summav) / (block_x * block_y * block_z); +} +#endif + +struct local_rdo_context +{ + const astcenc_contexti* ctx; + uint32_t base_offset; +}; + +static bool is_transparent(int v) { return (v & 0xFF) != 0xFF; } + +static bool has_any_transparency( + astcenc_profile decode_mode, + const symbolic_compressed_block& scb +) { + if (scb.block_type != SYM_BTYPE_NONCONST) return is_transparent(scb.constant_color[3]); + + vint4 ep0; + vint4 ep1; + bool rgb_lns; + bool a_lns; + + for (int i = 0; i < scb.partition_count; i++) + { + unpack_color_endpoints(decode_mode, scb.color_formats[i], scb.color_values[i], rgb_lns, a_lns, ep0, ep1); + if (is_transparent(ep0.lane<3>()) || is_transparent(ep1.lane<3>())) return true; + } + return false; +} + +static float compute_block_difference( + void* user_data, + const uint8_t* pcb, + uint32_t local_block_idx, + float* out_max_std_dev +) { + const local_rdo_context& local_ctx = *(local_rdo_context*)user_data; + const astcenc_contexti& ctx = *local_ctx.ctx; + + symbolic_compressed_block scb; + physical_to_symbolic(*ctx.bsd, pcb, scb); + + // Trial blocks may not be valid at all + if (scb.block_type == SYM_BTYPE_ERROR) return -ERROR_CALC_DEFAULT; + bool is_dual_plane = scb.block_type == SYM_BTYPE_NONCONST && ctx.bsd->get_block_mode(scb.block_mode).is_dual_plane; + if (is_dual_plane && scb.partition_count != 1) return -ERROR_CALC_DEFAULT; + if (ctx.config.cw_a_weight < 0.01f && has_any_transparency(ctx.config.profile, scb)) return -ERROR_CALC_DEFAULT; + + const astcenc_rdo_context& rdo_ctx = *ctx.rdo_context; + uint32_t block_idx = local_block_idx + local_ctx.base_offset; + const image_block& blk = rdo_ctx.m_blocks[block_idx]; + + if (out_max_std_dev) + { + // ERT expects texel values to be in [0, 255] + *out_max_std_dev = compute_block_std_dev(blk, 255.0f / 65535.0f); + } + +#if ASTCENC_RDO_SPECIALIZE_DIFF + float squared_error = 0.0f; + + if (scb.block_type != SYM_BTYPE_NONCONST) + squared_error = compute_symbolic_block_difference_constant(ctx.config, *ctx.bsd, scb, blk); + else if (is_dual_plane) + squared_error = compute_symbolic_block_difference_2plane(ctx.config, *ctx.bsd, scb, blk); + else if (scb.partition_count == 1) + squared_error = compute_symbolic_block_difference_1plane_1partition(ctx.config, *ctx.bsd, scb, blk); + else + squared_error = compute_symbolic_block_difference_1plane(ctx.config, *ctx.bsd, scb, blk); + + // ERT expects texel values to be in [0, 255] + return squared_error / blk.texel_count * sqr(255.0f / 65535.0f); +#else + image_block decoded_blk; + decoded_blk.decode_unorm8 = blk.decode_unorm8; + decoded_blk.texel_count = blk.texel_count; + decoded_blk.channel_weight = blk.channel_weight; + + decompress_symbolic_block(ctx.config.profile, *ctx.bsd, blk.xpos, blk.ypos, blk.zpos, scb, decoded_blk); + + // ERT expects texel values to be in [0, 255] + return compute_block_mse(blk, decoded_blk, *ctx.bsd, rdo_ctx.m_image_x, rdo_ctx.m_image_y, rdo_ctx.m_image_z, 255.0f / 65535.0f, 255.0f); +#endif +} + +void rate_distortion_optimize( + astcenc_context& ctxo, + const astcenc_image& image, + const astcenc_swizzle& swizzle, + uint8_t* buffer +) { + if (!ctxo.context.config.rdo_enabled) + { + return; + } + + // Only the first thread actually runs the initializer + ctxo.manage_rdo.init([&ctxo, &image, &swizzle] + { + return init_rdo_context(ctxo.context, image, swizzle); + }, + ctxo.context.config.progress_callback ? rdo_progress_emitter : nullptr); + + const astcenc_contexti& ctx = ctxo.context; + uint32_t xblocks = (image.dim_x + ctx.bsd->xdim - 1u) / ctx.bsd->xdim; + uint32_t yblocks = (image.dim_y + ctx.bsd->ydim - 1u) / ctx.bsd->ydim; + uint32_t zblocks = (image.dim_z + ctx.bsd->zdim - 1u) / ctx.bsd->zdim; + uint32_t total_blocks = xblocks * yblocks * zblocks; + + uint32_t blocks_per_task = astc::min(ctx.config.rdo_lookback, total_blocks); + // There is no way to losslessly partition the job (sequentially dependent on previous output) + // So we reserve up to one task for each thread to minimize the quality impact. + uint32_t partitions = ctx.config.rdo_partitions ? ctx.config.rdo_partitions : ctx.thread_count; + blocks_per_task = astc::max(blocks_per_task, (total_blocks - 1) / partitions + 1); + + uint32_t total_modified = 0; + while (true) + { + uint32_t count; + uint32_t base = ctxo.manage_rdo.get_task_assignment(blocks_per_task, count); + if (!count) + { + break; + } + + local_rdo_context local_ctx{ &ctx, base }; + + ert::reduce_entropy(buffer + base * ASTCENC_BYTES_PER_BLOCK, count, + ASTCENC_BYTES_PER_BLOCK, ASTCENC_BYTES_PER_BLOCK, + ctx.rdo_context->m_ert_params, total_modified, + &compute_block_difference, &local_ctx); + + ctxo.manage_rdo.complete_task_assignment(count); + } + + // Wait for rdo to complete before freeing memory + ctxo.manage_rdo.wait(); + + // Only the first thread to arrive actually runs the term + ctxo.manage_rdo.term([&ctxo] + { + delete ctxo.context.rdo_context; + ctxo.context.rdo_context = nullptr; + }); +} + +#endif diff --git a/Source/astcenccli_toplevel.cpp b/Source/astcenccli_toplevel.cpp index 39eb586a..41c36a09 100644 --- a/Source/astcenccli_toplevel.cpp +++ b/Source/astcenccli_toplevel.cpp @@ -1175,6 +1175,83 @@ static int edit_astcenc_config( argidx += 1; cli_config.diagnostic_images = true; } + else if (!strcmp(argv[argidx], "-rdo")) + { + argidx += 1; + config.rdo_enabled = true; + } + else if (!strcmp(argv[argidx], "-rdo-quality")) + { + argidx += 2; + if (argidx > argc) + { + print_error("ERROR: -rdo-quality switch with no argument\n"); + return 1; + } + + config.rdo_enabled = true; + config.rdo_quality = static_cast(atof(argv[argidx - 1])); + } + else if (!strcmp(argv[argidx], "-rdo-dict-size")) + { + argidx += 2; + if (argidx > argc) + { + print_error("ERROR: -rdo-dict-size switch with no argument\n"); + return 1; + } + + config.rdo_enabled = true; + config.rdo_lookback = atoi(argv[argidx - 1]) / 16; + } + else if (!strcmp(argv[argidx], "-rdo-lookback")) + { + argidx += 2; + if (argidx > argc) + { + print_error("ERROR: -rdo-lookback switch with no argument\n"); + return 1; + } + + config.rdo_enabled = true; + config.rdo_lookback = atoi(argv[argidx - 1]); + } + else if (!strcmp(argv[argidx], "-rdo-partitions")) + { + argidx += 2; + if (argidx > argc) + { + print_error("ERROR: -rdo-partitions switch with no argument\n"); + return 1; + } + + config.rdo_enabled = true; + config.rdo_partitions = atoi(argv[argidx - 1]); + } + else if (!strcmp(argv[argidx], "-rdo-max-smooth-block-error-scale")) + { + argidx += 2; + if (argidx > argc) + { + print_error("ERROR: -rdo-max-smooth-block-error-scale switch with no argument\n"); + return 1; + } + + config.rdo_enabled = true; + config.rdo_max_smooth_block_error_scale = static_cast(atof(argv[argidx - 1])); + } + else if (!strcmp(argv[argidx], "-rdo-max-smooth-block-std-dev")) + { + argidx += 2; + if (argidx > argc) + { + print_error("ERROR: -rdo-max-smooth-block-std-dev switch with no argument\n"); + return 1; + } + + config.rdo_enabled = true; + config.rdo_max_smooth_block_std_dev = static_cast(atof(argv[argidx - 1])); + } else // check others as well { print_error("ERROR: Argument '%s' not recognized\n", argv[argidx]); @@ -1265,6 +1342,15 @@ static void print_astcenc_config( printf(" Candidate cutoff: %u candidates\n", config.tune_candidate_limit); printf(" Refinement cutoff: %u iterations\n", config.tune_refinement_limit); printf(" Compressor thread count: %d\n", cli_config.thread_count); + printf(" Rate-distortion opt: %s\n", config.rdo_enabled ? "Enabled" : "Disabled"); + if (config.rdo_enabled) + { + printf(" RDO quality: %g\n", static_cast(config.rdo_quality)); + printf(" RDO lookback: %u blocks\n", config.rdo_lookback); + printf(" RDO max error scale: %g\n", static_cast(config.rdo_max_smooth_block_error_scale)); + printf(" RDO max standard deviation: %g\n", static_cast(config.rdo_max_smooth_block_std_dev)); + if (config.rdo_partitions) printf(" RDO partitions: %u\n", config.rdo_partitions); + } printf("\n"); } } @@ -1569,7 +1655,8 @@ static void print_diagnostic_image( static void print_diagnostic_images( astcenc_context* context, const astc_compressed_image& image, - const std::string& output_file + const std::string& output_file, + astcenc_operation operation ) { if (image.dim_z != 1) { @@ -1586,6 +1673,13 @@ static void print_diagnostic_images( auto diag_image = alloc_image(8, image.dim_x, image.dim_y, image.dim_z); + // ---- ---- ---- ---- Compressed Output ---- ---- ---- ---- + if ((operation & ASTCENC_STAGE_ST_COMP) == 0) + { + std::string fname = stem + "_diag.astc"; + store_cimage(image, fname.c_str()); + } + // ---- ---- ---- ---- Partitioning ---- ---- ---- ---- auto partition_func = [](astcenc_block_info& info, size_t texel_x, size_t texel_y) { const vint4 colors[] { @@ -1989,6 +2083,12 @@ int astcenc_main( return 1; } + // Unpacking RDO trials blocks requires full initialization + if (config.rdo_enabled && static_cast(config.flags & ASTCENC_FLG_SELF_DECOMPRESS_ONLY)) + { + config.flags &= ~ASTCENC_FLG_SELF_DECOMPRESS_ONLY; + } + // Enable progress callback if not in silent mode and using a terminal #if defined(_WIN32) int stdoutfno = _fileno(stdout); @@ -2329,7 +2429,7 @@ int astcenc_main( // Store diagnostic images if (cli_config.diagnostic_images && !is_null) { - print_diagnostic_images(codec_context, image_comp, output_filename); + print_diagnostic_images(codec_context, image_comp, output_filename, operation); } free_image(image_uncomp_in); diff --git a/Source/astcenccli_toplevel_help.cpp b/Source/astcenccli_toplevel_help.cpp index 71b9a42d..3aa8767f 100644 --- a/Source/astcenccli_toplevel_help.cpp +++ b/Source/astcenccli_toplevel_help.cpp @@ -385,6 +385,43 @@ ADVANCED COMPRESSION -thorough : 0.95 -verythorough : 0.98 -exhaustive : 0.99 + + -rdo + Enable Rate Distortion Optimization (RDO) post-processing. + + -rdo-quality + RDO quality scalar (lambda). Lower values yield higher + quality/larger LZ compressed files, higher values yield lower + quality/smaller LZ compressed files. A good range to try is [.2,4]. + Full range is [.001,50.0]. Default to 0.5. + + -rdo-lookback + RDO look back size in blocks. Lower values=faster, + but give less compression. Range is [4,4096]. Preset defaults are: + + -fastest : 64 + -fast : 128 + -medium : 256 + -thorough : 256 + -verythorough : 256 + -exhaustive : 256 + + -rdo-dict-size + Same as rdo-lookback, but in bytes. + + -rdo-partitions + RDO task partitions. Default to current number of threads. + Customize this for a deterministic output regardless of the running hardware. + + -rdo-max-smooth-block-error-scale + RDO max smooth block error scale. Range is [1,300]. + Default is 10.0, 1.0 is disabled. Larger values suppress more + artifacts (and allocate more bits) on smooth blocks. + + -rdo-max-smooth-block-std-dev + RDO max smooth block standard deviation. Range is + [.01,65536.0]. Default is 18.0. Larger values expand the range of + blocks considered smooth. )" // This split in the literals is needed for Visual Studio; the compiler // will concatenate these two strings together ... diff --git a/Source/cmake_core.cmake b/Source/cmake_core.cmake index e78eb70b..12d5c4cc 100644 --- a/Source/cmake_core.cmake +++ b/Source/cmake_core.cmake @@ -42,6 +42,8 @@ set(is_clang "$>") add_library(${ASTCENC_TARGET}-static STATIC + ert.cpp + astcenc_rate_distortion.cpp astcenc_averages_and_directions.cpp astcenc_block_sizes.cpp astcenc_color_quantize.cpp diff --git a/Source/ert.cpp b/Source/ert.cpp new file mode 100644 index 00000000..7758cc0c --- /dev/null +++ b/Source/ert.cpp @@ -0,0 +1,540 @@ +// SPDX-License-Identifier: Unlicense +// ---------------------------------------------------------------------------- +// This is free and unencumbered software released into the public domain. +// Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +// software, either in source code form or as a compiled binary, for any purpose, +// commercial or non - commercial, and by any means. +// In jurisdictions that recognize copyright laws, the author or authors of this +// software dedicate any and all copyright interest in the software to the public +// domain. We make this dedication for the benefit of the public at large and to +// the detriment of our heirs and successors.We intend this dedication to be an +// overt act of relinquishment in perpetuity of all present and future rights to +// this software under copyright law. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE +// AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// ---------------------------------------------------------------------------- + +#include "ert.h" + +#include +#include + +#define ERT_FAVOR_CONT_AND_REP0_MATCHES (1) +#define ERT_FAVOR_REP0_MATCHES (0) +#define ERT_ENABLE_DEBUG (0) + +namespace ert +{ + const uint32_t MAX_BLOCK_SIZE_IN_BYTES = 256; + const uint32_t MIN_MATCH_LEN = 3; + const float LITERAL_BITS = 13.0f; + const float MATCH_CONTINUE_BITS = 1.0f; + const float MATCH_REP0_BITS = 4.0f; + + static inline float clampf(float value, float low, float high) { if (value < low) value = low; else if (value > high) value = high; return value; } + template inline F lerp(F a, F b, F s) { return a + (b - a) * s; } + + static const uint8_t g_tdefl_small_dist_extra[512] = + { + 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7 + }; + + static const uint8_t g_tdefl_large_dist_extra[128] = + { + 0, 0, 8, 8, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13 + }; + + static inline uint32_t compute_match_cost_estimate(uint32_t dist, uint32_t match_len_in_bytes) + { + assert(match_len_in_bytes <= 258); + + uint32_t len_cost = 6; + if (match_len_in_bytes >= 12) + len_cost = 9; + else if (match_len_in_bytes >= 8) + len_cost = 8; + else if (match_len_in_bytes >= 6) + len_cost = 7; + + uint32_t dist_cost = 5; + if (dist < 512) + dist_cost += g_tdefl_small_dist_extra[dist & 511]; + else + { + dist_cost += g_tdefl_large_dist_extra[std::min(dist, 32767) >> 8]; + while (dist >= 32768) + { + dist_cost++; + dist >>= 1; + } + } + return len_cost + dist_cost; + } + + uint32_t hash_hsieh(const uint8_t* pBuf, size_t len, uint32_t salt) + { + if (!pBuf || !len) + return 0; + + uint32_t h = static_cast(len + (salt << 16)); + + const uint32_t bytes_left = len & 3; + len >>= 2; + + while (len--) + { + const uint16_t* pWords = reinterpret_cast(pBuf); + + h += pWords[0]; + + const uint32_t t = (pWords[1] << 11) ^ h; + h = (h << 16) ^ t; + + pBuf += sizeof(uint32_t); + + h += h >> 11; + } + + switch (bytes_left) + { + case 1: + h += *reinterpret_cast(pBuf); + h ^= h << 10; + h += h >> 1; + break; + case 2: + h += *reinterpret_cast(pBuf); + h ^= h << 11; + h += h >> 17; + break; + case 3: + h += *reinterpret_cast(pBuf); + h ^= h << 16; + h ^= (static_cast(pBuf[sizeof(uint16_t)])) << 18; + h += h >> 11; + break; + default: + break; + } + + h ^= h << 3; + h += h >> 5; + h ^= h << 4; + h += h >> 17; + h ^= h << 25; + h += h >> 6; + + return h; + } + + // BC7 entropy reduction transform with Deflate/LZMA/LZHAM optimizations + bool reduce_entropy(uint8_t* pBlock_bytes, uint32_t num_blocks, + uint32_t total_block_stride_in_bytes, uint32_t block_size_to_optimize_in_bytes, + const reduce_entropy_params& params, uint32_t& total_modified, + diff_block_func_type* pDiff_block_func, void* pDiff_block_func_user_data, + const float* pBlock_mse_scales) + { + assert(total_block_stride_in_bytes && block_size_to_optimize_in_bytes); + assert(total_block_stride_in_bytes >= block_size_to_optimize_in_bytes); + + assert((block_size_to_optimize_in_bytes >= MIN_MATCH_LEN) && (block_size_to_optimize_in_bytes <= MAX_BLOCK_SIZE_IN_BYTES)); + if ((block_size_to_optimize_in_bytes < MIN_MATCH_LEN) || (block_size_to_optimize_in_bytes > MAX_BLOCK_SIZE_IN_BYTES)) + return false; + + const int total_blocks_to_check = std::max(1U, params.m_lookback_window_size / total_block_stride_in_bytes); + +#if ERT_ENABLE_DEBUG + uint32_t len_hist[MAX_BLOCK_SIZE_IN_BYTES + 1]; + uint32_t second_len_hist[MAX_BLOCK_SIZE_IN_BYTES + 1]; + uint32_t total_second_matches = 0; + uint32_t total_smooth_blocks = 0; +#endif + + int prev_match_window_ofs_to_favor_cont = -1, prev_match_dist_to_favor = -1; + + const uint32_t HASH_SIZE = 8192; + uint32_t hash[HASH_SIZE]; + + for (uint32_t block_index = 0; block_index < num_blocks; block_index++) + { + if ((block_index & 0xFF) == 0) + memset(hash, 0, sizeof(hash)); + + uint8_t* pOrig_block = &pBlock_bytes[block_index * total_block_stride_in_bytes]; + + float max_std_dev = 0.0f; + float cur_mse = (*pDiff_block_func)(pDiff_block_func_user_data, pOrig_block, block_index, &max_std_dev); + if (cur_mse < 0.0f) + return false; + + if ((params.m_skip_zero_mse_blocks) && (cur_mse == 0.0f)) + continue; + + float yl = clampf(max_std_dev / params.m_max_smooth_block_std_dev, 0.0f, 1.0f); + yl = yl * yl; + float smooth_block_mse_scale = lerp(params.m_smooth_block_max_mse_scale, 1.0f, yl); + + if (pBlock_mse_scales) + { + if (pBlock_mse_scales[block_index] > 0.0f) + { + smooth_block_mse_scale = pBlock_mse_scales[block_index]; + } + } + +#if ERT_ENABLE_DEBUG + if (smooth_block_mse_scale > 1.0f) + total_smooth_blocks++; +#endif + + float cur_bits = (LITERAL_BITS * block_size_to_optimize_in_bytes); + float cur_t = cur_mse * smooth_block_mse_scale + cur_bits * params.m_lambda; + + int first_block_to_check = std::max(0, block_index - total_blocks_to_check); + int last_block_to_check = block_index - 1; + + uint8_t best_block[MAX_BLOCK_SIZE_IN_BYTES]; + memcpy(best_block, pOrig_block, block_size_to_optimize_in_bytes); + + float best_t = cur_t; + uint32_t best_match_len = 0, best_match_src_window_ofs = 0, best_match_dst_window_ofs = 0, best_match_dst_block_ofs = 0; + float best_match_bits = 0; + + // Don't let thresh_ms_err be 0 to let zero error blocks have slightly increased distortion + const float thresh_ms_err = params.m_max_allowed_rms_increase_ratio * params.m_max_allowed_rms_increase_ratio * std::max(cur_mse, 1.0f); + + for (int prev_block_index = last_block_to_check; prev_block_index >= first_block_to_check; --prev_block_index) + { + const uint8_t* pPrev_blk = &pBlock_bytes[prev_block_index * total_block_stride_in_bytes]; + + for (uint32_t len = block_size_to_optimize_in_bytes; len >= MIN_MATCH_LEN; len--) + { + if (params.m_allow_relative_movement) + { + for (uint32_t src_ofs = 0; src_ofs <= (block_size_to_optimize_in_bytes - len); src_ofs++) + { + assert(len + src_ofs <= block_size_to_optimize_in_bytes); + + const uint32_t src_match_window_ofs = prev_block_index * total_block_stride_in_bytes + src_ofs; + + for (uint32_t dst_ofs = 0; dst_ofs <= (block_size_to_optimize_in_bytes - len); dst_ofs++) + { + assert(len + dst_ofs <= block_size_to_optimize_in_bytes); + + const uint32_t dst_match_window_ofs = block_index * total_block_stride_in_bytes + dst_ofs; + + const uint32_t match_dist = dst_match_window_ofs - src_match_window_ofs; + + float trial_match_bits, trial_total_bits; + + uint32_t hs = hash_hsieh(pPrev_blk + src_ofs, len, dst_ofs); + +#if ERT_FAVOR_CONT_AND_REP0_MATCHES + // Continue a previous match (which would cross block boundaries) + if (((int)src_match_window_ofs == prev_match_window_ofs_to_favor_cont) && (dst_ofs == 0)) + { + trial_match_bits = MATCH_CONTINUE_BITS; + trial_total_bits = (block_size_to_optimize_in_bytes - len) * LITERAL_BITS + MATCH_CONTINUE_BITS; + } + // Exploit REP0 matches + else if ((prev_match_dist_to_favor != -1) && (src_match_window_ofs == (dst_match_window_ofs - prev_match_dist_to_favor))) + { + trial_match_bits = MATCH_REP0_BITS; + trial_total_bits = (block_size_to_optimize_in_bytes - len) * LITERAL_BITS + MATCH_REP0_BITS; + } + else + { + trial_match_bits = (float)compute_match_cost_estimate(match_dist, len); + trial_total_bits = (block_size_to_optimize_in_bytes - len) * LITERAL_BITS + trial_match_bits; + + uint32_t hash_check = hash[hs & (HASH_SIZE - 1)]; + if ((hash_check & 0xFF) == (block_index & 0xFF)) + { + if ((hash_check >> 8) == (hs >> 8)) + continue; + } + } +#else + uint32_t hash_check = hash[hs & (HASH_SIZE - 1)]; + if ((hash_check & 0xFF) == (block_index & 0xFF)) + { + if ((hash_check >> 8) == (hs >> 8)) + continue; + } +#endif + + hash[hs & (HASH_SIZE - 1)] = (hs & 0xFFFFFF00) | (block_index & 0xFF); + + const float trial_total_bits_times_lambda = trial_total_bits * params.m_lambda; + + uint8_t trial_block[MAX_BLOCK_SIZE_IN_BYTES]; + memcpy(trial_block, pOrig_block, block_size_to_optimize_in_bytes); + memcpy(trial_block + dst_ofs, pPrev_blk + src_ofs, len); + + float trial_mse = (*pDiff_block_func)(pDiff_block_func_user_data, trial_block, block_index, nullptr); + if (trial_mse < 0.0f) + continue; + + if (trial_mse < thresh_ms_err) + { + float t = trial_mse * smooth_block_mse_scale + trial_total_bits_times_lambda; + + if (t < best_t) + { + best_t = t; + memcpy(best_block, trial_block, block_size_to_optimize_in_bytes); + best_match_len = len; + best_match_src_window_ofs = src_match_window_ofs; + best_match_dst_window_ofs = dst_match_window_ofs; + best_match_dst_block_ofs = dst_ofs; + best_match_bits = trial_match_bits; + } + } + + } // dst_ofs + } // src_ofs + } + else + { + const uint32_t match_dist = (block_index - prev_block_index) * total_block_stride_in_bytes; + + // Assume the block has 1 match and block_size_to_optimize_in_bytes-match_len literals. + const float trial_match_bits = (float)compute_match_cost_estimate(match_dist, len); + const float trial_total_bits = (block_size_to_optimize_in_bytes - len) * LITERAL_BITS + trial_match_bits; + const float trial_total_bits_times_lambda = trial_total_bits * params.m_lambda; + + for (uint32_t ofs = 0; ofs <= (block_size_to_optimize_in_bytes - len); ofs++) + { + assert(len + ofs <= block_size_to_optimize_in_bytes); + + const uint32_t dst_match_window_ofs = block_index * total_block_stride_in_bytes + ofs; + const uint32_t src_match_window_ofs = prev_block_index * total_block_stride_in_bytes + ofs; + + float trial_match_bits_to_use = trial_match_bits; + float trial_total_bits_times_lambda_to_use = trial_total_bits_times_lambda; + + uint32_t hs = hash_hsieh(pPrev_blk + ofs, len, ofs); + +#if ERT_FAVOR_CONT_AND_REP0_MATCHES + // Continue a previous match (which would cross block boundaries) + if (((int)src_match_window_ofs == prev_match_window_ofs_to_favor_cont) && (ofs == 0)) + { + float continue_match_trial_bits = (block_size_to_optimize_in_bytes - len) * LITERAL_BITS + MATCH_CONTINUE_BITS; + trial_match_bits_to_use = MATCH_CONTINUE_BITS; + trial_total_bits_times_lambda_to_use = continue_match_trial_bits * params.m_lambda; + } + // Exploit REP0 matches + else if ((prev_match_dist_to_favor != -1) && (src_match_window_ofs == (dst_match_window_ofs - prev_match_dist_to_favor))) + { + float continue_match_trial_bits = (block_size_to_optimize_in_bytes - len) * LITERAL_BITS + MATCH_REP0_BITS; + trial_match_bits_to_use = MATCH_REP0_BITS; + trial_total_bits_times_lambda_to_use = continue_match_trial_bits * params.m_lambda; + } + else + { + uint32_t hash_check = hash[hs & (HASH_SIZE - 1)]; + if ((hash_check & 0xFF) == (block_index & 0xFF)) + { + if ((hash_check >> 8) == (hs >> 8)) + continue; + } + } +#else + uint32_t hash_check = hash[hs & (HASH_SIZE - 1)]; + if ((hash_check & 0xFF) == (block_index & 0xFF)) + { + if ((hash_check >> 8) == (hs >> 8)) + continue; + } +#endif + + hash[hs & (HASH_SIZE - 1)] = (hs & 0xFFFFFF00) | (block_index & 0xFF); + + uint8_t trial_block[MAX_BLOCK_SIZE_IN_BYTES]; + memcpy(trial_block, pOrig_block, block_size_to_optimize_in_bytes); + memcpy(trial_block + ofs, pPrev_blk + ofs, len); + + float trial_mse = (*pDiff_block_func)(pDiff_block_func_user_data, trial_block, block_index, nullptr); + if (trial_mse < 0.0f) + continue; + + if (trial_mse < thresh_ms_err) + { + float t = trial_mse * smooth_block_mse_scale + trial_total_bits_times_lambda_to_use; + + if (t < best_t) + { + best_t = t; + memcpy(best_block, trial_block, block_size_to_optimize_in_bytes); + best_match_len = len; + best_match_src_window_ofs = src_match_window_ofs; + best_match_dst_window_ofs = dst_match_window_ofs; + best_match_dst_block_ofs = ofs; + best_match_bits = trial_match_bits_to_use; + } + } + } // ofs + } + + } // len + + } // prev_block_index + + if (best_t < cur_t) + { + uint32_t best_second_match_len = 0, best_second_match_src_window_ofs = 0, best_second_match_dst_window_ofs = 0, best_second_match_dst_block_ofs = 0; + + // Try injecting a second match, being sure it does't overlap with the first. + if ((params.m_try_two_matches) && (best_match_len <= (block_size_to_optimize_in_bytes - 3))) + { + uint8_t matched_flags[MAX_BLOCK_SIZE_IN_BYTES]{}; + memset(matched_flags + best_match_dst_block_ofs, 1, best_match_len); + + uint8_t orig_best_block[MAX_BLOCK_SIZE_IN_BYTES]; + memcpy(orig_best_block, best_block, block_size_to_optimize_in_bytes); + + for (int prev_block_index = last_block_to_check; prev_block_index >= first_block_to_check; --prev_block_index) + { + const uint8_t* pPrev_blk = &pBlock_bytes[prev_block_index * total_block_stride_in_bytes]; + + const uint32_t match_dist = (block_index - prev_block_index) * total_block_stride_in_bytes; + + for (uint32_t len = 3; len <= (block_size_to_optimize_in_bytes - best_match_len); len++) + { + const float trial_total_bits = (block_size_to_optimize_in_bytes - len - best_match_len) * LITERAL_BITS + compute_match_cost_estimate(match_dist, len) + best_match_bits; + + const float trial_total_bits_times_lambda = trial_total_bits * params.m_lambda; + + for (uint32_t ofs = 0; ofs <= (block_size_to_optimize_in_bytes - len); ofs++) + { + int i; + for (i = 0; i < (int)len; i++) + if (matched_flags[ofs + i]) + break; + if (i != (int)len) + continue; + + assert(len + ofs <= block_size_to_optimize_in_bytes); + + const uint32_t dst_match_window_ofs = block_index * total_block_stride_in_bytes + ofs; + const uint32_t src_match_window_ofs = prev_block_index * total_block_stride_in_bytes + ofs; + + uint8_t trial_block[MAX_BLOCK_SIZE_IN_BYTES]; + memcpy(trial_block, orig_best_block, block_size_to_optimize_in_bytes); + memcpy(trial_block + ofs, pPrev_blk + ofs, len); + + float trial_mse = (*pDiff_block_func)(pDiff_block_func_user_data, trial_block, block_index, nullptr); + if (trial_mse < 0.0f) + continue; + + if (trial_mse < thresh_ms_err) + { + float t = trial_mse * smooth_block_mse_scale + trial_total_bits_times_lambda; + + if (t < best_t) + { + best_t = t; + memcpy(best_block, trial_block, block_size_to_optimize_in_bytes); + best_second_match_len = len; + best_second_match_src_window_ofs = src_match_window_ofs; + best_second_match_dst_window_ofs = dst_match_window_ofs; + best_second_match_dst_block_ofs = ofs; + } + } + } + } + } + } + + memcpy(pOrig_block, best_block, block_size_to_optimize_in_bytes); + total_modified++; + + if ((best_second_match_len == 0) || (best_match_dst_window_ofs > best_second_match_dst_window_ofs)) + { + int best_match_dist = best_match_dst_window_ofs - best_match_src_window_ofs; + assert(best_match_dist >= 1); + (void)best_match_dist; + + if (block_size_to_optimize_in_bytes == total_block_stride_in_bytes) + { + // If the match goes all the way to the end of a block, we can try to continue it on the next encoded block. + if ((best_match_dst_block_ofs + best_match_len) == total_block_stride_in_bytes) + prev_match_window_ofs_to_favor_cont = best_match_src_window_ofs + best_match_len; + else + prev_match_window_ofs_to_favor_cont = -1; + } + +#if ERT_FAVOR_REP0_MATCHES + // Compute the window offset where a cheaper REP0 match would be available + prev_match_dist_to_favor = best_match_dist; +#endif + } + else + { + int best_match_dist = best_second_match_dst_window_ofs - best_second_match_src_window_ofs; + assert(best_match_dist >= 1); + (void)best_match_dist; + + if (block_size_to_optimize_in_bytes == total_block_stride_in_bytes) + { + // If the match goes all the way to the end of a block, we can try to continue it on the next encoded block. + if ((best_second_match_dst_block_ofs + best_second_match_len) == total_block_stride_in_bytes) + prev_match_window_ofs_to_favor_cont = best_second_match_src_window_ofs + best_second_match_len; + else + prev_match_window_ofs_to_favor_cont = -1; + } + +#if ERT_FAVOR_REP0_MATCHES + // Compute the window offset where a cheaper REP0 match would be available + prev_match_dist_to_favor = best_match_dist; +#endif + } + +#if ERT_ENABLE_DEBUG + len_hist[best_match_len]++; + + if (best_second_match_len) + { + second_len_hist[best_second_match_len]++; + total_second_matches++; + } +#endif + } + else + { + prev_match_window_ofs_to_favor_cont = -1; + } + + } // block_index +#if ERT_ENABLE_DEBUG + if (params.m_debug_output) + { + printf("Total smooth blocks: %3.2f%%\n", total_smooth_blocks * 100.0f / num_blocks); + + printf("Match length histogram:\n"); + for (uint32_t i = MIN_MATCH_LEN; i <= block_size_to_optimize_in_bytes; i++) + printf("%u%c", len_hist[i], (i < block_size_to_optimize_in_bytes) ? ',' : '\n'); + + printf("Total second matches: %u %3.2f%%\n", total_second_matches, total_second_matches * 100.0f / num_blocks); + printf("Secod match length histogram:\n"); + for (uint32_t i = MIN_MATCH_LEN; i <= block_size_to_optimize_in_bytes; i++) + printf("%u%c", second_len_hist[i], (i < block_size_to_optimize_in_bytes) ? ',' : '\n'); + } +#endif + return true; + } + +} // namespace ert diff --git a/Source/ert.h b/Source/ert.h new file mode 100644 index 00000000..1ad1d400 --- /dev/null +++ b/Source/ert.h @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: Unlicense +// ---------------------------------------------------------------------------- +// This is free and unencumbered software released into the public domain. +// Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +// software, either in source code form or as a compiled binary, for any purpose, +// commercial or non - commercial, and by any means. +// In jurisdictions that recognize copyright laws, the author or authors of this +// software dedicate any and all copyright interest in the software to the public +// domain. We make this dedication for the benefit of the public at large and to +// the detriment of our heirs and successors.We intend this dedication to be an +// overt act of relinquishment in perpetuity of all present and future rights to +// this software under copyright law. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE +// AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// ---------------------------------------------------------------------------- + +#pragma once + +#include +#include +#include + +// Based on https://github.com/richgel999/bc7enc_rdo/blob/master/ert.h +// With interface tweaks that can make format integration more general & performant. + +namespace ert +{ + struct reduce_entropy_params + { + // m_lambda: The post-processor tries to reduce distortion*smooth_block_scale + rate*lambda (rate is approximate LZ bits and distortion is scaled MS error multiplied against the smooth block MSE weighting factor). + // Larger values push the postprocessor towards optimizing more for lower rate, and smaller values more for distortion. 0=minimal distortion. + float m_lambda; + + // m_lookback_window_size: The number of bytes the encoder can look back from each block to find matches. The larger this value, the slower the encoder but the higher the quality per LZ compressed bit. + uint32_t m_lookback_window_size; + + // m_max_allowed_rms_increase_ratio: How much the RMS error of a block is allowed to increase before a trial is rejected. 1.0=no increase allowed, 1.05=5% increase allowed, etc. + float m_max_allowed_rms_increase_ratio; + + float m_max_smooth_block_std_dev; + float m_smooth_block_max_mse_scale; + + bool m_try_two_matches; + bool m_allow_relative_movement; + bool m_skip_zero_mse_blocks; + bool m_debug_output; + + reduce_entropy_params() { clear(); } + + void clear() + { + m_lookback_window_size = 256; + m_lambda = 1.0f; + m_max_allowed_rms_increase_ratio = 10.0f; + m_max_smooth_block_std_dev = 18.0f; + m_smooth_block_max_mse_scale = 10.0f; + m_try_two_matches = false; + m_allow_relative_movement = false; + m_skip_zero_mse_blocks = false; + m_debug_output = false; + } + + void print() const + { + printf("lambda: %f\n", (double)m_lambda); + printf("Lookback window size: %u\n", m_lookback_window_size); + printf("Max allowed RMS increase ratio: %f\n", (double)m_max_allowed_rms_increase_ratio); + printf("Max smooth block std dev: %f\n", (double)m_max_smooth_block_std_dev); + printf("Smooth block max MSE scale: %f\n", (double)m_smooth_block_max_mse_scale); + printf("Try two matches: %u\n", m_try_two_matches); + printf("Allow relative movement: %u\n", m_allow_relative_movement); + printf("Skip zero MSE blocks: %u\n", m_skip_zero_mse_blocks); + } + }; + + /** + * @brief Callback to compute trial block differences. + * + * All comparing texel values should be in range [0, 255]. + * + * @param[out] out_max_std_dev The channel-wise maximum standard deviation for the original block, + * can be @c nullptr if not requested. + * + * @return Should return the mean squared error for current trial block, or any negative value to indicate errors. + */ + typedef float diff_block_func_type(void* pUser_data, const uint8_t* pBlock, uint32_t block_index, float* out_max_std_dev); + + // BC7 entropy reduction transform with Deflate/LZMA/LZHAM optimizations + bool reduce_entropy(uint8_t* pBlock_bytes, uint32_t num_blocks, + uint32_t total_block_stride_in_bytes, uint32_t block_size_to_optimize_in_bytes, + const reduce_entropy_params& params, uint32_t& total_modified, + diff_block_func_type* pDiff_block_func, void* pDiff_block_func_user_data, + const float* pBlock_mse_scales = nullptr); + +} // namespace ert