From 9778bbce702c4a6c4750cf1d6cbee0e58ca10752 Mon Sep 17 00:00:00 2001 From: bzp2010 Date: Sat, 16 May 2026 09:24:46 +0800 Subject: [PATCH 1/2] feat(policy): add definition and admin api --- Cargo.lock | 199 +++++++++++++++- Cargo.toml | 2 + src/admin/mod.rs | 17 ++ src/admin/policies.rs | 244 ++++++++++++++++++++ src/config/entities/mod.rs | 5 + src/config/entities/policies-schema.json | 68 ++++++ src/config/entities/policies.rs | 282 +++++++++++++++++++++++ 7 files changed, 809 insertions(+), 8 deletions(-) create mode 100644 src/admin/policies.rs create mode 100644 src/config/entities/policies-schema.json create mode 100644 src/config/entities/policies.rs diff --git a/Cargo.lock b/Cargo.lock index e17edef..0f4cbf5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -45,6 +45,7 @@ dependencies = [ "axum-server", "backon", "bytes", + "cel", "clap", "config", "dashmap", @@ -80,7 +81,7 @@ dependencies = [ "serde_json", "skp-ratelimit", "tempfile", - "thiserror", + "thiserror 2.0.18", "tokio", "tokio-openssl", "tokio-test", @@ -120,7 +121,7 @@ dependencies = [ "reqwest", "serde", "serde_json", - "thiserror", + "thiserror 2.0.18", "tokio", "utoipa", ] @@ -147,7 +148,7 @@ dependencies = [ "reqwest", "serde", "serde_json", - "thiserror", + "thiserror 2.0.18", "tokio", "utoipa", "uuid", @@ -159,6 +160,15 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anstream" version = "1.0.0" @@ -209,6 +219,23 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "antlr4rust" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "093d520274bfff7278d776f7ea12981a0a0a6f96db90964658e0f38fc6e9a6a6" +dependencies = [ + "better_any", + "bit-set", + "byteorder", + "lazy_static", + "murmur3", + "once_cell", + "parking_lot", + "typed-arena", + "uuid", +] + [[package]] name = "anyhow" version = "1.0.102" @@ -500,6 +527,12 @@ dependencies = [ "vsimd", ] +[[package]] +name = "better_any" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4372b9543397a4b86050cc5e7ee36953edf4bac9518e8a774c2da694977fb6e4" + [[package]] name = "bit-set" version = "0.8.0" @@ -560,6 +593,12 @@ version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "175812e0be2bccb6abe50bb8d566126198344f707e304f45c648fd8f2cc0365e" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.11.1" @@ -588,6 +627,22 @@ dependencies = [ "shlex", ] +[[package]] +name = "cel" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47a40f338a8c3505921000b609279775792c07cc21f97a3011578c0c5e1738ae" +dependencies = [ + "antlr4rust", + "chrono", + "lazy_static", + "nom", + "pastey", + "regex", + "serde", + "thiserror 1.0.69", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -605,6 +660,18 @@ dependencies = [ "rand_core 0.10.1", ] +[[package]] +name = "chrono" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +dependencies = [ + "iana-time-zone", + "num-traits", + "serde", + "windows-link", +] + [[package]] name = "clap" version = "4.6.1" @@ -1648,6 +1715,30 @@ dependencies = [ "windows-registry", ] +[[package]] +name = "iana-time-zone" +version = "0.1.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "icu_collections" version = "2.1.1" @@ -2153,6 +2244,12 @@ dependencies = [ "unicase", ] +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "mio" version = "1.2.0" @@ -2170,6 +2267,15 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" +[[package]] +name = "murmur3" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a198f9589efc03f544388dfc4a19fe8af4323662b62f598b8dcfdac62c14771c" +dependencies = [ + "byteorder", +] + [[package]] name = "native-tls" version = "0.2.18" @@ -2187,6 +2293,16 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "num" version = "0.4.3" @@ -2346,7 +2462,7 @@ dependencies = [ "futures-sink", "js-sys", "pin-project-lite", - "thiserror", + "thiserror 2.0.18", "tracing", ] @@ -2374,7 +2490,7 @@ dependencies = [ "opentelemetry_sdk", "prost", "reqwest", - "thiserror", + "thiserror 2.0.18", "tokio", "tonic", "tonic-types", @@ -2416,7 +2532,7 @@ dependencies = [ "percent-encoding", "portable-atomic", "rand 0.9.4", - "thiserror", + "thiserror 2.0.18", "tokio", "tokio-stream", ] @@ -2460,6 +2576,12 @@ dependencies = [ "windows-link", ] +[[package]] +name = "pastey" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5a797f0e07bdf071d15742978fc3128ec6c22891c31a3a931513263904c982a" + [[package]] name = "pathdiff" version = "0.2.3" @@ -3316,7 +3438,7 @@ dependencies = [ "parking_lot", "serde", "serde_json", - "thiserror", + "thiserror 2.0.18", "tokio", "tracing", ] @@ -3510,13 +3632,33 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + [[package]] name = "thiserror" version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ - "thiserror-impl", + "thiserror-impl 2.0.18", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", ] [[package]] @@ -3896,6 +4038,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typed-arena" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6af6ae20167a9ece4bcb41af5b80f8a1f1df981f6391189ce00fd257af04126a" + [[package]] name = "typeid" version = "1.0.3" @@ -4327,6 +4475,41 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-link" version = "0.2.1" diff --git a/Cargo.toml b/Cargo.toml index 48ac06f..cb1c53d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ axum = { version = "0.8.9", features = [ "macros", ], default-features = false } bytes = "1.0" +cel = "0.13.0" fastrace = { version = "0.7.16", features = ["enable"] } futures = "0.3" http = "1.4.0" @@ -67,6 +68,7 @@ axum.workspace = true regex.workspace = true serde.workspace = true serde_json.workspace = true +cel.workspace = true async-trait.workspace = true aws-credential-types.workspace = true aws-sigv4.workspace = true diff --git a/src/admin/mod.rs b/src/admin/mod.rs index 436677e..858cbf7 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -1,6 +1,7 @@ mod apikeys; mod models; mod playground; +mod policies; mod providers; mod types; @@ -37,6 +38,7 @@ pub const PATH_PREFIX: &str = "/aisix/admin"; tags( (name = models::OPENAPI_TAG, description = "Admin API for managing AI models"), (name = apikeys::OPENAPI_TAG, description = "Admin API for managing API keys"), + (name = policies::OPENAPI_TAG, description = "Admin API for managing guardrail policies"), (name = providers::OPENAPI_TAG, description = "Admin API for managing AI providers") ), security( @@ -59,6 +61,11 @@ pub const PATH_PREFIX: &str = "/aisix/admin"; apikeys::post, apikeys::put, apikeys::delete, + policies::list, + policies::get, + policies::post, + policies::put, + policies::delete, ) )] struct ApiDoc; @@ -139,6 +146,16 @@ pub fn create_router(state: AppState) -> Result { get(apikeys::get).put(apikeys::put).delete(apikeys::delete), ), ) + .merge( + Router::new() + .route("/policies", get(policies::list).post(policies::post)) + .route( + "/policies/{id}", + get(policies::get) + .put(policies::put) + .delete(policies::delete), + ), + ) .layer(axum::middleware::from_fn_with_state(state.clone(), auth)), ) // These routes use API key authentication instead of Admin key authentication. diff --git a/src/admin/policies.rs b/src/admin/policies.rs new file mode 100644 index 0000000..bd73b57 --- /dev/null +++ b/src/admin/policies.rs @@ -0,0 +1,244 @@ +use std::collections::HashSet; + +use axum::{ + extract::{Path, State}, + response::{IntoResponse, Response}, +}; +use bytes::Bytes; +use http::StatusCode; +use uuid::Uuid; + +use crate::{ + admin::{ + AppState, + types::{APIError, DeleteResponse, ItemResponse, ListResponse}, + }, + config::{ + PutEntry, + entities::{ + Guardrail, Policy, + policies::{SCHEMA_VALIDATOR, validate_policy_definition}, + }, + }, + utils::jsonschema::format_evaluation_error, +}; + +pub const OPENAPI_TAG: &str = "Policies"; + +#[utoipa::path( + get, + context_path = crate::admin::PATH_PREFIX, + path = "/policies", + tag = OPENAPI_TAG, + responses( + (status = StatusCode::OK, description = "Get policy list success", body = ListResponse>), + (status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error", body = APIError) + ) +)] +pub async fn list(State(state): State) -> Response { + let data = match state + .config_provider + .get_all::("/policies") + .await + { + Ok(data) => data, + Err(err) => return APIError::InternalError(err).into_response(), + }; + + ListResponse { + total: data.len(), + list: data + .into_iter() + .map(|item| ItemResponse { + key: item.key, + value: item.value, + created_index: Some(item.create_revision), + modified_index: Some(item.mod_revision), + }) + .collect(), + } + .into_response() +} + +#[utoipa::path( + get, + context_path = crate::admin::PATH_PREFIX, + path = "/policies/{id}", + tag = OPENAPI_TAG, + params( + ("id" = String, Path, description = "The ID of the policy"), + ), + responses( + (status = StatusCode::OK, description = "Get policy success", body = ItemResponse), + (status = StatusCode::NOT_FOUND, description = "Policy not found", body = APIError), + (status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error", body = APIError) + ) +)] +pub async fn get(State(state): State, Path(id): Path) -> Response { + let key = format!("/policies/{id}"); + let data = match state.config_provider.get::(&key).await { + Ok(Some(data)) => data, + Ok(None) => { + return APIError::NotFound(format!("Policy with ID {id} not found")).into_response(); + } + Err(err) => return APIError::InternalError(err).into_response(), + }; + + ItemResponse { + key, + value: data.value, + created_index: Some(data.create_revision), + modified_index: Some(data.mod_revision), + } + .into_response() +} + +#[utoipa::path( + post, + context_path = crate::admin::PATH_PREFIX, + path = "/policies", + tag = OPENAPI_TAG, + request_body(content_type = "application/json", content = Policy), + responses( + (status = StatusCode::CREATED, description = "Policy created successfully", body = ItemResponse), + (status = StatusCode::BAD_REQUEST, description = "Bad request", body = APIError), + (status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error", body = APIError) + ) +)] +pub async fn post(State(state): State, body: Bytes) -> Response { + update(state, &Uuid::new_v4().to_string(), body).await +} + +#[utoipa::path( + put, + context_path = crate::admin::PATH_PREFIX, + path = "/policies/{id}", + tag = OPENAPI_TAG, + params( + ("id" = String, Path, description = "The ID of the policy"), + ), + request_body(content_type = "application/json", content = Policy), + responses( + (status = StatusCode::OK, description = "Policy updated successfully", body = ItemResponse), + (status = StatusCode::CREATED, description = "Policy created successfully", body = ItemResponse), + (status = StatusCode::BAD_REQUEST, description = "Bad request", body = APIError), + (status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error", body = APIError) + ) +)] +pub async fn put(State(state): State, Path(id): Path, body: Bytes) -> Response { + update(state, &id, body).await +} + +#[utoipa::path( + delete, + context_path = crate::admin::PATH_PREFIX, + path = "/policies/{id}", + tag = OPENAPI_TAG, + params( + ("id" = String, Path, description = "The ID of the policy"), + ), + responses( + (status = StatusCode::OK, description = "Policy deleted successfully", body = DeleteResponse), + (status = StatusCode::NOT_FOUND, description = "Policy not found", body = APIError), + (status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error", body = APIError) + ) +)] +pub async fn delete(State(state): State, Path(id): Path) -> Response { + let key = format!("/policies/{id}"); + match state.config_provider.delete(&key).await { + Ok(deleted) if deleted > 0 => DeleteResponse { deleted, key }.into_response(), + Ok(_) => APIError::NotFound(format!("Policy with ID {id} not found")).into_response(), + Err(err) => APIError::InternalError(err).into_response(), + } +} + +async fn update(state: AppState, id: &str, body: Bytes) -> Response { + let key = format!("/policies/{id}"); + + let policy = match serde_json::from_slice::(&body) { + Ok(value) => value, + Err(err) => return APIError::BadRequest(format!("Invalid JSON: {err}")).into_response(), + }; + + let evaluation = SCHEMA_VALIDATOR.evaluate(&policy); + if !evaluation.flag().valid { + return APIError::BadRequest(format!( + "JSON schema validation error: {}", + format_evaluation_error(&evaluation) + )) + .into_response(); + } + + let policy = match serde_json::from_value::(policy) { + Ok(value) => value, + Err(err) => { + return APIError::BadRequest(format!("Invalid policy data: {err}")).into_response(); + } + }; + + if let Err(err) = validate_policy_definition(id, &policy) { + return APIError::BadRequest(err).into_response(); + } + + let mut seen_guardrails = HashSet::new(); + for guardrail_id in policy.referenced_guardrail_ids() { + if !seen_guardrails.insert(guardrail_id.to_string()) { + continue; + } + + let guardrail_key = format!("/guardrails/{guardrail_id}"); + match state.config_provider.get::(&guardrail_key).await { + Ok(Some(_)) => {} + Ok(None) => { + return APIError::BadRequest(format!("Guardrail with ID {guardrail_id} not found")) + .into_response(); + } + Err(err) => return APIError::InternalError(err).into_response(), + } + } + + if let Some(found) = state.resources.policies.get_by_name(&policy.name) + && found.id != id + { + return APIError::BadRequest("Policy name already exists".to_string()).into_response(); + } + + match state.config_provider.get_all::("/policies").await { + Ok(data) => { + if data + .iter() + .any(|item| item.value.name == policy.name && item.key != key) + { + return APIError::BadRequest("Policy name already exists".to_string()) + .into_response(); + } + } + Err(err) => return APIError::InternalError(err).into_response(), + } + + match state.config_provider.put(&key, &policy).await { + Ok(res) => match res { + PutEntry::Created => ( + StatusCode::CREATED, + ItemResponse { + key: key.to_string(), + value: policy, + created_index: None, + modified_index: None, + }, + ) + .into_response(), + PutEntry::Updated(_prev) => ( + StatusCode::OK, + ItemResponse { + key: key.to_string(), + value: policy, + created_index: None, + modified_index: None, + }, + ) + .into_response(), + }, + Err(err) => APIError::InternalError(err).into_response(), + } +} diff --git a/src/config/entities/mod.rs b/src/config/entities/mod.rs index d264863..bb2d23e 100644 --- a/src/config/entities/mod.rs +++ b/src/config/entities/mod.rs @@ -1,6 +1,7 @@ pub mod apikeys; pub mod guardrails; pub mod models; +pub mod policies; pub mod providers; pub mod types; @@ -11,6 +12,7 @@ use arc_swap::ArcSwap; pub use guardrails::Guardrail; use log::{info, warn}; pub use models::Model; +pub use policies::Policy; pub use providers::Provider; use serde::de::DeserializeOwned; use tokio::sync::mpsc::Receiver; @@ -22,6 +24,7 @@ pub struct ResourceRegistry { pub models: models::ModelsStore, pub apikeys: apikeys::ApiKeysStore, pub guardrails: guardrails::GuardrailsStore, + pub policies: policies::PoliciesStore, pub providers: providers::ProvidersStore, } @@ -29,6 +32,7 @@ impl ResourceRegistry { pub async fn new(provider: Arc) -> Self { let providers = providers::ProvidersStore::new(provider.clone()).await; let guardrails = guardrails::GuardrailsStore::new(provider.clone()).await; + let policies = policies::PoliciesStore::new(provider.clone()).await; let models = models::ModelsStore::new(provider.clone()).await; let apikeys = apikeys::ApiKeysStore::new(provider).await; @@ -36,6 +40,7 @@ impl ResourceRegistry { models, apikeys, guardrails, + policies, providers, } } diff --git a/src/config/entities/policies-schema.json b/src/config/entities/policies-schema.json new file mode 100644 index 0000000..9eff884 --- /dev/null +++ b/src/config/entities/policies-schema.json @@ -0,0 +1,68 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "additionalProperties": false, + "required": ["name", "when", "actions"], + "properties": { + "name": { + "type": "string", + "minLength": 1 + }, + "enabled": { + "type": "boolean", + "default": true + }, + "priority": { + "type": "integer", + "default": 0 + }, + "when": { + "type": "string", + "minLength": 1 + }, + "actions": { + "type": "array", + "minItems": 1, + "items": { + "oneOf": [ + { + "type": "object", + "additionalProperties": false, + "required": ["type", "config"], + "properties": { + "type": { + "const": "guardrail" + }, + "config": { + "type": "object", + "additionalProperties": false, + "required": ["guardrail_ids"], + "properties": { + "stages": { + "type": "array", + "default": ["input", "output"], + "minItems": 1, + "uniqueItems": true, + "items": { + "type": "string", + "enum": ["input", "output"] + } + }, + "guardrail_ids": { + "type": "array", + "minItems": 1, + "uniqueItems": true, + "items": { + "type": "string", + "minLength": 1 + } + } + } + } + } + } + ] + } + } + } +} diff --git a/src/config/entities/policies.rs b/src/config/entities/policies.rs new file mode 100644 index 0000000..4d248c7 --- /dev/null +++ b/src/config/entities/policies.rs @@ -0,0 +1,282 @@ +use std::{ + collections::HashMap, + sync::{Arc, LazyLock}, +}; + +use cel::Program; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; + +use super::{ConfigProvider, EntityStore, IndexFn, ResourceEntry}; +use crate::utils::jsonschema::format_evaluation_error; + +static SCHEMA: LazyLock = LazyLock::new(|| { + serde_json::from_str(include_str!("policies-schema.json")) + .expect("Invalid JSON document for Policy schema") +}); +pub static SCHEMA_VALIDATOR: LazyLock = + LazyLock::new(|| jsonschema::validator_for(&SCHEMA).expect("Invalid JSON schema for Policy")); + +fn default_enabled() -> bool { + true +} + +fn default_policy_stages() -> Vec { + vec![PolicyStage::Input, PolicyStage::Output] +} + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum PolicyStage { + Input, + Output, +} + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, PartialEq, Eq)] +pub struct GuardrailPolicyAction { + #[serde(default = "default_policy_stages")] + pub stages: Vec, + pub guardrail_ids: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, PartialEq, Eq)] +#[serde(tag = "type", content = "config")] +pub enum PolicyAction { + #[serde(rename = "guardrail")] + Guardrail(GuardrailPolicyAction), +} + +impl PolicyAction { + fn guardrail_ids(&self) -> &[String] { + match self { + Self::Guardrail(config) => &config.guardrail_ids, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct Policy { + pub name: String, + + #[serde(default = "default_enabled")] + pub enabled: bool, + + #[serde(default)] + pub priority: i32, + + pub when: String, + + pub actions: Vec, +} + +impl Policy { + pub fn referenced_guardrail_ids(&self) -> impl Iterator { + self.actions + .iter() + .flat_map(|action| action.guardrail_ids().iter().map(String::as_str)) + } +} + +pub(crate) fn validate_policy_definition(key: &str, value: &Policy) -> Result<(), String> { + let evaluation = SCHEMA_VALIDATOR.evaluate( + &serde_json::to_value(value) + .map_err(|error| format!("Failed to serialize policy for validation: {error}"))?, + ); + if !evaluation.flag().valid { + return Err(format!( + r#"JSON schema validation error on policy "{key}": {}"#, + format_evaluation_error(&evaluation) + )); + } + + Program::compile(&value.when) + .map_err(|error| format!(r#"CEL validation error on policy "{key}": {error}"#))?; + + Ok(()) +} + +#[derive(Clone)] +pub struct PoliciesStore { + store: EntityStore, +} + +static INDEX_FNS: &[IndexFn] = &[("by_name", |policy: &Policy| Some(policy.name.clone()))]; + +impl PoliciesStore { + pub async fn new(provider: Arc) -> Self { + Self { + store: EntityStore::new( + provider, + "/policies/", + "policies", + Some(validate_policy_definition), + INDEX_FNS, + ) + .await, + } + } + + pub fn list(&self) -> Arc>> { + self.store.list() + } + + pub fn get_by_id(&self, id: &str) -> Option> { + self.store.get(id) + } + + pub fn get_by_name(&self, name: &str) -> Option> { + self.store.get_by_secondary("by_name", name) + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::{ + Policy, SCHEMA, SCHEMA_VALIDATOR, format_evaluation_error, validate_policy_definition, + }; + + #[test] + fn test_valid_jsonschema() { + assert!(jsonschema::meta::is_valid(&SCHEMA)); + } + + #[rstest::rstest] + #[case::ok_minimal(json!({ + "name": "tenant-a-bedrock-default", + "when": "auth.api_key.id == 'tenant-a' && provider.type == 'bedrock'", + "actions": [{ + "type": "guardrail", + "config": { + "guardrail_ids": ["gr-bedrock-default"] + } + }] + }), true, None)] + #[case::ok_with_explicit_defaults(json!({ + "name": "responses-session-review", + "enabled": true, + "priority": 80, + "when": "route.format == 'responses' && input.messages.size() > 20", + "actions": [{ + "type": "guardrail", + "config": { + "stages": ["input"], + "guardrail_ids": ["gr-session-review"] + } + }] + }), true, None)] + #[case::missing_name(json!({ + "when": "true", + "actions": [{ + "type": "guardrail", + "config": { + "stages": ["input"], + "guardrail_ids": ["gr-input"] + } + }] + }), false, Some(r#"property "/" validation failed: "name" is a required property"#.to_string()))] + #[case::invalid_stage(json!({ + "name": "invalid-stage", + "when": "true", + "actions": [{ + "type": "guardrail", + "config": { + "stages": ["tool_call"], + "guardrail_ids": ["gr-input"] + } + }] + }), false, Some(r#"property "/actions/0/config/stages/0" validation failed: "tool_call" is not one of "input" or "output""#.to_string()))] + #[case::duplicate_guardrail_ids(json!({ + "name": "duplicate-guardrails", + "when": "true", + "actions": [{ + "type": "guardrail", + "config": { + "stages": ["input"], + "guardrail_ids": ["gr-input", "gr-input"] + } + }] + }), false, Some(r#"property "/actions/0/config/guardrail_ids" validation failed: ["gr-input","gr-input"] has non-unique elements"#.to_string()))] + #[case::invalid_root_additional_property(json!({ + "name": "extra-field", + "when": "true", + "actions": [{ + "type": "guardrail", + "config": { + "stages": ["input"], + "guardrail_ids": ["gr-input"] + } + }], + "extra": true + }), false, Some(r#"property "/" validation failed: Additional properties are not allowed ('extra' was unexpected)"#.to_string()))] + fn schemas( + #[case] input: serde_json::Value, + #[case] ok: bool, + #[case] expected_error: Option, + ) { + let evaluation = SCHEMA_VALIDATOR.evaluate(&input); + + assert_eq!(evaluation.flag().valid, ok, "unexpected evaluation result"); + if !ok { + assert_eq!( + format_evaluation_error(&evaluation), + expected_error.unwrap(), + "unexpected error message" + ); + } + } + + #[test] + fn validate_policy_definition_rejects_invalid_cel() { + let policy: Policy = serde_json::from_value(json!({ + "name": "broken-cel", + "when": "route.format ==", + "actions": [{ + "type": "guardrail", + "config": { + "stages": ["input"], + "guardrail_ids": ["gr-input"] + } + }] + })) + .unwrap(); + + let error = validate_policy_definition("broken-cel", &policy).unwrap_err(); + + assert!(error.contains("CEL validation error on policy \"broken-cel\"")); + } + + #[test] + fn deserialize_policy_defaults_enabled_and_priority() { + let policy: Policy = serde_json::from_value(json!({ + "name": "defaults", + "when": "true", + "actions": [{ + "type": "guardrail", + "config": { + "guardrail_ids": ["gr-input"] + } + }] + })) + .unwrap(); + + assert_eq!(policy.name, "defaults"); + assert!(policy.enabled); + assert_eq!(policy.priority, 0); + assert_eq!( + policy.actions, + vec![super::PolicyAction::Guardrail( + super::GuardrailPolicyAction { + stages: vec![super::PolicyStage::Input, super::PolicyStage::Output], + guardrail_ids: vec!["gr-input".to_string()], + } + )] + ); + assert_eq!( + policy.referenced_guardrail_ids().collect::>(), + vec!["gr-input"] + ); + } +} From fce0cc126070a19139eae66fc82c830aebffcc14 Mon Sep 17 00:00:00 2001 From: bzp2010 Date: Sat, 16 May 2026 09:42:36 +0800 Subject: [PATCH 2/2] fix --- src/config/entities/policies.rs | 46 ++++++- tests/admin/policies.test.ts | 204 ++++++++++++++++++++++++++++++++ tests/utils/admin.ts | 1 + 3 files changed, 249 insertions(+), 2 deletions(-) create mode 100644 tests/admin/policies.test.ts diff --git a/src/config/entities/policies.rs b/src/config/entities/policies.rs index 4d248c7..ecdfe8f 100644 --- a/src/config/entities/policies.rs +++ b/src/config/entities/policies.rs @@ -1,6 +1,6 @@ use std::{ collections::HashMap, - sync::{Arc, LazyLock}, + sync::{Arc, LazyLock, OnceLock}, }; use cel::Program; @@ -25,6 +25,10 @@ fn default_policy_stages() -> Vec { vec![PolicyStage::Input, PolicyStage::Output] } +fn default_compiled_when() -> Arc>> { + Arc::new(OnceLock::new()) +} + #[derive(Debug, Clone, Serialize, Deserialize, ToSchema, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum PolicyStage { @@ -67,6 +71,10 @@ pub struct Policy { pub when: String, pub actions: Vec, + + #[serde(skip, default = "default_compiled_when")] + #[schema(ignore)] + compiled_when: Arc>>, } impl Policy { @@ -75,6 +83,17 @@ impl Policy { .iter() .flat_map(|action| action.guardrail_ids().iter().map(String::as_str)) } + + pub fn compiled_when(&self) -> Result, cel::ParseErrors> { + if let Some(program) = self.compiled_when.get() { + return Ok(Arc::clone(program)); + } + + let compiled = Arc::new(Program::compile(&self.when)?); + let _ = self.compiled_when.set(Arc::clone(&compiled)); + + Ok(self.compiled_when.get().map(Arc::clone).unwrap_or(compiled)) + } } pub(crate) fn validate_policy_definition(key: &str, value: &Policy) -> Result<(), String> { @@ -89,7 +108,8 @@ pub(crate) fn validate_policy_definition(key: &str, value: &Policy) -> Result<() )); } - Program::compile(&value.when) + value + .compiled_when() .map_err(|error| format!(r#"CEL validation error on policy "{key}": {error}"#))?; Ok(()) @@ -131,6 +151,8 @@ impl PoliciesStore { #[cfg(test)] mod tests { + use std::sync::Arc; + use pretty_assertions::assert_eq; use serde_json::json; @@ -279,4 +301,24 @@ mod tests { vec!["gr-input"] ); } + + #[test] + fn compiled_when_reuses_cached_program() { + let policy: Policy = serde_json::from_value(json!({ + "name": "cached-program", + "when": "route.format == 'chat_completions'", + "actions": [{ + "type": "guardrail", + "config": { + "guardrail_ids": ["gr-input"] + } + }] + })) + .unwrap(); + + let first = policy.compiled_when().unwrap(); + let second = policy.compiled_when().unwrap(); + + assert!(Arc::ptr_eq(&first, &second)); + } } diff --git a/tests/admin/policies.test.ts b/tests/admin/policies.test.ts new file mode 100644 index 0000000..98a6e4a --- /dev/null +++ b/tests/admin/policies.test.ts @@ -0,0 +1,204 @@ +import { randomUUID } from 'node:crypto'; + +import { + POLICIES_URL, + adminDelete, + adminGet, + adminPost, + adminPut, + bearerAuthHeader, + extractIdFromStorageKey, + startIsolatedAdminApp, +} from '../utils/admin.js'; +import { etcdPutJson } from '../utils/etcd.js'; +import { App } from '../utils/setup.js'; + +const ADMIN_KEY = 'test_admin_key'; + +const seedRegexGuardrail = async (etcdPrefix: string, guardrailId: string) => { + await etcdPutJson(etcdPrefix, `/guardrails/${guardrailId}`, { + name: `${guardrailId}-name`, + type: 'regex', + config: { + pattern: 'blocked phrase', + block_reason: 'blocked by policy admin test guardrail', + }, + }); +}; + +const buildPolicyBody = (name: string, guardrailId: string, when = 'true') => ({ + name, + when, + actions: [ + { + type: 'guardrail', + config: { + guardrail_ids: [guardrailId], + }, + }, + ], +}); + +describe('admin policies', () => { + let server: App | undefined; + let etcdPrefix = ''; + + beforeEach(async () => { + etcdPrefix = `/ai-admin-${randomUUID()}`; + server = await startIsolatedAdminApp(ADMIN_KEY, etcdPrefix); + }); + + afterEach(async () => await server?.exit()); + + test('test_crud', async () => { + const auth = bearerAuthHeader(ADMIN_KEY); + const guardrailId = `policy-crud-guardrail-${randomUUID()}`; + + await seedRegexGuardrail(etcdPrefix, guardrailId); + + const listBefore = await adminGet(POLICIES_URL, auth); + expect(listBefore.status).toBe(200); + expect(listBefore.data.total).toBe(0); + + const createResp = await adminPost( + POLICIES_URL, + buildPolicyBody('test_policy', guardrailId), + auth, + ); + expect(createResp.status).toBe(201); + expect(createResp.data.value.enabled).toBe(true); + expect(createResp.data.value.priority).toBe(0); + expect(createResp.data.value.actions[0].config.stages).toStrictEqual([ + 'input', + 'output', + ]); + const id = extractIdFromStorageKey(createResp.data.key); + + const listAfterCreate = await adminGet(POLICIES_URL, auth); + expect(listAfterCreate.status).toBe(200); + expect(listAfterCreate.data.total).toBe(1); + + const updateResp = await adminPut( + `${POLICIES_URL}/${id}`, + { + name: 'updated_policy', + enabled: false, + priority: 42, + when: "route.format == 'chat_completions'", + actions: [ + { + type: 'guardrail', + config: { + stages: ['input'], + guardrail_ids: [guardrailId], + }, + }, + ], + }, + auth, + ); + expect(updateResp.status).toBe(200); + expect(updateResp.data.value.name).toBe('updated_policy'); + expect(updateResp.data.value.enabled).toBe(false); + expect(updateResp.data.value.priority).toBe(42); + expect(updateResp.data.value.actions[0].config.stages).toStrictEqual([ + 'input', + ]); + + const getResp = await adminGet(`${POLICIES_URL}/${id}`, auth); + expect(getResp.status).toBe(200); + expect(getResp.data.value.name).toBe('updated_policy'); + + const deleteResp = await adminDelete(`${POLICIES_URL}/${id}`, auth); + expect(deleteResp.status).toBe(200); + expect(deleteResp.data.deleted).toBe(1); + + const listAfterDelete = await adminGet(POLICIES_URL, auth); + expect(listAfterDelete.status).toBe(200); + expect(listAfterDelete.data.total).toBe(0); + }); + + test('test_put_status_codes', async () => { + const auth = bearerAuthHeader(ADMIN_KEY); + const guardrailId = `policy-put-guardrail-${randomUUID()}`; + + await seedRegexGuardrail(etcdPrefix, guardrailId); + + const body = buildPolicyBody('put_policy', guardrailId); + + const firstPut = await adminPut( + `${POLICIES_URL}/put-policy-fixed-id`, + body, + auth, + ); + expect(firstPut.status).toBe(201); + + const secondPut = await adminPut( + `${POLICIES_URL}/put-policy-fixed-id`, + body, + auth, + ); + expect(secondPut.status).toBe(200); + }); + + test('test_put_duplicate_name_rejected', async () => { + const auth = bearerAuthHeader(ADMIN_KEY); + const guardrailId = `policy-dup-guardrail-${randomUUID()}`; + + await seedRegexGuardrail(etcdPrefix, guardrailId); + + const putA = await adminPut( + `${POLICIES_URL}/put-dup-policy-a`, + buildPolicyBody('put-dup-name-a', guardrailId), + auth, + ); + expect(putA.status).toBe(201); + + const putB = await adminPut( + `${POLICIES_URL}/put-dup-policy-b`, + buildPolicyBody('put-dup-name-b', guardrailId), + auth, + ); + expect(putB.status).toBe(201); + + const putDup = await adminPut( + `${POLICIES_URL}/put-dup-policy-b`, + buildPolicyBody('put-dup-name-a', guardrailId), + auth, + ); + expect(putDup.status).toBe(400); + expect(putDup.data.error_msg).toBe('Policy name already exists'); + }); + + test('test_missing_guardrail_rejected', async () => { + const auth = bearerAuthHeader(ADMIN_KEY); + const missingGuardrailId = `missing-guardrail-${randomUUID()}`; + + const createResp = await adminPost( + POLICIES_URL, + buildPolicyBody('missing_guardrail_policy', missingGuardrailId), + auth, + ); + expect(createResp.status).toBe(400); + expect(createResp.data.error_msg).toBe( + `Guardrail with ID ${missingGuardrailId} not found`, + ); + }); + + test('test_invalid_cel_rejected', async () => { + const auth = bearerAuthHeader(ADMIN_KEY); + const guardrailId = `policy-invalid-cel-guardrail-${randomUUID()}`; + + await seedRegexGuardrail(etcdPrefix, guardrailId); + + const putResp = await adminPut( + `${POLICIES_URL}/invalid-cel-policy`, + buildPolicyBody('invalid_cel_policy', guardrailId, 'route.format =='), + auth, + ); + expect(putResp.status).toBe(400); + expect(putResp.data.error_msg).toContain( + 'CEL validation error on policy "invalid-cel-policy"', + ); + }); +}); diff --git a/tests/utils/admin.ts b/tests/utils/admin.ts index 9509e63..55b7d0a 100644 --- a/tests/utils/admin.ts +++ b/tests/utils/admin.ts @@ -6,6 +6,7 @@ import { App, defaultConfig } from './setup.js'; export const ADMIN_BASE_URL = 'http://127.0.0.1:3001'; export const ADMIN_PREFIX = '/aisix/admin'; export const MODELS_URL = '/models'; +export const POLICIES_URL = '/policies'; export const PROVIDERS_URL = '/providers'; export const adminUrl = (path: string) =>