Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 232 additions & 0 deletions src/admin/guardrails.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
use axum::{
extract::{Path, State},
response::{IntoResponse, Response},
};
use bytes::Bytes;
use http::StatusCode;
use uuid::Uuid;

use crate::{
admin::{
AppState,
types::{APIError, DeleteResponse, ItemResponse, ListResponse},
},
config::{
PutEntry,
entities::{
Guardrail, Policy,
guardrails::{SCHEMA_VALIDATOR, validate_guardrail_definition},
},
},
utils::jsonschema::format_evaluation_error,
};

pub const OPENAPI_TAG: &str = "Guardrails";

#[utoipa::path(
get,
context_path = crate::admin::PATH_PREFIX,
path = "/guardrails",
tag = OPENAPI_TAG,
responses(
(status = StatusCode::OK, description = "Get guardrail list success", body = ListResponse<ItemResponse<Guardrail>>),
(status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error", body = APIError)
)
)]
pub async fn list(State(state): State<AppState>) -> Response {
let data = match state
Comment thread
bzp2010 marked this conversation as resolved.
.config_provider
.get_all::<serde_json::Value>("/guardrails")
.await
{
Ok(data) => data,
Err(err) => return APIError::InternalError(err).into_response(),
};

ListResponse {
total: data.len(),
list: data
.into_iter()
.map(|item| ItemResponse {
key: item.key,
value: item.value,
created_index: Some(item.create_revision),
modified_index: Some(item.mod_revision),
})
.collect(),
}
.into_response()
}

#[utoipa::path(
get,
context_path = crate::admin::PATH_PREFIX,
path = "/guardrails/{id}",
tag = OPENAPI_TAG,
params(
("id" = String, Path, description = "The ID of the guardrail"),
),
responses(
(status = StatusCode::OK, description = "Get guardrail success", body = ItemResponse<Guardrail>),
(status = StatusCode::NOT_FOUND, description = "Guardrail not found", body = APIError),
(status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error", body = APIError)
)
)]
pub async fn get(State(state): State<AppState>, Path(id): Path<String>) -> Response {
let key = format!("/guardrails/{id}");
let data = match state.config_provider.get::<serde_json::Value>(&key).await {
Ok(Some(data)) => data,
Ok(None) => {
return APIError::NotFound(format!("Guardrail with ID {id} not found")).into_response();
}
Err(err) => return APIError::InternalError(err).into_response(),
};

ItemResponse {
key,
value: data.value,
created_index: Some(data.create_revision),
modified_index: Some(data.mod_revision),
}
.into_response()
}

#[utoipa::path(
post,
context_path = crate::admin::PATH_PREFIX,
path = "/guardrails",
tag = OPENAPI_TAG,
request_body(content_type = "application/json", content = Guardrail),
responses(
(status = StatusCode::CREATED, description = "Guardrail created successfully", body = ItemResponse<Guardrail>),
(status = StatusCode::BAD_REQUEST, description = "Bad request", body = APIError),
(status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error", body = APIError)
)
)]
pub async fn post(State(state): State<AppState>, body: Bytes) -> Response {
update(state, &Uuid::new_v4().to_string(), body).await
}

#[utoipa::path(
put,
context_path = crate::admin::PATH_PREFIX,
path = "/guardrails/{id}",
tag = OPENAPI_TAG,
params(
("id" = String, Path, description = "The ID of the guardrail"),
),
request_body(content_type = "application/json", content = Guardrail),
responses(
(status = StatusCode::OK, description = "Guardrail updated successfully", body = ItemResponse<Guardrail>),
(status = StatusCode::CREATED, description = "Guardrail created successfully", body = ItemResponse<Guardrail>),
(status = StatusCode::BAD_REQUEST, description = "Bad request", body = APIError),
(status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error", body = APIError)
)
)]
pub async fn put(State(state): State<AppState>, Path(id): Path<String>, body: Bytes) -> Response {
update(state, &id, body).await
}

#[utoipa::path(
delete,
context_path = crate::admin::PATH_PREFIX,
path = "/guardrails/{id}",
tag = OPENAPI_TAG,
params(
("id" = String, Path, description = "The ID of the guardrail"),
),
responses(
(status = StatusCode::BAD_REQUEST, description = "Guardrail is still referenced by policies", body = APIError),
(status = StatusCode::OK, description = "Guardrail deleted successfully", body = DeleteResponse),
(status = StatusCode::NOT_FOUND, description = "Guardrail not found", body = APIError),
(status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error", body = APIError)
)
)]
pub async fn delete(State(state): State<AppState>, Path(id): Path<String>) -> Response {
let key = format!("/guardrails/{id}");

match state.config_provider.get::<serde_json::Value>(&key).await {
Ok(Some(_)) => {}
Ok(None) => {
return APIError::NotFound(format!("Guardrail with ID {id} not found")).into_response();
}
Err(err) => return APIError::InternalError(err).into_response(),
}

match state.config_provider.get_all::<Policy>("/policies").await {
Ok(policies) => {
if policies.iter().any(|item| {
item.value
.referenced_guardrail_ids()
.any(|guardrail_id| guardrail_id == id)
}) {
return APIError::BadRequest(
"guardrail is still referenced by policies".to_string(),
)
.into_response();
}
}
Err(err) => return APIError::InternalError(err).into_response(),
}

match state.config_provider.delete(&key).await {
Ok(deleted) if deleted > 0 => DeleteResponse { deleted, key }.into_response(),
Ok(_) => APIError::NotFound(format!("Guardrail with ID {id} not found")).into_response(),
Err(err) => APIError::InternalError(err).into_response(),
}
}

async fn update(state: AppState, id: &str, body: Bytes) -> Response {
let key = format!("/guardrails/{id}");

let guardrail = match serde_json::from_slice::<serde_json::Value>(&body) {
Ok(value) => value,
Err(err) => return APIError::BadRequest(format!("Invalid JSON: {err}")).into_response(),
};

let evaluation = SCHEMA_VALIDATOR.evaluate(&guardrail);
if !evaluation.flag().valid {
return APIError::BadRequest(format!(
"JSON schema validation error: {}",
format_evaluation_error(&evaluation)
))
.into_response();
}

let guardrail = match serde_json::from_value::<Guardrail>(guardrail) {
Ok(value) => value,
Err(err) => {
return APIError::BadRequest(format!("Invalid guardrail data: {err}")).into_response();
}
};

if let Err(err) = validate_guardrail_definition(id, &guardrail) {
return APIError::BadRequest(err).into_response();
}

match state.config_provider.put(&key, &guardrail).await {
Ok(res) => match res {
PutEntry::Created => (
StatusCode::CREATED,
ItemResponse {
key: key.to_string(),
value: guardrail,
created_index: None,
modified_index: None,
},
)
.into_response(),
PutEntry::Updated(_prev) => (
StatusCode::OK,
ItemResponse {
key: key.to_string(),
value: guardrail,
created_index: None,
modified_index: None,
},
)
.into_response(),
},
Err(err) => APIError::InternalError(err).into_response(),
}
}
17 changes: 17 additions & 0 deletions src/admin/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod apikeys;
mod guardrails;
mod models;
mod playground;
mod policies;
Expand Down Expand Up @@ -38,6 +39,7 @@ pub const PATH_PREFIX: &str = "/aisix/admin";
tags(
(name = models::OPENAPI_TAG, description = "Admin API for managing AI models"),
(name = apikeys::OPENAPI_TAG, description = "Admin API for managing API keys"),
(name = guardrails::OPENAPI_TAG, description = "Admin API for managing guardrails"),
(name = policies::OPENAPI_TAG, description = "Admin API for managing guardrail policies"),
(name = providers::OPENAPI_TAG, description = "Admin API for managing AI providers")
),
Expand All @@ -61,6 +63,11 @@ pub const PATH_PREFIX: &str = "/aisix/admin";
apikeys::post,
apikeys::put,
apikeys::delete,
guardrails::list,
guardrails::get,
guardrails::post,
guardrails::put,
guardrails::delete,
policies::list,
policies::get,
policies::post,
Expand Down Expand Up @@ -146,6 +153,16 @@ pub fn create_router(state: AppState) -> Result<Router> {
get(apikeys::get).put(apikeys::put).delete(apikeys::delete),
),
)
.merge(
Router::new()
.route("/guardrails", get(guardrails::list).post(guardrails::post))
.route(
"/guardrails/{id}",
get(guardrails::get)
.put(guardrails::put)
.delete(guardrails::delete),
),
)
.merge(
Router::new()
.route("/policies", get(policies::list).post(policies::post))
Expand Down
12 changes: 9 additions & 3 deletions src/config/entities/guardrails.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl Guardrail {
}
}

fn validate(key: &str, value: &Guardrail) -> Result<(), String> {
pub(crate) fn validate_guardrail_definition(key: &str, value: &Guardrail) -> Result<(), String> {
let evaluation = SCHEMA_VALIDATOR.evaluate(
&serde_json::to_value(value)
.map_err(|error| format!("Failed to serialize guardrail for validation: {}", error))?,
Expand Down Expand Up @@ -89,8 +89,14 @@ pub struct GuardrailsStore {
impl GuardrailsStore {
pub async fn new(provider: Arc<dyn ConfigProvider + Send + Sync>) -> Self {
Self {
store: EntityStore::new(provider, "/guardrails/", "guardrails", Some(validate), &[])
.await,
store: EntityStore::new(
provider,
"/guardrails/",
"guardrails",
Some(validate_guardrail_definition),
&[],
)
.await,
}
}

Expand Down
Loading