diff --git a/src/admin/guardrails.rs b/src/admin/guardrails.rs new file mode 100644 index 0000000..2e78ffe --- /dev/null +++ b/src/admin/guardrails.rs @@ -0,0 +1,232 @@ +use axum::{ + extract::{Path, State}, + response::{IntoResponse, Response}, +}; +use bytes::Bytes; +use http::StatusCode; +use uuid::Uuid; + +use crate::{ + admin::{ + AppState, + types::{APIError, DeleteResponse, ItemResponse, ListResponse}, + }, + config::{ + PutEntry, + entities::{ + Guardrail, Policy, + guardrails::{SCHEMA_VALIDATOR, validate_guardrail_definition}, + }, + }, + utils::jsonschema::format_evaluation_error, +}; + +pub const OPENAPI_TAG: &str = "Guardrails"; + +#[utoipa::path( + get, + context_path = crate::admin::PATH_PREFIX, + path = "/guardrails", + tag = OPENAPI_TAG, + responses( + (status = StatusCode::OK, description = "Get guardrail list success", body = ListResponse>), + (status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error", body = APIError) + ) +)] +pub async fn list(State(state): State) -> Response { + let data = match state + .config_provider + .get_all::("/guardrails") + .await + { + Ok(data) => data, + Err(err) => return APIError::InternalError(err).into_response(), + }; + + ListResponse { + total: data.len(), + list: data + .into_iter() + .map(|item| ItemResponse { + key: item.key, + value: item.value, + created_index: Some(item.create_revision), + modified_index: Some(item.mod_revision), + }) + .collect(), + } + .into_response() +} + +#[utoipa::path( + get, + context_path = crate::admin::PATH_PREFIX, + path = "/guardrails/{id}", + tag = OPENAPI_TAG, + params( + ("id" = String, Path, description = "The ID of the guardrail"), + ), + responses( + (status = StatusCode::OK, description = "Get guardrail success", body = ItemResponse), + (status = StatusCode::NOT_FOUND, description = "Guardrail not found", body = APIError), + (status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error", body = APIError) + ) +)] +pub async fn get(State(state): State, Path(id): Path) -> Response { + let key = format!("/guardrails/{id}"); + let data = match state.config_provider.get::(&key).await { + Ok(Some(data)) => data, + Ok(None) => { + return APIError::NotFound(format!("Guardrail with ID {id} not found")).into_response(); + } + Err(err) => return APIError::InternalError(err).into_response(), + }; + + ItemResponse { + key, + value: data.value, + created_index: Some(data.create_revision), + modified_index: Some(data.mod_revision), + } + .into_response() +} + +#[utoipa::path( + post, + context_path = crate::admin::PATH_PREFIX, + path = "/guardrails", + tag = OPENAPI_TAG, + request_body(content_type = "application/json", content = Guardrail), + responses( + (status = StatusCode::CREATED, description = "Guardrail created successfully", body = ItemResponse), + (status = StatusCode::BAD_REQUEST, description = "Bad request", body = APIError), + (status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error", body = APIError) + ) +)] +pub async fn post(State(state): State, body: Bytes) -> Response { + update(state, &Uuid::new_v4().to_string(), body).await +} + +#[utoipa::path( + put, + context_path = crate::admin::PATH_PREFIX, + path = "/guardrails/{id}", + tag = OPENAPI_TAG, + params( + ("id" = String, Path, description = "The ID of the guardrail"), + ), + request_body(content_type = "application/json", content = Guardrail), + responses( + (status = StatusCode::OK, description = "Guardrail updated successfully", body = ItemResponse), + (status = StatusCode::CREATED, description = "Guardrail created successfully", body = ItemResponse), + (status = StatusCode::BAD_REQUEST, description = "Bad request", body = APIError), + (status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error", body = APIError) + ) +)] +pub async fn put(State(state): State, Path(id): Path, body: Bytes) -> Response { + update(state, &id, body).await +} + +#[utoipa::path( + delete, + context_path = crate::admin::PATH_PREFIX, + path = "/guardrails/{id}", + tag = OPENAPI_TAG, + params( + ("id" = String, Path, description = "The ID of the guardrail"), + ), + responses( + (status = StatusCode::BAD_REQUEST, description = "Guardrail is still referenced by policies", body = APIError), + (status = StatusCode::OK, description = "Guardrail deleted successfully", body = DeleteResponse), + (status = StatusCode::NOT_FOUND, description = "Guardrail not found", body = APIError), + (status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error", body = APIError) + ) +)] +pub async fn delete(State(state): State, Path(id): Path) -> Response { + let key = format!("/guardrails/{id}"); + + match state.config_provider.get::(&key).await { + Ok(Some(_)) => {} + Ok(None) => { + return APIError::NotFound(format!("Guardrail with ID {id} not found")).into_response(); + } + Err(err) => return APIError::InternalError(err).into_response(), + } + + match state.config_provider.get_all::("/policies").await { + Ok(policies) => { + if policies.iter().any(|item| { + item.value + .referenced_guardrail_ids() + .any(|guardrail_id| guardrail_id == id) + }) { + return APIError::BadRequest( + "guardrail is still referenced by policies".to_string(), + ) + .into_response(); + } + } + Err(err) => return APIError::InternalError(err).into_response(), + } + + match state.config_provider.delete(&key).await { + Ok(deleted) if deleted > 0 => DeleteResponse { deleted, key }.into_response(), + Ok(_) => APIError::NotFound(format!("Guardrail with ID {id} not found")).into_response(), + Err(err) => APIError::InternalError(err).into_response(), + } +} + +async fn update(state: AppState, id: &str, body: Bytes) -> Response { + let key = format!("/guardrails/{id}"); + + let guardrail = match serde_json::from_slice::(&body) { + Ok(value) => value, + Err(err) => return APIError::BadRequest(format!("Invalid JSON: {err}")).into_response(), + }; + + let evaluation = SCHEMA_VALIDATOR.evaluate(&guardrail); + if !evaluation.flag().valid { + return APIError::BadRequest(format!( + "JSON schema validation error: {}", + format_evaluation_error(&evaluation) + )) + .into_response(); + } + + let guardrail = match serde_json::from_value::(guardrail) { + Ok(value) => value, + Err(err) => { + return APIError::BadRequest(format!("Invalid guardrail data: {err}")).into_response(); + } + }; + + if let Err(err) = validate_guardrail_definition(id, &guardrail) { + return APIError::BadRequest(err).into_response(); + } + + match state.config_provider.put(&key, &guardrail).await { + Ok(res) => match res { + PutEntry::Created => ( + StatusCode::CREATED, + ItemResponse { + key: key.to_string(), + value: guardrail, + created_index: None, + modified_index: None, + }, + ) + .into_response(), + PutEntry::Updated(_prev) => ( + StatusCode::OK, + ItemResponse { + key: key.to_string(), + value: guardrail, + created_index: None, + modified_index: None, + }, + ) + .into_response(), + }, + Err(err) => APIError::InternalError(err).into_response(), + } +} diff --git a/src/admin/mod.rs b/src/admin/mod.rs index 858cbf7..e3bfda7 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -1,4 +1,5 @@ mod apikeys; +mod guardrails; mod models; mod playground; mod policies; @@ -38,6 +39,7 @@ pub const PATH_PREFIX: &str = "/aisix/admin"; tags( (name = models::OPENAPI_TAG, description = "Admin API for managing AI models"), (name = apikeys::OPENAPI_TAG, description = "Admin API for managing API keys"), + (name = guardrails::OPENAPI_TAG, description = "Admin API for managing guardrails"), (name = policies::OPENAPI_TAG, description = "Admin API for managing guardrail policies"), (name = providers::OPENAPI_TAG, description = "Admin API for managing AI providers") ), @@ -61,6 +63,11 @@ pub const PATH_PREFIX: &str = "/aisix/admin"; apikeys::post, apikeys::put, apikeys::delete, + guardrails::list, + guardrails::get, + guardrails::post, + guardrails::put, + guardrails::delete, policies::list, policies::get, policies::post, @@ -146,6 +153,16 @@ pub fn create_router(state: AppState) -> Result { get(apikeys::get).put(apikeys::put).delete(apikeys::delete), ), ) + .merge( + Router::new() + .route("/guardrails", get(guardrails::list).post(guardrails::post)) + .route( + "/guardrails/{id}", + get(guardrails::get) + .put(guardrails::put) + .delete(guardrails::delete), + ), + ) .merge( Router::new() .route("/policies", get(policies::list).post(policies::post)) diff --git a/src/config/entities/guardrails.rs b/src/config/entities/guardrails.rs index 2b06929..759f3d4 100644 --- a/src/config/entities/guardrails.rs +++ b/src/config/entities/guardrails.rs @@ -57,7 +57,7 @@ impl Guardrail { } } -fn validate(key: &str, value: &Guardrail) -> Result<(), String> { +pub(crate) fn validate_guardrail_definition(key: &str, value: &Guardrail) -> Result<(), String> { let evaluation = SCHEMA_VALIDATOR.evaluate( &serde_json::to_value(value) .map_err(|error| format!("Failed to serialize guardrail for validation: {}", error))?, @@ -89,8 +89,14 @@ pub struct GuardrailsStore { impl GuardrailsStore { pub async fn new(provider: Arc) -> Self { Self { - store: EntityStore::new(provider, "/guardrails/", "guardrails", Some(validate), &[]) - .await, + store: EntityStore::new( + provider, + "/guardrails/", + "guardrails", + Some(validate_guardrail_definition), + &[], + ) + .await, } } diff --git a/tests/admin/guardrails.test.ts b/tests/admin/guardrails.test.ts new file mode 100644 index 0000000..61e2fdf --- /dev/null +++ b/tests/admin/guardrails.test.ts @@ -0,0 +1,174 @@ +import { randomUUID } from 'node:crypto'; + +import { + GUARDRAILS_URL, + POLICIES_URL, + adminDelete, + adminGet, + adminPost, + adminPut, + bearerAuthHeader, + extractIdFromStorageKey, + startIsolatedAdminApp, +} from '../utils/admin.js'; +import { App } from '../utils/setup.js'; + +const ADMIN_KEY = 'test_admin_key'; + +const buildRegexGuardrailBody = ( + name: string, + pattern = 'blocked phrase', + blockReason = 'blocked by guardrail admin test', +) => ({ + name, + type: 'regex', + config: { + pattern, + block_reason: blockReason, + }, +}); + +const buildPolicyBody = (name: string, guardrailId: string) => ({ + name, + when: 'true', + actions: [ + { + type: 'guardrail', + config: { + guardrail_ids: [guardrailId], + }, + }, + ], +}); + +describe('admin guardrails', () => { + let server: App | undefined; + let etcdPrefix = ''; + + beforeEach(async () => { + etcdPrefix = `/ai-admin-${randomUUID()}`; + server = await startIsolatedAdminApp(ADMIN_KEY, etcdPrefix); + }); + + afterEach(async () => await server?.exit()); + + test('test_crud', async () => { + const auth = bearerAuthHeader(ADMIN_KEY); + + const listBefore = await adminGet(GUARDRAILS_URL, auth); + expect(listBefore.status).toBe(200); + expect(listBefore.data.total).toBe(0); + + const createResp = await adminPost( + GUARDRAILS_URL, + buildRegexGuardrailBody('test_guardrail'), + auth, + ); + expect(createResp.status).toBe(201); + expect(createResp.data.value.name).toBe('test_guardrail'); + expect(createResp.data.value.type).toBe('regex'); + expect(createResp.data.value.config.pattern).toBe('blocked phrase'); + + const id = extractIdFromStorageKey(createResp.data.key); + + const listAfterCreate = await adminGet(GUARDRAILS_URL, auth); + expect(listAfterCreate.status).toBe(200); + expect(listAfterCreate.data.total).toBe(1); + + const updateResp = await adminPut( + `${GUARDRAILS_URL}/${id}`, + buildRegexGuardrailBody( + 'updated_guardrail', + 'updated phrase', + 'updated block reason', + ), + auth, + ); + expect(updateResp.status).toBe(200); + expect(updateResp.data.value.name).toBe('updated_guardrail'); + expect(updateResp.data.value.config.pattern).toBe('updated phrase'); + expect(updateResp.data.value.config.block_reason).toBe( + 'updated block reason', + ); + + const getResp = await adminGet(`${GUARDRAILS_URL}/${id}`, auth); + expect(getResp.status).toBe(200); + expect(getResp.data.value.name).toBe('updated_guardrail'); + + const deleteResp = await adminDelete(`${GUARDRAILS_URL}/${id}`, auth); + expect(deleteResp.status).toBe(200); + expect(deleteResp.data.deleted).toBe(1); + + const listAfterDelete = await adminGet(GUARDRAILS_URL, auth); + expect(listAfterDelete.status).toBe(200); + expect(listAfterDelete.data.total).toBe(0); + }); + + test('test_put_status_codes', async () => { + const auth = bearerAuthHeader(ADMIN_KEY); + const body = buildRegexGuardrailBody('put_guardrail'); + + const firstPut = await adminPut( + `${GUARDRAILS_URL}/put-guardrail-fixed-id`, + body, + auth, + ); + expect(firstPut.status).toBe(201); + + const secondPut = await adminPut( + `${GUARDRAILS_URL}/put-guardrail-fixed-id`, + body, + auth, + ); + expect(secondPut.status).toBe(200); + }); + + test('test_invalid_schema_rejected', async () => { + const auth = bearerAuthHeader(ADMIN_KEY); + + const createResp = await adminPost( + GUARDRAILS_URL, + { + name: 'invalid_guardrail', + type: 'regex', + config: {}, + }, + auth, + ); + + expect(createResp.status).toBe(400); + expect(createResp.data.error_msg).toContain('JSON schema validation error'); + }); + + test('test_delete_referenced_guardrail_rejected', async () => { + const auth = bearerAuthHeader(ADMIN_KEY); + + const guardrailResp = await adminPost( + GUARDRAILS_URL, + buildRegexGuardrailBody('referenced_guardrail'), + auth, + ); + expect(guardrailResp.status).toBe(201); + + const guardrailId = extractIdFromStorageKey(guardrailResp.data.key); + + const policyResp = await adminPost( + POLICIES_URL, + buildPolicyBody('guardrail_ref_policy', guardrailId), + auth, + ); + expect(policyResp.status).toBe(201); + + const deleteResp = await adminDelete( + `${GUARDRAILS_URL}/${guardrailId}`, + auth, + ); + expect(deleteResp.status).toBe(400); + expect(deleteResp.data.error_msg).toBe( + 'guardrail is still referenced by policies', + ); + + const getResp = await adminGet(`${GUARDRAILS_URL}/${guardrailId}`, auth); + expect(getResp.status).toBe(200); + }); +}); diff --git a/tests/utils/admin.ts b/tests/utils/admin.ts index 55b7d0a..28ff78a 100644 --- a/tests/utils/admin.ts +++ b/tests/utils/admin.ts @@ -5,6 +5,7 @@ import { App, defaultConfig } from './setup.js'; export const ADMIN_BASE_URL = 'http://127.0.0.1:3001'; export const ADMIN_PREFIX = '/aisix/admin'; +export const GUARDRAILS_URL = '/guardrails'; export const MODELS_URL = '/models'; export const POLICIES_URL = '/policies'; export const PROVIDERS_URL = '/providers';