From aac127d01a68639dc7cbd5ba84c45adfc99c018b Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Tue, 9 Jun 2026 17:50:25 -0700 Subject: [PATCH 1/2] Add get_dsp load-time configuration options --- NAM/container.cpp | 7 ++ NAM/dsp.h | 5 +- NAM/get_dsp.cpp | 151 +++++++++++++++++++++++------------- NAM/get_dsp.h | 39 ++++++++-- NAM/wavenet/slimmable.cpp | 6 ++ tools/run_tests.cpp | 3 + tools/test/test_get_dsp.cpp | 103 +++++++++++++++++++++++- 7 files changed, 252 insertions(+), 62 deletions(-) diff --git a/NAM/container.cpp b/NAM/container.cpp index ee7d9f16..e1ac16dd 100644 --- a/NAM/container.cpp +++ b/NAM/container.cpp @@ -100,6 +100,13 @@ void ContainerModel::SetSlimmableSize(const double val) { return; } + + if (!mHaveExternalSampleRate && GetMaxBufferSize() == 0) + { + _active_index.store(active_index, std::memory_order_release); + return; + } + // Setting _active_index puts the model in the RT path, so reset before doing that. const double sr = mHaveExternalSampleRate ? mExternalSampleRate : mExpectedSampleRate; _submodels[active_index].model->Reset(sr, GetMaxBufferSize()); diff --git a/NAM/dsp.h b/NAM/dsp.h index c714a197..01408d0a 100644 --- a/NAM/dsp.h +++ b/NAM/dsp.h @@ -327,8 +327,9 @@ struct dspData nlohmann::json config; ///< Model configuration JSON nlohmann::json metadata; ///< Model metadata JSON std::vector weights; ///< Model weights - double expected_sample_rate; ///< Expected sample rate in Hz. Most NAM models implicitly assume data at some sample - ///< rate. Use -1.0 for "I don't know". + double expected_sample_rate = NAM_UNKNOWN_EXPECTED_SAMPLE_RATE; ///< Expected sample rate in Hz. Most NAM models + ///< implicitly assume data at some sample rate. Use + ///< -1.0 for "I don't know". }; /// \brief Verify that the config version is supported by this plugin version diff --git a/NAM/get_dsp.cpp b/NAM/get_dsp.cpp index 3aa85924..c40f77ca 100644 --- a/NAM/get_dsp.cpp +++ b/NAM/get_dsp.cpp @@ -10,12 +10,27 @@ #include "json.hpp" #include "get_dsp.h" #include "model_config.h" +#include "slimmable.h" namespace nam { +std::vector GetWeights(nlohmann::json const& j); + namespace { +struct LoadOptions +{ + std::optional expectedSampleRate; + std::optional maxBufferSize; + std::optional slimmableSize; + + bool requires_initial_reset() const + { + return expectedSampleRate.has_value() || maxBufferSize.has_value() || slimmableSize.has_value(); + } +}; + class CoreVersionSupportChecker : public IVersionSupportChecker { public: @@ -52,6 +67,65 @@ std::mutex& version_support_registry_mutex() return registry_mutex; } +dspData parse_dsp_data(const nlohmann::json& config, std::optional expectedSampleRate) +{ + verify_config_version(config["version"].get()); + + dspData out; + out.version = config["version"].get(); + out.architecture = config["architecture"].get(); + out.config = config["config"]; + out.metadata = config.value("metadata", nlohmann::json()); + out.weights = GetWeights(config); + out.expected_sample_rate = expectedSampleRate.value_or(nam::get_sample_rate_from_nam_file(config)); + return out; +} + +void apply_initial_slimmable_size(DSP& dsp, const double slimmableSize) +{ + auto* slimmable = dynamic_cast(&dsp); + if (slimmable == nullptr) + throw std::runtime_error("Cannot set slimmable size on a model that is not slimmable."); + slimmable->SetSlimmableSize(slimmableSize); +} + +void apply_metadata(DSP& dsp, const ModelMetadata& metadata) +{ + if (metadata.loudness.has_value()) + dsp.SetLoudness(metadata.loudness.value()); + if (metadata.input_level.has_value()) + dsp.SetInputLevel(metadata.input_level.value()); + if (metadata.output_level.has_value()) + dsp.SetOutputLevel(metadata.output_level.value()); +} + +void configure_initial_state(DSP& dsp, const ModelMetadata& metadata, const LoadOptions& options) +{ + if (options.slimmableSize.has_value()) + apply_initial_slimmable_size(dsp, options.slimmableSize.value()); + + if (options.requires_initial_reset()) + { + const double sampleRate = options.expectedSampleRate.value_or(metadata.sample_rate); + const int maxBufferSize = options.maxBufferSize.value_or(NAM_DEFAULT_MAX_BUFFER_SIZE); + dsp.Reset(sampleRate, maxBufferSize); + } + else + { + // Preserve the historical load behavior when no load-time configuration is requested. + dsp.prewarm(); + } +} + +std::unique_ptr create_dsp_with_options(std::unique_ptr config, std::vector weights, + const ModelMetadata& metadata, const LoadOptions& options) +{ + auto out = config->create(std::move(weights), metadata.sample_rate); + apply_metadata(*out, metadata); + configure_initial_state(*out, metadata, options); + return out; +} + } // namespace Version ParseVersion(const std::string& versionStr) @@ -139,51 +213,37 @@ std::vector GetWeights(nlohmann::json const& j) throw std::runtime_error("Corrupted model file is missing weights."); } -std::unique_ptr get_dsp(const std::filesystem::path config_filename) +std::unique_ptr get_dsp(const std::filesystem::path config_filename, std::optional expectedSampleRate, + std::optional maxBufferSize, std::optional slimmableSize) { dspData temp; - return get_dsp(config_filename, temp); + return get_dsp(config_filename, temp, expectedSampleRate, maxBufferSize, slimmableSize); } -std::unique_ptr get_dsp(const nlohmann::json& config) +std::unique_ptr get_dsp(const nlohmann::json& config, std::optional expectedSampleRate, + std::optional maxBufferSize, std::optional slimmableSize) { dspData temp; - return get_dsp(config, temp); + return get_dsp(config, temp, expectedSampleRate, maxBufferSize, slimmableSize); } -std::unique_ptr get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig) +std::unique_ptr get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig, + std::optional expectedSampleRate, std::optional maxBufferSize, + std::optional slimmableSize) { if (!std::filesystem::exists(config_filename)) throw std::runtime_error("Config file doesn't exist!\n"); std::ifstream i(config_filename); nlohmann::json j; i >> j; - get_dsp(j, returnedConfig); - - /*Copy to a new dsp_config object for get_dsp below, - since not sure if weights actually get modified as being non-const references on some - model constructors inside get_dsp(dsp_config& conf). - We need to return unmodified version of dsp_config via returnedConfig.*/ - dspData conf = returnedConfig; - - return get_dsp(conf); + return get_dsp(j, returnedConfig, expectedSampleRate, maxBufferSize, slimmableSize); } -std::unique_ptr get_dsp(const nlohmann::json& config, dspData& returnedConfig) +std::unique_ptr get_dsp(const nlohmann::json& config, dspData& returnedConfig, + std::optional expectedSampleRate, std::optional maxBufferSize, + std::optional slimmableSize) { - verify_config_version(config["version"].get()); - - auto architecture = config["architecture"]; - nlohmann::json config_json = config["config"]; - std::vector weights = GetWeights(config); - - // Assign values to returnedConfig - returnedConfig.version = config["version"].get(); - returnedConfig.architecture = config["architecture"].get(); - returnedConfig.config = config_json; - returnedConfig.metadata = config.value("metadata", nlohmann::json()); - returnedConfig.weights = weights; - returnedConfig.expected_sample_rate = nam::get_sample_rate_from_nam_file(config); + returnedConfig = parse_dsp_data(config, expectedSampleRate); /*Copy to a new dsp_config object for get_dsp below, since not sure if weights actually get modified as being non-const references on some @@ -191,7 +251,7 @@ std::unique_ptr get_dsp(const nlohmann::json& config, dspData& returnedConf We need to return unmodified version of dsp_config via returnedConfig.*/ dspData conf = returnedConfig; - return get_dsp(conf); + return get_dsp(conf, expectedSampleRate, maxBufferSize, slimmableSize); } // ============================================================================= @@ -204,44 +264,27 @@ std::unique_ptr parse_model_config_json(const std::string& architec return ConfigParserRegistry::instance().parse(architecture, config, sample_rate); } -namespace -{ - -void apply_metadata(DSP& dsp, const ModelMetadata& metadata) -{ - if (metadata.loudness.has_value()) - dsp.SetLoudness(metadata.loudness.value()); - if (metadata.input_level.has_value()) - dsp.SetInputLevel(metadata.input_level.value()); - if (metadata.output_level.has_value()) - dsp.SetOutputLevel(metadata.output_level.value()); -} - -} // anonymous namespace - std::unique_ptr create_dsp(std::unique_ptr config, std::vector weights, const ModelMetadata& metadata) { - auto out = config->create(std::move(weights), metadata.sample_rate); - apply_metadata(*out, metadata); - // "pre-warm" the model to settle initial conditions - // Can this be removed now that it's part of Reset()? - out->prewarm(); - return out; + return create_dsp_with_options(std::move(config), std::move(weights), metadata, LoadOptions{}); } // ============================================================================= // get_dsp(dspData&) — now uses unified path // ============================================================================= -std::unique_ptr get_dsp(dspData& conf) +std::unique_ptr get_dsp(dspData& conf, std::optional expectedSampleRate, std::optional maxBufferSize, + std::optional slimmableSize) { verify_config_version(conf.version); + const double effectiveSampleRate = expectedSampleRate.value_or(conf.expected_sample_rate); + const LoadOptions options{expectedSampleRate, maxBufferSize, slimmableSize}; // Extract metadata from JSON ModelMetadata metadata; metadata.version = conf.version; - metadata.sample_rate = conf.expected_sample_rate; + metadata.sample_rate = effectiveSampleRate; if (!conf.metadata.is_null()) { @@ -255,8 +298,8 @@ std::unique_ptr get_dsp(dspData& conf) metadata.output_level = extract("output_level_dbu"); } - auto model_config = ConfigParserRegistry::instance().parse(conf.architecture, conf.config, conf.expected_sample_rate); - return create_dsp(std::move(model_config), std::move(conf.weights), metadata); + auto model_config = ConfigParserRegistry::instance().parse(conf.architecture, conf.config, effectiveSampleRate); + return create_dsp_with_options(std::move(model_config), std::move(conf.weights), metadata, options); } double get_sample_rate_from_nam_file(const nlohmann::json& j) diff --git a/NAM/get_dsp.h b/NAM/get_dsp.h index da874fe9..eed2678d 100644 --- a/NAM/get_dsp.h +++ b/NAM/get_dsp.h @@ -2,6 +2,7 @@ #include #include +#include #include #include "dsp.h" @@ -67,32 +68,60 @@ const std::string EARLIEST_SUPPORTED_NAM_FILE_VERSION = "0.5.0"; /// \brief Get NAM from a .nam file at the provided location /// \param config_filename Path to the .nam model file +/// \param expectedSampleRate Expected sample rate to configure the model with; std::nullopt uses the file default +/// \param maxBufferSize Maximum buffer size to configure the model with; std::nullopt uses NAM_DEFAULT_MAX_BUFFER_SIZE +/// \param slimmableSize Slimmable size to configure the model with; std::nullopt uses the model default /// \return Unique pointer to a DSP object -std::unique_ptr get_dsp(const std::filesystem::path config_filename); +std::unique_ptr get_dsp(const std::filesystem::path config_filename, + std::optional expectedSampleRate = std::nullopt, + std::optional maxBufferSize = std::nullopt, + std::optional slimmableSize = std::nullopt); /// \brief Get NAM from a provided configuration struct /// \param conf DSP data structure containing model configuration and weights +/// \param expectedSampleRate Expected sample rate to configure the model with; std::nullopt uses the config default +/// \param maxBufferSize Maximum buffer size to configure the model with; std::nullopt uses NAM_DEFAULT_MAX_BUFFER_SIZE +/// \param slimmableSize Slimmable size to configure the model with; std::nullopt uses the model default /// \return Unique pointer to a DSP object -std::unique_ptr get_dsp(dspData& conf); +std::unique_ptr get_dsp(dspData& conf, std::optional expectedSampleRate = std::nullopt, + std::optional maxBufferSize = std::nullopt, + std::optional slimmableSize = std::nullopt); /// \brief Get NAM from a .nam file and store its configuration /// /// Creates an instance of DSP and also returns a dspData struct that holds the data of the model. /// \param config_filename Path to the .nam model file /// \param returnedConfig Output parameter that will be filled with the model data +/// \param expectedSampleRate Expected sample rate to configure the model with; std::nullopt uses the file default +/// \param maxBufferSize Maximum buffer size to configure the model with; std::nullopt uses NAM_DEFAULT_MAX_BUFFER_SIZE +/// \param slimmableSize Slimmable size to configure the model with; std::nullopt uses the model default /// \return Unique pointer to a DSP object -std::unique_ptr get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig); +std::unique_ptr get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig, + std::optional expectedSampleRate = std::nullopt, + std::optional maxBufferSize = std::nullopt, + std::optional slimmableSize = std::nullopt); /// \brief Get NAM from a provided configuration JSON object /// \param config JSON configuration object /// \param returnedConfig Output parameter that will be filled with the model data +/// \param expectedSampleRate Expected sample rate to configure the model with; std::nullopt uses the file default +/// \param maxBufferSize Maximum buffer size to configure the model with; std::nullopt uses NAM_DEFAULT_MAX_BUFFER_SIZE +/// \param slimmableSize Slimmable size to configure the model with; std::nullopt uses the model default /// \return Unique pointer to a DSP object -std::unique_ptr get_dsp(const nlohmann::json& config, dspData& returnedConfig); +std::unique_ptr get_dsp(const nlohmann::json& config, dspData& returnedConfig, + std::optional expectedSampleRate = std::nullopt, + std::optional maxBufferSize = std::nullopt, + std::optional slimmableSize = std::nullopt); /// \brief Get NAM from a provided configuration JSON object (convenience overload) /// \param config JSON configuration object +/// \param expectedSampleRate Expected sample rate to configure the model with; std::nullopt uses the file default +/// \param maxBufferSize Maximum buffer size to configure the model with; std::nullopt uses NAM_DEFAULT_MAX_BUFFER_SIZE +/// \param slimmableSize Slimmable size to configure the model with; std::nullopt uses the model default /// \return Unique pointer to a DSP object -std::unique_ptr get_dsp(const nlohmann::json& config); +std::unique_ptr get_dsp(const nlohmann::json& config, std::optional expectedSampleRate = std::nullopt, + std::optional maxBufferSize = std::nullopt, + std::optional slimmableSize = std::nullopt); /// \brief Get sample rate from a .nam file /// \param j JSON object from the .nam file diff --git a/NAM/wavenet/slimmable.cpp b/NAM/wavenet/slimmable.cpp index 19b019e2..8a751882 100644 --- a/NAM/wavenet/slimmable.cpp +++ b/NAM/wavenet/slimmable.cpp @@ -495,6 +495,12 @@ void SlimmableWavenet::SetSlimmableSize(const double val) target[i] = ratio_to_channels(val, allowed); } + if (_current_buffer_size <= 0) + { + _rebuild_model(target); + return; + } + _stage_rebuild_model(target); } diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index 0f9d50a3..124d8f05 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -280,6 +280,9 @@ int main() test_get_dsp::test_gets_output_level(); test_get_dsp::test_null_input_level(); test_get_dsp::test_null_output_level(); + test_get_dsp::test_get_dsp_without_load_options_preserves_prewarm_only(); + test_get_dsp::test_get_dsp_applies_load_options(); + test_get_dsp::test_get_dsp_applies_slimmable_option_before_reset_with_defaults(); test_get_dsp::test_version_patch_one_beyond_supported(); test_get_dsp::test_version_minor_one_beyond_supported(); test_get_dsp::test_version_too_early(); diff --git a/tools/test/test_get_dsp.cpp b/tools/test/test_get_dsp.cpp index 150f50fe..30de68c7 100644 --- a/tools/test/test_get_dsp.cpp +++ b/tools/test/test_get_dsp.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -10,6 +11,8 @@ #include "json.hpp" #include "NAM/get_dsp.h" +#include "NAM/registry.h" +#include "NAM/slimmable.h" namespace test_get_dsp { @@ -46,6 +49,65 @@ nam::dspData _GetConfig(const std::string& configStr = basicConfigStr) return returnedConfig; } +class LoadOptionsDSP : public nam::DSP, public nam::SlimmableModel +{ +public: + explicit LoadOptionsDSP(const double expected_sample_rate) + : nam::DSP(1, 1, expected_sample_rate) + { + } + + void Reset(const double sampleRate, const int maxBufferSize) override + { + reset_count++; + reset_sample_rate = sampleRate; + reset_buffer_size = maxBufferSize; + nam::DSP::Reset(sampleRate, maxBufferSize); + } + + void prewarm() override { prewarm_count++; } + + void SetSlimmableSize(const double val) override + { + slim_set_count++; + slim_set_size = val; + slim_set_before_reset = reset_count == 0; + } + + int reset_count = 0; + int prewarm_count = 0; + int slim_set_count = 0; + double reset_sample_rate = 0.0; + int reset_buffer_size = 0; + double slim_set_size = -1.0; + bool slim_set_before_reset = false; +}; + +std::unique_ptr LoadOptionsFactory(const nlohmann::json& config, std::vector& weights, + const double expectedSampleRate) +{ + (void)config; + (void)weights; + return std::make_unique(expectedSampleRate); +} + +namespace +{ +static nam::factory::Helper _register_LoadOptionsArchitecture("LoadOptionsArchitecture", LoadOptionsFactory); +} + +nlohmann::json load_options_config() +{ + return nlohmann::json::parse(R"({ + "version": "0.7.0", + "metadata": {}, + "architecture": "LoadOptionsArchitecture", + "config": {}, + "weights": [], + "sample_rate": 48000 + })"); +} + void test_gets_input_level() { nam::dspData config = _GetConfig(); @@ -84,6 +146,45 @@ void test_null_output_level() assert(!dsp->HasOutputLevel()); } +void test_get_dsp_without_load_options_preserves_prewarm_only() +{ + auto dsp = nam::get_dsp(load_options_config()); + auto* loaded = dynamic_cast(dsp.get()); + assert(loaded != nullptr); + assert(loaded->reset_count == 0); + assert(loaded->prewarm_count == 1); + assert(std::abs(loaded->GetExpectedSampleRate() - 48000.0) < 1e-9); +} + +void test_get_dsp_applies_load_options() +{ + nam::dspData returnedConfig; + auto dsp = + nam::get_dsp(load_options_config(), returnedConfig, std::optional{44100.0}, std::optional{128}); + auto* loaded = dynamic_cast(dsp.get()); + assert(loaded != nullptr); + assert(loaded->reset_count == 1); + assert(loaded->prewarm_count == 1); + assert(loaded->reset_sample_rate == 44100.0); + assert(loaded->reset_buffer_size == 128); + assert(std::abs(loaded->GetExpectedSampleRate() - 44100.0) < 1e-9); + assert(returnedConfig.expected_sample_rate == 44100.0); +} + +void test_get_dsp_applies_slimmable_option_before_reset_with_defaults() +{ + auto dsp = nam::get_dsp(load_options_config(), std::nullopt, std::optional{256}, std::optional{0.25}); + auto* loaded = dynamic_cast(dsp.get()); + assert(loaded != nullptr); + assert(loaded->slim_set_count == 1); + assert(loaded->slim_set_size == 0.25); + assert(loaded->slim_set_before_reset); + assert(loaded->reset_count == 1); + assert(loaded->prewarm_count == 1); + assert(loaded->reset_sample_rate == 48000.0); + assert(loaded->reset_buffer_size == 256); +} + // Helper function to process buffers through a DSP model void process_buffers(nam::DSP* dsp, int num_buffers, int buffer_size) { @@ -273,4 +374,4 @@ void test_register_custom_version_support_checker() assert(nam::is_version_supported("DEMO::1.0.3") == nam::Supported::PARTIAL); assert(nam::is_version_supported("DEMO::2.0.0") == nam::Supported::NO); } -}; // namespace test_get_dsp \ No newline at end of file +}; // namespace test_get_dsp From 2386951308d64d6ab218b703277e19292c4471a1 Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Tue, 9 Jun 2026 18:35:32 -0700 Subject: [PATCH 2/2] Update Reset prewarm documentation --- NAM/dsp.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/NAM/dsp.h b/NAM/dsp.h index 01408d0a..cb5f9ea6 100644 --- a/NAM/dsp.h +++ b/NAM/dsp.h @@ -133,8 +133,7 @@ class DSP /// \brief General function for resetting the DSP unit /// - /// This doesn't call prewarm(). If you want to do that, then you might want to use ResetAndPrewarm(). - /// See https://github.com/sdatkinson/NeuralAmpModelerCore/issues/96 for the reasoning. + /// This calls prewarm() after applying the sample rate and max buffer size. /// \param sampleRate Current sample rate /// \param maxBufferSize Maximum buffer size to process virtual void Reset(const double sampleRate, const int maxBufferSize);