From ad48f2f27fad47ef2f2f88d05655dd9e5a64ffbf Mon Sep 17 00:00:00 2001 From: bzp2010 Date: Tue, 19 May 2026 01:17:17 +0800 Subject: [PATCH 1/3] feat(guardrail): apply to request by policy --- src/config/entities/models-schema.json | 4 - src/config/entities/models.rs | 27 +-- src/proxy/guardrails.rs | 84 ++++----- src/proxy/handlers/format_handler.rs | 96 +++++++--- src/proxy/hooks/context.rs | 10 + src/proxy/hooks/mod.rs | 1 + src/proxy/mod.rs | 7 +- src/proxy/policies.rs | 245 +++++++++++++++++++++++++ tests/admin/models.test.ts | 27 +++ tests/proxy/guardrail/shared.ts | 47 ++++- 10 files changed, 439 insertions(+), 109 deletions(-) create mode 100644 src/proxy/policies.rs diff --git a/src/config/entities/models-schema.json b/src/config/entities/models-schema.json index 3b31dbe..af8c3de 100644 --- a/src/config/entities/models-schema.json +++ b/src/config/entities/models-schema.json @@ -5,10 +5,6 @@ "name": { "type": "string", "minLength": 1 }, "provider_id": { "type": "string", "minLength": 1 }, "model": { "type": "string", "minLength": 1 }, - "guardrail_ids": { - "type": "array", - "items": { "type": "string", "minLength": 1 } - }, "timeout": { "type": "integer", "minimum": 0 diff --git a/src/config/entities/models.rs b/src/config/entities/models.rs index 02d7d5a..5cc6bfe 100644 --- a/src/config/entities/models.rs +++ b/src/config/entities/models.rs @@ -28,12 +28,6 @@ pub struct Model { pub provider_id: String, pub model: String, - // Temporary binding surface for guardrail runtime wiring until policy evaluation attaches - // guardrails dynamically. Keeping this on Model lets tests and the current runtime path stay - // simple without committing to the long-term control-plane shape. - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub guardrail_ids: Vec, - #[serde(skip_serializing_if = "Option::is_none")] pub timeout: Option, @@ -115,12 +109,6 @@ mod tests { "provider_id": "openai-primary", "model": "gpt-5" }), true, None)] - #[case::ok_with_guardrails(json!({ - "name": "test", - "provider_id": "bedrock-primary", - "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", - "guardrail_ids": ["gr-input", "gr-output"] - }), true, None)] #[case::ok_with_rate_limit(json!({ "name": "test", "provider_id": "bedrock-primary", @@ -154,18 +142,12 @@ mod tests { "provider_id": "openai-primary", "model": 123 }), false, Some(r#"property "/model" validation failed: 123 is not of type "string""#.to_string()))] - #[case::invalid_guardrail_ids_type(json!({ - "name": "test", - "provider_id": "openai-primary", - "model": "gpt-5", - "guardrail_ids": "gr-input" - }), false, Some(r#"property "/guardrail_ids" validation failed: "gr-input" is not of type "array""#.to_string()))] - #[case::invalid_guardrail_ids_element_type(json!({ + #[case::legacy_guardrail_ids_rejected_by_schema(json!({ "name": "test", "provider_id": "openai-primary", "model": "gpt-5", - "guardrail_ids": [1] - }), false, Some(r#"property "/guardrail_ids/0" validation failed: 1 is not of type "string""#.to_string()))] + "guardrail_ids": ["gr-input"] + }), false, Some(r#"property "/" validation failed: Additional properties are not allowed ('guardrail_ids' was unexpected)"#.to_string()))] #[case::invalid_root_additional_property(json!({ "name": "test", "provider_id": "openai-primary", @@ -190,7 +172,7 @@ mod tests { } #[test] - fn deserialize_model_preserves_provider_reference_and_model_name() { + fn deserialize_model_ignores_legacy_guardrail_ids() { let model: super::Model = serde_json::from_value(json!({ "name": "test", "provider_id": "bedrock-primary", @@ -206,7 +188,6 @@ mod tests { model.model, "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.anthropic.claude-3-5-sonnet-20240620-v1:0" ); - assert_eq!(model.guardrail_ids, vec!["gr-input"]); assert_eq!(model.timeout, Some(30000)); } } diff --git a/src/proxy/guardrails.rs b/src/proxy/guardrails.rs index 132cb07..923e65f 100644 --- a/src/proxy/guardrails.rs +++ b/src/proxy/guardrails.rs @@ -12,7 +12,7 @@ use thiserror::Error; pub(crate) mod streaming; use crate::{ - config::entities::{Model, ResourceEntry, ResourceRegistry, guardrails::GuardrailConfig}, + config::entities::guardrails::GuardrailConfig, gateway::{ error::GatewayError, types::openai::{ @@ -42,7 +42,7 @@ where } #[async_trait] -pub(crate) trait ConfiguredGuardrailRuntime: Send + Sync { +pub(crate) trait ResolvedGuardrail: Send + Sync { fn name(&self) -> &'static str; fn supports_stage(&self, stage: GuardrailStage) -> bool; @@ -53,19 +53,24 @@ pub(crate) trait ConfiguredGuardrailRuntime: Send + Sync { ) -> Result, GatewayError>; } -struct GuardrailRuntimeHandle { +struct RuntimeResolvedGuardrail { runtime: R, config: C, + stage: GuardrailStage, } -impl GuardrailRuntimeHandle { - fn new(runtime: R, config: C) -> Self { - Self { runtime, config } +impl RuntimeResolvedGuardrail { + fn new(runtime: R, config: C, stage: GuardrailStage) -> Self { + Self { + runtime, + config, + stage, + } } } #[async_trait] -impl ConfiguredGuardrailRuntime for GuardrailRuntimeHandle +impl ResolvedGuardrail for RuntimeResolvedGuardrail where R: GuardrailRuntime + Send + Sync, C: Send + Sync, @@ -76,7 +81,7 @@ where } fn supports_stage(&self, stage: GuardrailStage) -> bool { - self.runtime.supports_stage(stage) + self.stage == stage && self.runtime.supports_stage(stage) } async fn check( @@ -213,29 +218,6 @@ pub(crate) fn output_payload_from_check_payload( } } -pub(crate) fn resolve_model_guardrails( - model: &ResourceEntry, - resources: &ResourceRegistry, -) -> Result>, GatewayError> { - // This direct Model -> guardrail lookup is intentionally temporary. The long-term attachment - // point should come from policy evaluation so request-time guardrail selection is not encoded - // in the model resource itself. - model - .guardrail_ids - .iter() - .map(|guardrail_id| { - let guardrail = resources - .guardrails - .get_by_id(guardrail_id) - .ok_or_else(|| { - GatewayError::Internal(format!("guardrail {} not found", guardrail_id)) - })?; - - configured_guardrail_runtime_from_configs(&guardrail.guardrail) - }) - .collect() -} - #[cfg(test)] pub(crate) async fn run_guardrail_check( runtime: &R, @@ -272,17 +254,20 @@ where run_guardrail_check(runtime, config, payload).await } -fn configured_guardrail_runtime_from_configs( +pub(crate) fn build_resolved_guardrail_for_stage( guardrail: &GuardrailConfig, -) -> Result, GatewayError> { + stage: GuardrailStage, +) -> Result, GatewayError> { match guardrail { - GuardrailConfig::Bedrock(config) => Ok(Box::new(GuardrailRuntimeHandle::new( + GuardrailConfig::Bedrock(config) => Ok(Box::new(RuntimeResolvedGuardrail::new( BedrockGuardrailRuntime::new(), config.clone(), + stage, ))), - GuardrailConfig::Regex(config) => Ok(Box::new(GuardrailRuntimeHandle::new( + GuardrailConfig::Regex(config) => Ok(Box::new(RuntimeResolvedGuardrail::new( RegexGuardrailRuntime::new(), config.clone(), + stage, ))), } } @@ -370,8 +355,8 @@ mod tests { use thiserror::Error; use super::{ - GuardrailBridgeError, chat_message_to_guardrail_message, - configured_guardrail_runtime_from_configs, guardrail_message_to_chat_message, + GuardrailBridgeError, build_resolved_guardrail_for_stage, + chat_message_to_guardrail_message, guardrail_message_to_chat_message, input_guardrail_payload_from_chat_messages, input_payload_from_check_payload, input_payload_to_chat_messages, output_guardrail_payload_from_chat_messages, output_payload_from_check_payload, output_payload_to_chat_messages, @@ -631,9 +616,9 @@ mod tests { } #[test] - fn configured_guardrail_runtime_from_configs_builds_bedrock_runtime() { - let runtime = configured_guardrail_runtime_from_configs(&GuardrailConfig::Bedrock( - BedrockGuardrailConfig { + fn build_resolved_guardrail_for_stage_builds_bedrock_runtime() { + let runtime = build_resolved_guardrail_for_stage( + &GuardrailConfig::Bedrock(BedrockGuardrailConfig { identifier: "guardrail-123".into(), version: "1".into(), region: "us-east-1".into(), @@ -641,23 +626,30 @@ mod tests { secret_access_key: "secret".into(), session_token: None, endpoint: None, - }, - )) + }), + GuardrailStage::Input, + ) .unwrap(); assert_eq!(runtime.name(), "bedrock"); assert!(runtime.supports_stage(GuardrailStage::Input)); + assert!(!runtime.supports_stage(GuardrailStage::Output)); } #[test] - fn configured_guardrail_runtime_from_configs_builds_regex_runtime() { - let runtime = configured_guardrail_runtime_from_configs(&GuardrailConfig::Regex( - RegexGuardrailConfig::new("secret", Some("matched blocked content".into())).unwrap(), - )) + fn build_resolved_guardrail_for_stage_builds_regex_runtime() { + let runtime = build_resolved_guardrail_for_stage( + &GuardrailConfig::Regex( + RegexGuardrailConfig::new("secret", Some("matched blocked content".into())) + .unwrap(), + ), + GuardrailStage::Output, + ) .unwrap(); assert_eq!(runtime.name(), "regex"); assert!(runtime.supports_stage(GuardrailStage::Output)); + assert!(!runtime.supports_stage(GuardrailStage::Input)); } #[test] diff --git a/src/proxy/handlers/format_handler.rs b/src/proxy/handlers/format_handler.rs index 58441bf..a015487 100644 --- a/src/proxy/handlers/format_handler.rs +++ b/src/proxy/handlers/format_handler.rs @@ -25,6 +25,7 @@ use crate::{ traits::{ChatFormat, ProviderCapabilities}, types::{ common::Usage, + openai::ChatMessage, response::{ChatResponse, ChatResponseStream}, }, }, @@ -32,7 +33,7 @@ use crate::{ proxy::{ AppState, guardrails::{ - ConfiguredGuardrailRuntime, resolve_model_guardrails, + ResolvedGuardrail, input_payload_from_check_payload, input_payload_to_chat_messages, streaming::{ StreamGuardrailDecision, WholeResponseReplayAction, WholeResponseReplayDriver, WholeResponseReplayFinalize, @@ -41,6 +42,7 @@ use crate::{ hooks::{ self, RequestContext, authorization::AuthorizationError, rate_limit::RateLimitError, }, + policies::{resolve_request_guardrails, stable_route_format}, provider::create_provider_instance, utils::trace::span_attributes::{apply_span_properties, usage_span_properties}, }, @@ -59,7 +61,7 @@ pub(crate) trait FormatHandlerAdapter: Send + Sync + 'static { Response = Self::Response, StreamChunk = Self::StreamChunk, >; - type Request: Sync; + type Request: Sync + Serialize; type Response: Serialize; type StreamChunk: Clone + Serialize + Send + 'static; type Error: IntoResponse @@ -207,11 +209,9 @@ where A: FormatHandlerAdapter, { hooks::observability::record_start_time(&mut request_ctx).await; - hooks::authorization::check( - &mut request_ctx, - as ChatFormat>::extract_model(&request_data).to_owned(), - ) - .await?; + let requested_model_name = + as ChatFormat>::extract_model(&request_data).to_owned(); + hooks::authorization::check(&mut request_ctx, requested_model_name.clone()).await?; hooks::rate_limit::pre_check(&mut request_ctx).await?; let model = request_ctx @@ -221,7 +221,6 @@ where .cloned() .ok_or_else(A::missing_model_error)?; - A::set_model(&mut request_data, model.model.clone()); let timeout = model.timeout.map(Duration::from_millis); let gateway = state.gateway(); @@ -229,14 +228,32 @@ where let provider = model.provider(resources.as_ref()).ok_or_else(|| { GatewayError::Internal(format!("provider {} not found", model.provider_id)) })?; - let configured_guardrails = resolve_model_guardrails(&model, resources.as_ref())?; let provider_instance = create_provider_instance(gateway.as_ref(), &provider)?; let provider_base_url = provider_instance.effective_base_url().ok(); let mut lifecycle_state = A::prepare_lifecycle(&state, &mut request_ctx, &mut request_data).await?; + let policy_request_raw = serde_json::to_value(&request_data).map_err(|err| { + GatewayError::Internal(format!( + "failed to serialize request for policy evaluation: {err}" + )) + })?; + let policy_input_messages = policy_input_messages::(&lifecycle_state, &request_data)?; + let resolved_guardrails = resolve_request_guardrails( + &request_ctx, + &model, + &provider, + stable_route_format( as ChatFormat>::name()), + &requested_model_name, + as ChatFormat>::is_stream(&request_data), + &policy_request_raw, + &policy_input_messages, + resources.as_ref(), + ) + .await?; + A::set_model(&mut request_data, model.model.clone()); apply_input_guardrails::( - &configured_guardrails, + &resolved_guardrails, &mut lifecycle_state, &mut request_data, ) @@ -269,7 +286,7 @@ where usage, })) => { let output_guardrail_result = apply_output_guardrails::( - &configured_guardrails, + &resolved_guardrails, &mut lifecycle_state, &mut response, ) @@ -295,7 +312,7 @@ where Ok(Ok(ChatResponse::Stream { stream, usage_rx })) => { handle_stream_response::( state, - configured_guardrails, + resolved_guardrails, stream, usage_rx, &mut request_ctx, @@ -315,8 +332,30 @@ where } } +fn policy_input_messages( + lifecycle_state: &A::LifecycleState, + request: &AdapterRequest, +) -> Result, A::Error> +where + A: FormatHandlerAdapter, +{ + let Some(payload) = A::guardrail_input_payload(lifecycle_state, request)? else { + return Ok(Vec::new()); + }; + + let input = input_payload_from_check_payload(payload).map_err(guardrail_bridge_error)?; + Ok(input_payload_to_chat_messages(&input).map_err(guardrail_bridge_error)?) +} + +fn guardrail_bridge_error(error: E) -> GatewayError +where + E: std::fmt::Display, +{ + GatewayError::Bridge(error.to_string()) +} + async fn apply_input_guardrails( - guardrails: &[Box], + guardrails: &[Box], lifecycle_state: &mut A::LifecycleState, request: &mut AdapterRequest, ) -> Result<(), A::Error> @@ -351,7 +390,7 @@ where } async fn apply_output_guardrails( - guardrails: &[Box], + guardrails: &[Box], lifecycle_state: &mut A::LifecycleState, response: &mut AdapterResponse, ) -> Result<(), A::Error> @@ -386,7 +425,7 @@ where } async fn apply_stream_output_guardrails( - guardrails: &[Box], + guardrails: &[Box], payload: &GuardrailCheckPayload, ) -> Result<(), A::Error> where @@ -431,7 +470,7 @@ fn require_stream_output_guardrail_payload( }) } -fn has_output_guardrails(guardrails: &[Box]) -> bool { +fn has_output_guardrails(guardrails: &[Box]) -> bool { guardrails .iter() .any(|guardrail| guardrail.supports_stage(crate::guardrail::traits::GuardrailStage::Output)) @@ -595,7 +634,7 @@ async fn record_first_stream_output_emit( async fn handle_stream_response( state: AppState, - configured_guardrails: Vec>, + resolved_guardrails: Vec>, stream: ChatResponseStream>, usage_rx: oneshot::Receiver, request_ctx: &mut RequestContext, @@ -609,13 +648,12 @@ where let stream_request_ctx = request_ctx.clone(); let stream_state = state.clone(); - let replay_driver = - WholeResponseReplayDriver::new(has_output_guardrails(&configured_guardrails)); - let configured_guardrails = Arc::new(configured_guardrails); + let replay_driver = WholeResponseReplayDriver::new(has_output_guardrails(&resolved_guardrails)); + let resolved_guardrails = Arc::new(resolved_guardrails); let sse_stream = futures::stream::unfold( ( stream_state, - configured_guardrails, + resolved_guardrails, stream, span, stream_request_ctx, @@ -630,7 +668,7 @@ where ), |( state, - configured_guardrails, + resolved_guardrails, mut stream, span, mut request_ctx, @@ -674,7 +712,7 @@ where Ok::(A::serialize_stream_item(&chunk)), ( state, - configured_guardrails, + resolved_guardrails, stream, span, request_ctx, @@ -711,7 +749,7 @@ where Ok(event), ( state, - configured_guardrails, + resolved_guardrails, stream, span, request_ctx, @@ -737,7 +775,7 @@ where Ok(event), ( state, - configured_guardrails, + resolved_guardrails, stream, span, request_ctx, @@ -788,7 +826,7 @@ where Ok::(A::serialize_stream_item(&chunk)), ( state, - configured_guardrails, + resolved_guardrails, stream, span, request_ctx, @@ -834,7 +872,7 @@ where Ok(event), ( state, - configured_guardrails, + resolved_guardrails, stream, span, request_ctx, @@ -865,7 +903,7 @@ where match require_stream_output_guardrail_payload(payload) { Ok(payload) => { apply_stream_output_guardrails::( - configured_guardrails.as_ref(), + resolved_guardrails.as_ref(), &payload, ) .await @@ -926,7 +964,7 @@ where Ok(event), ( state, - configured_guardrails, + resolved_guardrails, stream, span, request_ctx, diff --git a/src/proxy/hooks/context.rs b/src/proxy/hooks/context.rs index 630caf9..215be65 100644 --- a/src/proxy/hooks/context.rs +++ b/src/proxy/hooks/context.rs @@ -14,6 +14,12 @@ struct RequestContextInner { extensions: RwLock, } +#[derive(Clone)] +pub(crate) struct RequestRouteInfo { + pub method: String, + pub path: String, +} + #[derive(Clone)] pub struct RequestContext { inner: Arc, @@ -27,6 +33,10 @@ impl FromRequestParts for RequestContext { state: &AppState, ) -> Result { let mut ctx = http::Extensions::new(); + ctx.insert(RequestRouteInfo { + method: parts.method.as_str().to_string(), + path: parts.uri.path().to_string(), + }); ctx.insert(parts.extensions.remove::>().expect( "Authentication middleware should have inserted ApiKey into request extensions", )); diff --git a/src/proxy/hooks/mod.rs b/src/proxy/hooks/mod.rs index 419de01..09b6a1a 100644 --- a/src/proxy/hooks/mod.rs +++ b/src/proxy/hooks/mod.rs @@ -4,3 +4,4 @@ pub mod observability; pub mod rate_limit; pub use context::RequestContext; +pub(crate) use context::RequestRouteInfo; diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 6adbed0..d670eb9 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -3,6 +3,7 @@ mod handlers; mod hooks; pub(crate) mod message_history; mod middlewares; +mod policies; mod provider; mod utils; @@ -64,11 +65,7 @@ pub fn create_router(state: AppState) -> Result { .merge(Router::new().route("/v1/models", get(handlers::models::list_models))) .route( "/v1/chat/completions", - post( - handlers::format_handler::< - handlers::chat_completions::ChatCompletionsAdapter, - >, - ), + post(handlers::format_handler::), ) .route( "/v1/messages", diff --git a/src/proxy/policies.rs b/src/proxy/policies.rs new file mode 100644 index 0000000..754020e --- /dev/null +++ b/src/proxy/policies.rs @@ -0,0 +1,245 @@ +use std::collections::HashSet; + +use cel::{Context, Value as CelValue}; +use serde::Serialize; +use serde_json::Value as JsonValue; + +use crate::{ + config::entities::{ + ApiKey, Model, Policy, Provider, ResourceEntry, ResourceRegistry, + policies::{PolicyAction, PolicyStage}, + }, + gateway::{error::GatewayError, types::openai::ChatMessage}, + guardrail::traits::GuardrailStage, + proxy::{ + guardrails::{ResolvedGuardrail, build_resolved_guardrail_for_stage}, + hooks::{RequestContext, RequestRouteInfo}, + }, +}; + +pub(crate) fn stable_route_format(format_name: &'static str) -> &'static str { + match format_name { + "openai_chat" => "chat_completions", + "anthropic_messages" => "messages", + "openai_responses" => "responses", + other => other, + } +} + +pub(crate) async fn resolve_request_guardrails( + request_ctx: &RequestContext, + model: &ResourceEntry, + provider: &ResourceEntry, + route_format: &'static str, + request_model: &str, + request_stream: bool, + request_raw: &JsonValue, + input_messages: &[ChatMessage], + resources: &ResourceRegistry, +) -> Result>, GatewayError> { + let (api_key, route) = { + let extensions = request_ctx.extensions().await; + let api_key = extensions + .get::>() + .cloned() + .ok_or_else(|| GatewayError::Internal("policy context missing api key".into()))?; + let route = extensions + .get::() + .cloned() + .ok_or_else(|| GatewayError::Internal("policy context missing route info".into()))?; + (api_key, route) + }; + + let context = PolicyContext { + auth: PolicyAuthContext { + api_key: PolicyApiKeyContext { id: &api_key.id }, + }, + model: PolicyModelContext { + id: &model.id, + name: &model.name, + upstream: &model.model, + }, + provider: PolicyProviderContext { + id: &provider.id, + name: &provider.name, + provider_type: provider.provider_type(), + }, + route: PolicyRouteContext { + method: &route.method, + path: &route.path, + format: route_format, + }, + request: PolicyRequestContext { + model: request_model, + stream: request_stream, + raw: request_raw, + }, + input: PolicyInputContext { + messages: input_messages, + }, + }; + + let policies = resources.policies.list(); + let mut matched_policies = policies + .values() + .filter(|policy| policy.enabled) + .filter_map(|policy| match policy_matches(policy, &context) { + Ok(true) => Some(Ok(policy)), + Ok(false) => None, + Err(err) => Some(Err(err)), + }) + .collect::, _>>()?; + + matched_policies.sort_by(|left, right| { + right + .priority + .cmp(&left.priority) + .then_with(|| left.id.cmp(&right.id)) + }); + + let mut seen_bindings = HashSet::new(); + let mut resolved_guardrails = Vec::new(); + + for policy in matched_policies { + for action in &policy.actions { + match action { + PolicyAction::Guardrail(config) => { + for stage in &config.stages { + let resolved_stage = guardrail_stage_from_policy_stage(stage); + let stage_key = guardrail_stage_key(resolved_stage); + + for guardrail_id in &config.guardrail_ids { + let dedupe_key = format!("{guardrail_id}:{stage_key}"); + if !seen_bindings.insert(dedupe_key) { + continue; + } + + let guardrail = + resources.guardrails.get_by_id(guardrail_id).ok_or_else(|| { + GatewayError::Internal(format!( + "guardrail {} referenced by policy {} not found", + guardrail_id, policy.id + )) + })?; + + resolved_guardrails.push(build_resolved_guardrail_for_stage( + &guardrail.guardrail, + resolved_stage, + )?); + } + } + } + } + } + } + + Ok(resolved_guardrails) +} + +fn policy_matches( + policy: &ResourceEntry, + context: &PolicyContext<'_>, +) -> Result { + let mut cel_context = Context::default(); + cel_context.add_variable("auth", &context.auth).map_err(|err| { + GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) + })?; + cel_context.add_variable("model", &context.model).map_err(|err| { + GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) + })?; + cel_context.add_variable("provider", &context.provider).map_err(|err| { + GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) + })?; + cel_context.add_variable("route", &context.route).map_err(|err| { + GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) + })?; + cel_context.add_variable("request", &context.request).map_err(|err| { + GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) + })?; + cel_context.add_variable("input", &context.input).map_err(|err| { + GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) + })?; + + let program = policy.compiled_when().map_err(|err| { + GatewayError::Internal(format!("policy {} failed to compile cached CEL: {err}", policy.id)) + })?; + let result = program.execute(&cel_context).map_err(|err| { + GatewayError::Internal(format!("policy {} evaluation failed: {err}", policy.id)) + })?; + + match result { + CelValue::Bool(value) => Ok(value), + other => Err(GatewayError::Internal(format!( + "policy {} when must evaluate to bool, got {:?}", + policy.id, other + ))), + } +} + +fn guardrail_stage_from_policy_stage(stage: &PolicyStage) -> GuardrailStage { + match stage { + PolicyStage::Input => GuardrailStage::Input, + PolicyStage::Output => GuardrailStage::Output, + } +} + +fn guardrail_stage_key(stage: GuardrailStage) -> &'static str { + match stage { + GuardrailStage::Input => "input", + GuardrailStage::Output => "output", + } +} + +#[derive(Serialize)] +struct PolicyContext<'a> { + auth: PolicyAuthContext<'a>, + model: PolicyModelContext<'a>, + provider: PolicyProviderContext<'a>, + route: PolicyRouteContext<'a>, + request: PolicyRequestContext<'a>, + input: PolicyInputContext<'a>, +} + +#[derive(Serialize)] +struct PolicyAuthContext<'a> { + api_key: PolicyApiKeyContext<'a>, +} + +#[derive(Serialize)] +struct PolicyApiKeyContext<'a> { + id: &'a str, +} + +#[derive(Serialize)] +struct PolicyModelContext<'a> { + id: &'a str, + name: &'a str, + upstream: &'a str, +} + +#[derive(Serialize)] +struct PolicyProviderContext<'a> { + id: &'a str, + name: &'a str, + #[serde(rename = "type")] + provider_type: &'a str, +} + +#[derive(Serialize)] +struct PolicyRouteContext<'a> { + method: &'a str, + path: &'a str, + format: &'a str, +} + +#[derive(Serialize)] +struct PolicyRequestContext<'a> { + model: &'a str, + stream: bool, + raw: &'a JsonValue, +} + +#[derive(Serialize)] +struct PolicyInputContext<'a> { + messages: &'a [ChatMessage], +} diff --git a/tests/admin/models.test.ts b/tests/admin/models.test.ts index 4438b17..4b14d5f 100644 --- a/tests/admin/models.test.ts +++ b/tests/admin/models.test.ts @@ -176,4 +176,31 @@ describe('admin models', () => { expect(duplicateResp.status).toBe(400); expect(duplicateResp.data.error_msg).toBe('Model name already exists'); }); + + test('test_guardrail_ids_rejected', async () => { + const auth = bearerAuthHeader(ADMIN_KEY); + const providerId = 'legacy-guardrail-provider'; + const providerResp = await adminPut( + `${PROVIDERS_URL}/${providerId}`, + { + name: providerId, + type: 'openai', + config: TEST_PROVIDER_CONFIG, + }, + auth, + ); + expect(providerResp.status).toBe(201); + + const createResp = await adminPost( + MODELS_URL, + { + ...buildModelBody('legacy_guardrail_model', providerId), + guardrail_ids: ['gr-legacy'], + }, + auth, + ); + + expect(createResp.status).toBe(400); + expect(createResp.data.error_msg).toContain('guardrail_ids'); + }); }); diff --git a/tests/proxy/guardrail/shared.ts b/tests/proxy/guardrail/shared.ts index ba355d9..d0a0dad 100644 --- a/tests/proxy/guardrail/shared.ts +++ b/tests/proxy/guardrail/shared.ts @@ -2,6 +2,7 @@ import { randomUUID } from 'node:crypto'; import { MODELS_URL, + POLICIES_URL, PROVIDERS_URL, adminPost, adminPut, @@ -111,7 +112,6 @@ export const setupOpenAiRegexGuardrailFixture = async ({ name: inputGuardedModelName, model: buildModel(upstreamModel), provider_id: providerId, - guardrail_ids: [inputGuardrailId], }, auth, ), @@ -126,7 +126,6 @@ export const setupOpenAiRegexGuardrailFixture = async ({ name: outputGuardedModelName, model: buildModel(upstreamModel), provider_id: providerId, - guardrail_ids: [outputGuardrailId], }, auth, ), @@ -134,6 +133,50 @@ export const setupOpenAiRegexGuardrailFixture = async ({ 'create output-guarded model', ); + ensureStatus( + await adminPost( + POLICIES_URL, + { + name: `${modelPrefix}-input-policy-${randomUUID()}`, + when: `model.name == '${inputGuardedModelName}'`, + actions: [ + { + type: 'guardrail', + config: { + stages: ['input'], + guardrail_ids: [inputGuardrailId], + }, + }, + ], + }, + auth, + ), + 201, + 'create input guardrail policy', + ); + + ensureStatus( + await adminPost( + POLICIES_URL, + { + name: `${modelPrefix}-output-policy-${randomUUID()}`, + when: `model.name == '${outputGuardedModelName}'`, + actions: [ + { + type: 'guardrail', + config: { + stages: ['output'], + guardrail_ids: [outputGuardrailId], + }, + }, + ], + }, + auth, + ), + 201, + 'create output guardrail policy', + ); + ensureStatus( await adminPost( '/apikeys', From 9f0df8aa77167144e7047df879942a26b09ca2a9 Mon Sep 17 00:00:00 2001 From: bzp2010 Date: Tue, 19 May 2026 01:28:05 +0800 Subject: [PATCH 2/3] fix lint --- src/proxy/handlers/format_handler.rs | 24 +++--- src/proxy/policies.rs | 108 ++++++++++++++++----------- 2 files changed, 77 insertions(+), 55 deletions(-) diff --git a/src/proxy/handlers/format_handler.rs b/src/proxy/handlers/format_handler.rs index a015487..b1d9c04 100644 --- a/src/proxy/handlers/format_handler.rs +++ b/src/proxy/handlers/format_handler.rs @@ -42,7 +42,7 @@ use crate::{ hooks::{ self, RequestContext, authorization::AuthorizationError, rate_limit::RateLimitError, }, - policies::{resolve_request_guardrails, stable_route_format}, + policies::{RequestGuardrailResolution, resolve_request_guardrails, stable_route_format}, provider::create_provider_instance, utils::trace::span_attributes::{apply_span_properties, usage_span_properties}, }, @@ -238,17 +238,17 @@ where )) })?; let policy_input_messages = policy_input_messages::(&lifecycle_state, &request_data)?; - let resolved_guardrails = resolve_request_guardrails( - &request_ctx, - &model, - &provider, - stable_route_format( as ChatFormat>::name()), - &requested_model_name, - as ChatFormat>::is_stream(&request_data), - &policy_request_raw, - &policy_input_messages, - resources.as_ref(), - ) + let resolved_guardrails = resolve_request_guardrails(RequestGuardrailResolution { + request_ctx: &request_ctx, + model: &model, + provider: &provider, + route_format: stable_route_format( as ChatFormat>::name()), + request_model: &requested_model_name, + request_stream: as ChatFormat>::is_stream(&request_data), + request_raw: &policy_request_raw, + input_messages: &policy_input_messages, + resources: resources.as_ref(), + }) .await?; A::set_model(&mut request_data, model.model.clone()); diff --git a/src/proxy/policies.rs b/src/proxy/policies.rs index 754020e..41a167c 100644 --- a/src/proxy/policies.rs +++ b/src/proxy/policies.rs @@ -26,19 +26,23 @@ pub(crate) fn stable_route_format(format_name: &'static str) -> &'static str { } } +pub(crate) struct RequestGuardrailResolution<'a> { + pub request_ctx: &'a RequestContext, + pub model: &'a ResourceEntry, + pub provider: &'a ResourceEntry, + pub route_format: &'static str, + pub request_model: &'a str, + pub request_stream: bool, + pub request_raw: &'a JsonValue, + pub input_messages: &'a [ChatMessage], + pub resources: &'a ResourceRegistry, +} + pub(crate) async fn resolve_request_guardrails( - request_ctx: &RequestContext, - model: &ResourceEntry, - provider: &ResourceEntry, - route_format: &'static str, - request_model: &str, - request_stream: bool, - request_raw: &JsonValue, - input_messages: &[ChatMessage], - resources: &ResourceRegistry, + request: RequestGuardrailResolution<'_>, ) -> Result>, GatewayError> { let (api_key, route) = { - let extensions = request_ctx.extensions().await; + let extensions = request.request_ctx.extensions().await; let api_key = extensions .get::>() .cloned() @@ -55,31 +59,31 @@ pub(crate) async fn resolve_request_guardrails( api_key: PolicyApiKeyContext { id: &api_key.id }, }, model: PolicyModelContext { - id: &model.id, - name: &model.name, - upstream: &model.model, + id: &request.model.id, + name: &request.model.name, + upstream: &request.model.model, }, provider: PolicyProviderContext { - id: &provider.id, - name: &provider.name, - provider_type: provider.provider_type(), + id: &request.provider.id, + name: &request.provider.name, + provider_type: request.provider.provider_type(), }, route: PolicyRouteContext { method: &route.method, path: &route.path, - format: route_format, + format: request.route_format, }, request: PolicyRequestContext { - model: request_model, - stream: request_stream, - raw: request_raw, + model: request.request_model, + stream: request.request_stream, + raw: request.request_raw, }, input: PolicyInputContext { - messages: input_messages, + messages: request.input_messages, }, }; - let policies = resources.policies.list(); + let policies = request.resources.policies.list(); let mut matched_policies = policies .values() .filter(|policy| policy.enabled) @@ -114,8 +118,11 @@ pub(crate) async fn resolve_request_guardrails( continue; } - let guardrail = - resources.guardrails.get_by_id(guardrail_id).ok_or_else(|| { + let guardrail = request + .resources + .guardrails + .get_by_id(guardrail_id) + .ok_or_else(|| { GatewayError::Internal(format!( "guardrail {} referenced by policy {} not found", guardrail_id, policy.id @@ -141,27 +148,42 @@ fn policy_matches( context: &PolicyContext<'_>, ) -> Result { let mut cel_context = Context::default(); - cel_context.add_variable("auth", &context.auth).map_err(|err| { - GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) - })?; - cel_context.add_variable("model", &context.model).map_err(|err| { - GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) - })?; - cel_context.add_variable("provider", &context.provider).map_err(|err| { - GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) - })?; - cel_context.add_variable("route", &context.route).map_err(|err| { - GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) - })?; - cel_context.add_variable("request", &context.request).map_err(|err| { - GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) - })?; - cel_context.add_variable("input", &context.input).map_err(|err| { - GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) - })?; + cel_context + .add_variable("auth", &context.auth) + .map_err(|err| { + GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) + })?; + cel_context + .add_variable("model", &context.model) + .map_err(|err| { + GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) + })?; + cel_context + .add_variable("provider", &context.provider) + .map_err(|err| { + GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) + })?; + cel_context + .add_variable("route", &context.route) + .map_err(|err| { + GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) + })?; + cel_context + .add_variable("request", &context.request) + .map_err(|err| { + GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) + })?; + cel_context + .add_variable("input", &context.input) + .map_err(|err| { + GatewayError::Internal(format!("policy {} context error: {err}", policy.id)) + })?; let program = policy.compiled_when().map_err(|err| { - GatewayError::Internal(format!("policy {} failed to compile cached CEL: {err}", policy.id)) + GatewayError::Internal(format!( + "policy {} failed to compile cached CEL: {err}", + policy.id + )) })?; let result = program.execute(&cel_context).map_err(|err| { GatewayError::Internal(format!("policy {} evaluation failed: {err}", policy.id)) From e18522ec726b1ea073fc62d1792b715706d3331f Mon Sep 17 00:00:00 2001 From: bzp2010 Date: Tue, 19 May 2026 09:56:13 +0800 Subject: [PATCH 3/3] fix tests --- tests/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/package.json b/tests/package.json index b400d53..45e2db7 100644 --- a/tests/package.json +++ b/tests/package.json @@ -8,7 +8,7 @@ "test": "vitest run", "test:dev": "vitest" }, - "packageManager": "pnpm@11.1.2+sha512.415a1cc25974731e75455c1468371be74c5aa5fb7621b50d4056d222451609f11412f23fd602e6169f1e060466641f798597e1be961a10688836a67b16569499", + "packageManager": "pnpm@11.1.3+sha512.c85357fe17ca12dd23dd7071822666dfd7e3cb76fe214e3370b5ea2fb34f2a231185509b63e717f3cd0acb38dd3f8d82bcd5e8172400ae678b70ea4fbed0896d", "devDependencies": { "@anthropic-ai/sdk": "^0.88.0", "@eslint/js": "^10.0.1",