diff --git a/tools/bench_a2_fast.cpp b/tools/bench_a2_fast.cpp index 2dcaff1..4057503 100644 --- a/tools/bench_a2_fast.cpp +++ b/tools/bench_a2_fast.cpp @@ -6,7 +6,12 @@ // Runs both over the same audio, reports median wall time + real-time factor. // // Usage: -// bench_a2_fast [--buffer N] [--seconds S] [--iters I] [ ...] +// bench_a2_fast [--buffer N] [--seconds S] [--iters I] [--slim V] [ ...] +// +// A model may be a plain WaveNet .nam or a SlimmableContainer .nam (e.g. an A2 file +// bundling A2-Full + A2-Lite submodels). For a container, --slim V (0.0-1.0) selects the +// submodel exactly as ContainerModel::SetSlimmableSize would (first submodel whose +// max_value exceeds V, else the last); the selected submodel is then benched fast vs generic. // // Only compiled when NAM_ENABLE_A2_FAST is defined. @@ -44,9 +49,13 @@ struct Options int buffer_size = 64; double seconds = 2.0; int iterations = 10; + double slim = 1.0; // SlimmableContainer selector; defaults to full size (last submodel) + bool has_slim = false; std::vector model_paths; }; +const char* kUsage = "Usage: bench_a2_fast [--buffer N] [--seconds S] [--iters I] [--slim V] ...\n"; + Options parse_args(int argc, char** argv) { Options o; @@ -59,9 +68,19 @@ Options parse_args(int argc, char** argv) o.seconds = std::atof(argv[++i]); else if (a == "--iters" && i + 1 < argc) o.iterations = std::atoi(argv[++i]); + else if (a == "--slim" && i + 1 < argc) + { + o.slim = std::atof(argv[++i]); + o.has_slim = true; + if (o.slim < 0.0 || o.slim > 1.0) + { + std::cerr << "--slim value must be between 0.0 and 1.0\n"; + std::exit(1); + } + } else if (a == "-h" || a == "--help") { - std::cerr << "Usage: bench_a2_fast [--buffer N] [--seconds S] [--iters I] ...\n"; + std::cerr << kUsage; std::exit(0); } else @@ -78,7 +97,33 @@ struct LoadedModel std::string path; }; -LoadedModel load_nam(const std::string& path) +// Select a SlimmableContainer submodel for the given slim value, mirroring +// ContainerModel::SetSlimmableSize: the first submodel whose max_value exceeds the value, +// else the last. Returns the chosen submodel's model spec and appends a label to `path`. +nlohmann::json select_container_submodel(const nlohmann::json& config, double slim, std::string& path) +{ + const auto& submodels = config["submodels"]; + if (!submodels.is_array() || submodels.empty()) + throw std::runtime_error(path + ": SlimmableContainer 'submodels' must be a non-empty array"); + + size_t idx = submodels.size() - 1; + for (size_t i = 0; i < submodels.size(); i++) + { + if (slim < submodels[i].at("max_value").get()) + { + idx = i; + break; + } + } + + std::ostringstream label; + label << path << " [slim=" << slim << " -> submodel " << idx << "/" << submodels.size() + << ", max_value=" << submodels[idx].at("max_value").get() << "]"; + path = label.str(); + return submodels[idx].at("model"); +} + +LoadedModel load_nam(const std::string& path, const Options& o) { LoadedModel m; m.path = path; @@ -87,8 +132,22 @@ LoadedModel load_nam(const std::string& path) throw std::runtime_error("Could not open " + path); nlohmann::json j; is >> j; - if (j.value("architecture", std::string()) != "WaveNet") - throw std::runtime_error(path + ": not a WaveNet model"); + + // Resolve a SlimmableContainer to the submodel selected by --slim (each submodel is a + // full, standalone model spec). A plain WaveNet is used directly. + std::string arch = j.value("architecture", std::string()); + if (arch == "SlimmableContainer") + { + j = select_container_submodel(j["config"], o.slim, m.path); + arch = j.value("architecture", std::string()); + } + else if (o.has_slim) + { + std::cerr << "[note] " << path << ": --slim ignored (not a SlimmableContainer)\n"; + } + + if (arch != "WaveNet") + throw std::runtime_error(path + ": not a WaveNet model (architecture=" + arch + ")"); m.config = j["config"]; m.weights = j["weights"].get>(); if (j.contains("sample_rate") && !j["sample_rate"].is_null()) @@ -244,14 +303,14 @@ int main(int argc, char** argv) Options o = parse_args(argc, argv); if (o.model_paths.empty()) { - std::cerr << "Usage: bench_a2_fast [--buffer N] [--seconds S] [--iters I] ...\n"; + std::cerr << kUsage; return 1; } for (const auto& p : o.model_paths) { try { - bench_model(load_nam(p), o); + bench_model(load_nam(p, o), o); } catch (const std::exception& e) {