Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 66 additions & 7 deletions tools/bench_a2_fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
// Runs both over the same audio, reports median wall time + real-time factor.
//
// Usage:
// bench_a2_fast [--buffer N] [--seconds S] [--iters I] <model.nam> [<model.nam> ...]
// bench_a2_fast [--buffer N] [--seconds S] [--iters I] [--slim V] <model.nam> [<model.nam> ...]
//
// A model may be a plain WaveNet .nam or a SlimmableContainer .nam (e.g. an A2 file
// bundling A2-Full + A2-Lite submodels). For a container, --slim V (0.0-1.0) selects the
// submodel exactly as ContainerModel::SetSlimmableSize would (first submodel whose
// max_value exceeds V, else the last); the selected submodel is then benched fast vs generic.
//
// Only compiled when NAM_ENABLE_A2_FAST is defined.

Expand Down Expand Up @@ -44,9 +49,13 @@ struct Options
int buffer_size = 64;
double seconds = 2.0;
int iterations = 10;
double slim = 1.0; // SlimmableContainer selector; defaults to full size (last submodel)
bool has_slim = false;
std::vector<std::string> model_paths;
};

const char* kUsage = "Usage: bench_a2_fast [--buffer N] [--seconds S] [--iters I] [--slim V] <model.nam> ...\n";

Options parse_args(int argc, char** argv)
{
Options o;
Expand All @@ -59,9 +68,19 @@ Options parse_args(int argc, char** argv)
o.seconds = std::atof(argv[++i]);
else if (a == "--iters" && i + 1 < argc)
o.iterations = std::atoi(argv[++i]);
else if (a == "--slim" && i + 1 < argc)
{
o.slim = std::atof(argv[++i]);
o.has_slim = true;
if (o.slim < 0.0 || o.slim > 1.0)
{
std::cerr << "--slim value must be between 0.0 and 1.0\n";
std::exit(1);
}
}
else if (a == "-h" || a == "--help")
{
std::cerr << "Usage: bench_a2_fast [--buffer N] [--seconds S] [--iters I] <model.nam> ...\n";
std::cerr << kUsage;
std::exit(0);
}
else
Expand All @@ -78,7 +97,33 @@ struct LoadedModel
std::string path;
};

LoadedModel load_nam(const std::string& path)
// Select a SlimmableContainer submodel for the given slim value, mirroring
// ContainerModel::SetSlimmableSize: the first submodel whose max_value exceeds the value,
// else the last. Returns the chosen submodel's model spec and appends a label to `path`.
nlohmann::json select_container_submodel(const nlohmann::json& config, double slim, std::string& path)
{
const auto& submodels = config["submodels"];
if (!submodels.is_array() || submodels.empty())
throw std::runtime_error(path + ": SlimmableContainer 'submodels' must be a non-empty array");

size_t idx = submodels.size() - 1;
for (size_t i = 0; i < submodels.size(); i++)
{
if (slim < submodels[i].at("max_value").get<double>())
{
idx = i;
break;
}
}

std::ostringstream label;
label << path << " [slim=" << slim << " -> submodel " << idx << "/" << submodels.size()
<< ", max_value=" << submodels[idx].at("max_value").get<double>() << "]";
path = label.str();
return submodels[idx].at("model");
}

LoadedModel load_nam(const std::string& path, const Options& o)
{
LoadedModel m;
m.path = path;
Expand All @@ -87,8 +132,22 @@ LoadedModel load_nam(const std::string& path)
throw std::runtime_error("Could not open " + path);
nlohmann::json j;
is >> j;
if (j.value("architecture", std::string()) != "WaveNet")
throw std::runtime_error(path + ": not a WaveNet model");

// Resolve a SlimmableContainer to the submodel selected by --slim (each submodel is a
// full, standalone model spec). A plain WaveNet is used directly.
std::string arch = j.value("architecture", std::string());
if (arch == "SlimmableContainer")
{
j = select_container_submodel(j["config"], o.slim, m.path);
arch = j.value("architecture", std::string());
}
else if (o.has_slim)
{
std::cerr << "[note] " << path << ": --slim ignored (not a SlimmableContainer)\n";
}

if (arch != "WaveNet")
throw std::runtime_error(path + ": not a WaveNet model (architecture=" + arch + ")");
m.config = j["config"];
m.weights = j["weights"].get<std::vector<float>>();
if (j.contains("sample_rate") && !j["sample_rate"].is_null())
Expand Down Expand Up @@ -244,14 +303,14 @@ int main(int argc, char** argv)
Options o = parse_args(argc, argv);
if (o.model_paths.empty())
{
std::cerr << "Usage: bench_a2_fast [--buffer N] [--seconds S] [--iters I] <model.nam> ...\n";
std::cerr << kUsage;
return 1;
}
for (const auto& p : o.model_paths)
{
try
{
bench_model(load_nam(p), o);
bench_model(load_nam(p, o), o);
}
catch (const std::exception& e)
{
Expand Down
Loading