一个基于 Ray 的分布式推理框架,用于高效处理大规模音频数据。框架采用队列式流水线架构,支持多 GPU 并行推理、断点续跑和灵活的数据源接入。
- 🚀 分布式推理:基于 Ray 实现多 GPU 并行处理
- 📦 队列式流水线:解耦数据加载、模型推理和结果保存
- 🔄 断点续跑:基于 SQLite 的任务跟踪,支持中断后恢复
- 🔌 灵活扩展:易于添加新模型和数据加载器
- 📊 统一数据格式:使用
AudioInfo类统一数据流
- BEATs:音频分类模型
- Qwen3-Omni:多模态生成模型(基于 vLLM)
- AudioJSONLDataLoader:从本地 JSONL 文件加载音频路径
- OSSDataloader:从 OSS(对象存储)加载音频数据,支持预签名 URL 和 FFmpeg 解码
- JSONLDataLoader:通用 JSONL 数据加载器
框架采用队列式流水线架构,数据在各组件间通过 Ray Queue 传递:
┌─────────────────┐
│ DataLoaderWorker │ 顺序迭代数据,跳过已完成任务
└────────┬────────┘
│ AudioInfo[]
▼
┌─────────────────┐
│ input_queue │ 输入队列(Ray Queue)
└────────┬────────┘
│ AudioInfo[]
▼
┌─────────────────┐
│ ModelWorker(s) │ 并行处理,支持多 GPU
│ (可动态扩展) │
└────────┬────────┘
│ AudioInfo[] (含预测结果)
▼
┌─────────────────┐
│ result_queue │ 结果队列(Ray Queue)
└────────┬────────┘
│ AudioInfo[]
▼
┌─────────────────┐
│ SaveWorker │ 批量保存结果,更新任务状态
└────────┬────────┘
│
▼
┌─────────────────┐
│ TaskTracker │ SQLite 数据库(任务状态跟踪)
│ (SaveWorker │ 直接调用 mark_tasks_completed
│ 内部使用) │
└─────────────────┘
- DataLoaderWorker:顺序迭代数据加载器,跳过已完成的任务(通过
TaskTracker查询),将批次数据放入input_queue - ModelWorker(s):从
input_queue获取批次数据,调用模型进行推理,将结果放入result_queue - SaveWorker:从
result_queue获取结果,批量写入文件并更新数据库中的任务状态 - TaskTracker:使用 SQLite 跟踪任务状态(unallocated → completed),支持断点续跑
所有组件间传递的数据都使用 AudioInfo 类,这是一个统一的数据结构:
@dataclass
class AudioInfo:
# 音频路径/URL
audio_path: Optional[str] = None
url: Optional[str] = None
# OSS 相关参数(用于 OSSDataloader)
bucket: Optional[str] = None
object_key: Optional[str] = None
endpoint: Optional[str] = None
# ...
# 音频片段信息
start: Optional[float] = None
end: Optional[float] = None
# 模型输出
predictions: List[Dict[str, Any]] = field(default_factory=list)
error: Optional[str] = None
# 额外字段
_extra: Dict[str, Any] = field(default_factory=dict)# 基础依赖
pip install ray torch torchaudio
# OSSDataloader 依赖(可选)
pip install minio ffmpeg-python
# Qwen3-Omni 依赖(可选)
pip install vllmpython ray_inference.py --model /path/to/model --output /path/to/output所有配置都在 config.py 中的 Config 类中管理。修改 ray_inference.py 的 __main__ 部分:
# 数据路径
cfg.data_path = '/path/to/data.jsonl'
cfg.output_path = './outputs'
# 模型和数据加载器
cfg.model_type = 'beats' # 或 'qwen3omni'
cfg.dataloader_type = 'audio_jsonl' # 或 'oss', 'jsonl'
# 推理配置
cfg.batch_size = 512
# OSS 配置(如果使用 OSSDataloader)
cfg.meta_dir = '/path/to/metadata'
cfg.endpoint = 'oss.example.com:8009'
cfg.access_key = 'your_access_key'
cfg.secret_key = 'your_secret_key'在 models/ 目录下创建新模型文件,继承 BaseModel 类:
from models.base_model import BaseModel
from audio_info import AudioInfo
from typing import List
class MyModel(BaseModel):
def __init__(self, model_name: str, model_path: str = None, **kwargs):
super().__init__(model_name, model_path, **kwargs)
# 初始化你的模型
self._load_model()
def _load_model(self):
"""加载模型"""
# 实现模型加载逻辑
pass
def generate(self, inputs: List[AudioInfo], **kwargs) -> List[AudioInfo]:
"""
处理 AudioInfo 列表,返回处理后的 AudioInfo 列表
Args:
inputs: AudioInfo 对象列表
**kwargs: 其他参数
Returns:
处理后的 AudioInfo 对象列表(包含 predictions)
"""
results = []
for audio_info in inputs:
# 处理单个 AudioInfo
# 将结果写入 audio_info.predictions
audio_info.predictions = [...] # 你的预测结果
results.append(audio_info)
return results
def generate_batch(self, batch_data: List[AudioInfo], **kwargs) -> List[AudioInfo]:
"""批量处理(可选,默认调用 generate)"""
return self.generate(batch_data, **kwargs)然后在 workers/model_worker.py 中添加创建函数:
def create_my_model_worker(model_name: str, model_path: str = None, **kwargs):
from models.my_model import MyModel
return ModelWorker.remote(MyModel, model_name, model_path, **kwargs)在 dataloader.py 中创建新数据加载器,继承 BaseDataLoader 类:
from dataloader import BaseDataLoader
from audio_info import AudioInfo
from typing import Iterator, List
class MyDataLoader(BaseDataLoader):
def _load_data(self):
"""加载数据到 self.data"""
# 实现数据加载逻辑
self.data = [...] # 数据列表(字典格式)
def __iter__(self) -> Iterator[List[AudioInfo]]:
"""返回批次数据"""
for i in range(0, len(self.data), self.batch_size):
batch = self.data[i:i + self.batch_size]
# 转换为 AudioInfo 对象
yield [AudioInfo.from_dict(item) for item in batch]
def get_item(self, index: int) -> dict:
"""根据索引获取单个数据项"""
return self.data[index]然后在 dataloader.py 的 create_dataloader 函数中添加:
def create_dataloader(dataloader_type: str = 'oss', ...):
# ...
elif dataloader_type == 'my_dataloader':
return MyDataLoader(data_path=data_path, batch_size=batch_size, ...)- DataLoaderWorker:数据加载工作器,负责从数据源加载数据并放入队列
- ModelWorker:模型推理工作器,每个 GPU 一个实例,从队列获取数据并推理
- SaveWorker:结果保存工作器,批量保存结果并更新任务状态
- QueueMonitor:队列监控器,定期输出队列大小信息
- TaskTracker:任务跟踪器,使用 SQLite 管理任务状态
框架使用 SQLite 数据库跟踪任务状态:
- 启动时查询已完成的任务,跳过这些任务
- 处理完成后标记任务为
completed - 支持中断后恢复,自动跳过已完成的任务
- 批量处理:SaveWorker 使用缓冲区批量写入文件和数据库
- 异步流水线:数据加载、推理和保存并行执行
- 资源管理:每个 ModelWorker 使用
num_gpus=0.1,支持多 worker 共享 GPU
完整示例请参考 example_ray_run.py。
[根据项目实际情况填写]
欢迎提交 Issue 和 Pull Request!