Skip to content

VoiceAgentGroup/AudioCaptionPipeline

Repository files navigation

Ray Inference Framework

一个基于 Ray 的分布式推理框架,用于高效处理大规模音频数据。框架采用队列式流水线架构,支持多 GPU 并行推理、断点续跑和灵活的数据源接入。

特性

  • 🚀 分布式推理:基于 Ray 实现多 GPU 并行处理
  • 📦 队列式流水线:解耦数据加载、模型推理和结果保存
  • 🔄 断点续跑:基于 SQLite 的任务跟踪,支持中断后恢复
  • 🔌 灵活扩展:易于添加新模型和数据加载器
  • 📊 统一数据格式:使用 AudioInfo 类统一数据流

支持的模型

  • BEATs:音频分类模型
  • Qwen3-Omni:多模态生成模型(基于 vLLM)

支持的数据加载器

  • AudioJSONLDataLoader:从本地 JSONL 文件加载音频路径
  • OSSDataloader:从 OSS(对象存储)加载音频数据,支持预签名 URL 和 FFmpeg 解码
  • JSONLDataLoader:通用 JSONL 数据加载器

整体流程

框架采用队列式流水线架构,数据在各组件间通过 Ray Queue 传递:

┌─────────────────┐
│ DataLoaderWorker │  顺序迭代数据,跳过已完成任务
└────────┬────────┘
         │ AudioInfo[]
         ▼
┌─────────────────┐
│   input_queue    │  输入队列(Ray Queue)
└────────┬────────┘
         │ AudioInfo[]
         ▼
┌─────────────────┐
│  ModelWorker(s) │  并行处理,支持多 GPU
│  (可动态扩展)    │
└────────┬────────┘
         │ AudioInfo[] (含预测结果)
         ▼
┌─────────────────┐
│  result_queue   │  结果队列(Ray Queue)
└────────┬────────┘
         │ AudioInfo[]
         ▼
┌─────────────────┐
│   SaveWorker    │  批量保存结果,更新任务状态
└────────┬────────┘
         │
         ▼
┌─────────────────┐
│  TaskTracker    │  SQLite 数据库(任务状态跟踪)
│  (SaveWorker    │  直接调用 mark_tasks_completed
│   内部使用)      │
└─────────────────┘

工作流程说明

  1. DataLoaderWorker:顺序迭代数据加载器,跳过已完成的任务(通过 TaskTracker 查询),将批次数据放入 input_queue
  2. ModelWorker(s):从 input_queue 获取批次数据,调用模型进行推理,将结果放入 result_queue
  3. SaveWorker:从 result_queue 获取结果,批量写入文件并更新数据库中的任务状态
  4. TaskTracker:使用 SQLite 跟踪任务状态(unallocated → completed),支持断点续跑

数据格式

所有组件间传递的数据都使用 AudioInfo 类,这是一个统一的数据结构:

@dataclass
class AudioInfo:
    # 音频路径/URL
    audio_path: Optional[str] = None
    url: Optional[str] = None
  
    # OSS 相关参数(用于 OSSDataloader)
    bucket: Optional[str] = None
    object_key: Optional[str] = None
    endpoint: Optional[str] = None
    # ...
  
    # 音频片段信息
    start: Optional[float] = None
    end: Optional[float] = None
  
    # 模型输出
    predictions: List[Dict[str, Any]] = field(default_factory=list)
    error: Optional[str] = None
  
    # 额外字段
    _extra: Dict[str, Any] = field(default_factory=dict)

快速开始

安装依赖

# 基础依赖
pip install ray torch torchaudio

# OSSDataloader 依赖(可选)
pip install minio ffmpeg-python

# Qwen3-Omni 依赖(可选)
pip install vllm

基本使用

python ray_inference.py --model /path/to/model --output /path/to/output

配置修改

所有配置都在 config.py 中的 Config 类中管理。修改 ray_inference.py__main__ 部分:

# 数据路径
cfg.data_path = '/path/to/data.jsonl'
cfg.output_path = './outputs'

# 模型和数据加载器
cfg.model_type = 'beats'  # 或 'qwen3omni'
cfg.dataloader_type = 'audio_jsonl'  # 或 'oss', 'jsonl'

# 推理配置
cfg.batch_size = 512

# OSS 配置(如果使用 OSSDataloader)
cfg.meta_dir = '/path/to/metadata'
cfg.endpoint = 'oss.example.com:8009'
cfg.access_key = 'your_access_key'
cfg.secret_key = 'your_secret_key'

扩展指南

添加新模型

models/ 目录下创建新模型文件,继承 BaseModel 类:

from models.base_model import BaseModel
from audio_info import AudioInfo
from typing import List

class MyModel(BaseModel):
    def __init__(self, model_name: str, model_path: str = None, **kwargs):
        super().__init__(model_name, model_path, **kwargs)
        # 初始化你的模型
        self._load_model()
  
    def _load_model(self):
        """加载模型"""
        # 实现模型加载逻辑
        pass
  
    def generate(self, inputs: List[AudioInfo], **kwargs) -> List[AudioInfo]:
        """
        处理 AudioInfo 列表,返回处理后的 AudioInfo 列表
      
        Args:
            inputs: AudioInfo 对象列表
            **kwargs: 其他参数
          
        Returns:
            处理后的 AudioInfo 对象列表(包含 predictions)
        """
        results = []
        for audio_info in inputs:
            # 处理单个 AudioInfo
            # 将结果写入 audio_info.predictions
            audio_info.predictions = [...]  # 你的预测结果
            results.append(audio_info)
        return results
  
    def generate_batch(self, batch_data: List[AudioInfo], **kwargs) -> List[AudioInfo]:
        """批量处理(可选,默认调用 generate)"""
        return self.generate(batch_data, **kwargs)

然后在 workers/model_worker.py 中添加创建函数:

def create_my_model_worker(model_name: str, model_path: str = None, **kwargs):
    from models.my_model import MyModel
    return ModelWorker.remote(MyModel, model_name, model_path, **kwargs)

添加新数据加载器

dataloader.py 中创建新数据加载器,继承 BaseDataLoader 类:

from dataloader import BaseDataLoader
from audio_info import AudioInfo
from typing import Iterator, List

class MyDataLoader(BaseDataLoader):
    def _load_data(self):
        """加载数据到 self.data"""
        # 实现数据加载逻辑
        self.data = [...]  # 数据列表(字典格式)
  
    def __iter__(self) -> Iterator[List[AudioInfo]]:
        """返回批次数据"""
        for i in range(0, len(self.data), self.batch_size):
            batch = self.data[i:i + self.batch_size]
            # 转换为 AudioInfo 对象
            yield [AudioInfo.from_dict(item) for item in batch]
  
    def get_item(self, index: int) -> dict:
        """根据索引获取单个数据项"""
        return self.data[index]

然后在 dataloader.pycreate_dataloader 函数中添加:

def create_dataloader(dataloader_type: str = 'oss', ...):
    # ...
    elif dataloader_type == 'my_dataloader':
        return MyDataLoader(data_path=data_path, batch_size=batch_size, ...)

架构设计

组件说明

  • DataLoaderWorker:数据加载工作器,负责从数据源加载数据并放入队列
  • ModelWorker:模型推理工作器,每个 GPU 一个实例,从队列获取数据并推理
  • SaveWorker:结果保存工作器,批量保存结果并更新任务状态
  • QueueMonitor:队列监控器,定期输出队列大小信息
  • TaskTracker:任务跟踪器,使用 SQLite 管理任务状态

断点续跑机制

框架使用 SQLite 数据库跟踪任务状态:

  • 启动时查询已完成的任务,跳过这些任务
  • 处理完成后标记任务为 completed
  • 支持中断后恢复,自动跳过已完成的任务

性能优化

  • 批量处理:SaveWorker 使用缓冲区批量写入文件和数据库
  • 异步流水线:数据加载、推理和保存并行执行
  • 资源管理:每个 ModelWorker 使用 num_gpus=0.1,支持多 worker 共享 GPU

示例

完整示例请参考 example_ray_run.py

许可证

[根据项目实际情况填写]

贡献

欢迎提交 Issue 和 Pull Request!

About

A repository for using Ray to process data.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors