Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .staging/pre-1.0/infer/vllm/driver
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def main():
t1 = fmwork.time_get()
print(); print('FMWORK DATASET', '%.6f' % (fmwork.time_diff(t1, t0)))

os.environ["VLLM_SPYRE_WARMUP_PROMPT_LENS"] = par.input_size
os.environ["VLLM_SPYRE_WARMUP_NEW_TOKENS"] = par.output_size
os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = par.batch_size
os.environ["SENDNN_INFERENCE_WARMUP_PROMPT_LENS"] = par.input_size
os.environ["SENDNN_INFERENCE_WARMUP_NEW_TOKENS"] = par.output_size
os.environ['SENDNN_INFERENCE_WARMUP_BATCH_SIZES'] = par.batch_size

llm()
runs()
Expand Down
6 changes: 3 additions & 3 deletions .staging/spyre/infer/vllm/driver
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def main():
params()

# spyre environment variables
os.environ["VLLM_SPYRE_WARMUP_PROMPT_LENS"] = par.input_size
os.environ["VLLM_SPYRE_WARMUP_NEW_TOKENS"] = str(max(var.output_sizes))
os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = par.batch_size
os.environ["SENDNN_INFERENCE_WARMUP_PROMPT_LENS"] = par.input_size
os.environ["SENDNN_INFERENCE_WARMUP_NEW_TOKENS"] = str(max(var.output_sizes))
os.environ['SENDNN_INFERENCE_WARMUP_BATCH_SIZES'] = par.batch_size

if par.dataset_path:
fmwork.banner('DATASET')
Expand Down
16 changes: 8 additions & 8 deletions infer/vllm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ client \
--env FLEX_RDMA_MODE_FULL=FALSE \
--env FLEX_HDMA_MODE_FULL=1 \
--env OMP_NUM_THREADS=32 \
--env VLLM_SPYRE_WARMUP_PROMPT_LENS=1024 \
--env VLLM_SPYRE_WARMUP_NEW_TOKENS=128 \
--env VLLM_SPYRE_WARMUP_BATCH_SIZES=1 \
--env SENDNN_INFERENCE_WARMUP_PROMPT_LENS=1024 \
--env SENDNN_INFERENCE_WARMUP_NEW_TOKENS=128 \
--env SENDNN_INFERENCE_WARMUP_BATCH_SIZES=1 \
-- \
driver \
--platform spyre \
Expand Down Expand Up @@ -108,7 +108,7 @@ driver \
--env FLEX_RDMA_MODE_FULL=FALSE \
--env FLEX_HDMA_MODE_FULL=1 \
--env OMP_NUM_THREADS=32 \
--env VLLM_SPYRE_USE_CB=1 \
--env SENDNN_INFERENCE_USE_CB=1 \
-- \
driver \
--platform spyre \
Expand Down Expand Up @@ -146,9 +146,9 @@ server \
--env FLEX_RDMA_MODE_FULL=FALSE \
--env FLEX_HDMA_MODE_FULL=1 \
--env OMP_NUM_THREADS=32 \
--env VLLM_SPYRE_WARMUP_PROMPT_LENS=1024 \
--env VLLM_SPYRE_WARMUP_NEW_TOKENS=128 \
--env VLLM_SPYRE_WARMUP_BATCH_SIZES=1 \
--env SENDNN_INFERENCE_WARMUP_PROMPT_LENS=1024 \
--env SENDNN_INFERENCE_WARMUP_NEW_TOKENS=128 \
--env SENDNN_INFERENCE_WARMUP_BATCH_SIZES=1 \
--no-enable-prefix-caching \
--max-model-len 2048 \
--max-num-seqs 1 \
Expand Down Expand Up @@ -182,7 +182,7 @@ server \
--env FLEX_RDMA_MODE_FULL=FALSE \
--env FLEX_HDMA_MODE_FULL=1 \
--env OMP_NUM_THREADS=32 \
--env VLLM_SPYRE_USE_CB=1 \
--env SENDNN_INFERENCE_USE_CB=1 \
--no-enable-prefix-caching \
--max-model-len 2048 \
--max-num-seqs 1 \
Expand Down
6 changes: 3 additions & 3 deletions infer/vllm/driver
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def setup_runtime_spyre(args):

print('setup_runtime: spyre')

os.environ["VLLM_SPYRE_WARMUP_PROMPT_LENS"] = str(max(args.input_sizes))
os.environ["VLLM_SPYRE_WARMUP_NEW_TOKENS"] = str(max(args.output_sizes))
os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = str(max(args.batch_sizes))
os.environ["SENDNN_INFERENCE_WARMUP_PROMPT_LENS"] = str(max(args.input_sizes))
os.environ["SENDNN_INFERENCE_WARMUP_NEW_TOKENS"] = str(max(args.output_sizes))
os.environ['SENDNN_INFERENCE_WARMUP_BATCH_SIZES'] = str(max(args.batch_sizes))

# ------------
# setup engine
Expand Down
8 changes: 4 additions & 4 deletions infer/vllm/process
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def process_direct(args):
split = line.strip().split(' ')
opt = '--env ' + ' '.join(split[2:])
opts.append(opt)
if 'VLLM_SPYRE_USE_CB' in line:
if 'SENDNN_INFERENCE_USE_CB' in line:
value = line.strip().split('=')[-1]
if value == '1':
batch_mode = 'continuous'
Expand Down Expand Up @@ -460,7 +460,7 @@ def process_server(args):
batch_mode = 'static'
for line in open(log_server):
if line.startswith('FMWORK EXP'):
if 'VLLM_SPYRE_USE_CB' in line:
if 'SENDNN_INFERENCE_USE_CB' in line:
value = line.strip().split('=')[-1]
if value == '1':
batch_mode = 'continuous'
Expand Down Expand Up @@ -586,8 +586,8 @@ def process_server(args):
num_prompts = None

# Extract warmup parameters from server.cmd
warmup_prompt_match = re.search(r'VLLM_SPYRE_WARMUP_PROMPT_LENS=(\d+)', server_cmd_content)
warmup_tokens_match = re.search(r'VLLM_SPYRE_WARMUP_NEW_TOKENS=(\d+)', server_cmd_content)
warmup_prompt_match = re.search(r'SENDNN_INFERENCE_WARMUP_PROMPT_LENS=(\d+)', server_cmd_content)
warmup_tokens_match = re.search(r'SENDNN_INFERENCE_WARMUP_NEW_TOKENS=(\d+)', server_cmd_content)
warmup_input = int(warmup_prompt_match.group(1)) if warmup_prompt_match else None
warmup_output = int(warmup_tokens_match.group(1)) if warmup_tokens_match else None

Expand Down