Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/sprout-auth/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ pub struct AuthContext {
pub pubkey: nostr::PublicKey,
/// Permission scopes granted to this connection.
pub scopes: Vec<Scope>,
/// Token-level channel restriction, if authentication used a scoped API token.
///
/// `None` means unrestricted or not token-authenticated.
pub channel_ids: Option<Vec<uuid::Uuid>>,
/// How the connection was authenticated.
pub auth_method: AuthMethod,
}
Expand Down Expand Up @@ -187,6 +191,7 @@ impl AuthService {
Ok(AuthContext {
pubkey: verified_pubkey,
scopes,
channel_ids: None,
auth_method,
})
}
Expand Down Expand Up @@ -417,6 +422,7 @@ mod tests {
let ctx = AuthContext {
pubkey: keys.public_key(),
scopes: vec![Scope::MessagesRead, Scope::ChannelsRead],
channel_ids: None,
auth_method: AuthMethod::Nip42PubkeyOnly,
};
assert!(ctx.has_scope(&Scope::MessagesRead));
Expand Down
21 changes: 12 additions & 9 deletions crates/sprout-relay/src/api/agents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use nostr::util::hex as nostr_hex;

use crate::state::AppState;

use super::{extract_auth_context, internal_error};
use super::{constrain_accessible_channels, extract_auth_context, internal_error};

/// Returns all bot/agent members visible to the authenticated user, with presence status.
///
Expand All @@ -28,14 +28,17 @@ pub async fn agents_handler(
let pubkey_bytes = ctx.pubkey_bytes.clone();

// Get requester's accessible channels to filter bot channel visibility.
let accessible_channels = state
.db
.get_accessible_channels(&pubkey_bytes, None, None)
.await
.map_err(|e| {
tracing::error!("agents: failed to load accessible channels: {e}");
internal_error("presence lookup failed")
})?;
let accessible_channels = constrain_accessible_channels(
state
.db
.get_accessible_channels(&pubkey_bytes, None, None)
.await
.map_err(|e| {
tracing::error!("agents: failed to load accessible channels: {e}");
internal_error("presence lookup failed")
})?,
ctx.channel_ids.as_deref(),
);
let accessible_ids: std::collections::HashSet<String> = accessible_channels
.iter()
.map(|ac| ac.channel.id.to_string())
Expand Down
15 changes: 9 additions & 6 deletions crates/sprout-relay/src/api/channels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use sprout_db::channel::ChannelRecord;

use crate::state::AppState;

use super::{extract_auth_context, internal_error};
use super::{constrain_accessible_channels, extract_auth_context, internal_error};

/// Query parameters for `GET /api/channels`.
#[derive(Debug, Deserialize)]
Expand All @@ -41,11 +41,14 @@ pub async fn channels_handler(
.map_err(super::scope_error)?;
let pubkey_bytes = ctx.pubkey_bytes.clone();

let channels = state
.db
.get_accessible_channels(&pubkey_bytes, params.visibility.as_deref(), params.member)
.await
.map_err(|e| internal_error(&format!("db error: {e}")))?;
let channels = constrain_accessible_channels(
state
.db
.get_accessible_channels(&pubkey_bytes, params.visibility.as_deref(), params.member)
.await
.map_err(|e| internal_error(&format!("db error: {e}")))?,
ctx.channel_ids.as_deref(),
);

// Bulk-fetch member counts and last-message timestamps in two queries
// instead of 2N queries (one per channel per metric).
Expand Down
15 changes: 9 additions & 6 deletions crates/sprout-relay/src/api/feed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use sprout_core::kind::{self, event_kind_u32};

use crate::state::AppState;

use super::{extract_auth_context, internal_error};
use super::{constrain_channel_ids, extract_auth_context, internal_error};

/// Agent activity kind set — used to partition activity into agent vs channel activity.
const AGENT_KINDS: &[u32] = &[
Expand Down Expand Up @@ -71,11 +71,14 @@ pub async fn feed_handler(
.map(|t| t.split(',').map(|s| s.trim()).collect());
let wants = |cat: &str| -> bool { type_filter.as_ref().is_none_or(|f| f.contains(cat)) };

let accessible_ids = state
.db
.get_accessible_channel_ids(&pubkey_bytes)
.await
.map_err(|e| internal_error(&format!("db error: {e}")))?;
let accessible_ids = constrain_channel_ids(
state
.db
.get_accessible_channel_ids(&pubkey_bytes)
.await
.map_err(|e| internal_error(&format!("db error: {e}")))?,
ctx.channel_ids.as_deref(),
);

if accessible_ids.is_empty() {
let generated_at = Utc::now().timestamp();
Expand Down
89 changes: 89 additions & 0 deletions crates/sprout-relay/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,28 @@ pub fn check_token_channel_access(
Ok(())
}

/// Intersect an owner-accessible channel list with a token's `channel_ids`, when present.
pub(crate) fn constrain_channel_ids(
mut channel_ids: Vec<Uuid>,
allowed: Option<&[Uuid]>,
) -> Vec<Uuid> {
if let Some(allowed) = allowed {
channel_ids.retain(|channel_id| allowed.contains(channel_id));
}
channel_ids
}

/// Filter accessible channel records against a token's `channel_ids`, when present.
pub(crate) fn constrain_accessible_channels(
mut channels: Vec<sprout_db::channel::AccessibleChannel>,
allowed: Option<&[Uuid]>,
) -> Vec<sprout_db::channel::AccessibleChannel> {
if let Some(allowed) = allowed {
channels.retain(|channel| allowed.contains(&channel.channel.id));
}
channels
}

/// Convert a scope-check failure into a 403 Forbidden response.
///
/// Used by handlers to propagate `require_scope` errors via `?`.
Expand Down Expand Up @@ -557,8 +579,41 @@ where

#[cfg(test)]
mod tests {
use chrono::Utc;

use super::*;

fn accessible_channel(id: Uuid) -> sprout_db::channel::AccessibleChannel {
let now = Utc::now();
sprout_db::channel::AccessibleChannel {
channel: sprout_db::channel::ChannelRecord {
id,
name: "restricted".to_string(),
channel_type: "stream".to_string(),
visibility: "private".to_string(),
description: None,
canvas: None,
created_by: vec![0; 32],
created_at: now,
updated_at: now,
archived_at: None,
deleted_at: None,
nip29_group_id: None,
topic_required: false,
max_members: None,
topic: None,
topic_set_by: None,
topic_set_at: None,
purpose: None,
purpose_set_by: None,
purpose_set_at: None,
ttl_seconds: None,
ttl_deadline: None,
},
is_member: true,
}
}

// ── decode_jwt_payload_unverified ─────────────────────────────────────────
//
// This private helper is the core of the dev-mode JWT path in
Expand Down Expand Up @@ -763,4 +818,38 @@ mod tests {
assert_eq!(status, StatusCode::NOT_FOUND);
assert_eq!(body.0["error"], "approval not found");
}

#[test]
fn constrain_channel_ids_intersects_with_token_allowlist() {
let allowed = uuid::Uuid::new_v4();
let denied = uuid::Uuid::new_v4();

let constrained = constrain_channel_ids(vec![allowed, denied], Some(&[allowed]));

assert_eq!(constrained, vec![allowed]);
}

#[test]
fn constrain_channel_ids_leaves_unrestricted_lists_unchanged() {
let a = uuid::Uuid::new_v4();
let b = uuid::Uuid::new_v4();

let constrained = constrain_channel_ids(vec![a, b], None);

assert_eq!(constrained, vec![a, b]);
}

#[test]
fn constrain_accessible_channels_respects_token_allowlist() {
let allowed = uuid::Uuid::new_v4();
let denied = uuid::Uuid::new_v4();

let constrained = constrain_accessible_channels(
vec![accessible_channel(allowed), accessible_channel(denied)],
Some(&[allowed]),
);

assert_eq!(constrained.len(), 1);
assert_eq!(constrained[0].channel.id, allowed);
}
}
27 changes: 19 additions & 8 deletions crates/sprout-relay/src/api/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use sprout_search::SearchQuery;

use crate::state::AppState;

use super::extract_auth_context;
use super::{constrain_channel_ids, extract_auth_context};

/// Query parameters for the search endpoint.
#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -42,21 +42,32 @@ pub async fn search_handler(
let query_str = params.q.unwrap_or_default();
let per_page = params.limit.unwrap_or(20).min(100);

let channel_ids = state
.db
.get_accessible_channel_ids(&pubkey_bytes)
.await
.unwrap_or_default();
let channel_ids = constrain_channel_ids(
state
.db
.get_accessible_channel_ids(&pubkey_bytes)
.await
.unwrap_or_default(),
ctx.channel_ids.as_deref(),
);

// Build Typesense filter_by: channel_id:=[id1,id2,...] || global events
// Channel-restricted tokens must stay within their allowlist; unrestricted callers
// may also see global events.
let include_global = ctx.channel_ids.is_none();
if channel_ids.is_empty() && !include_global {
return Ok(Json(serde_json::json!({ "hits": [], "found": 0 })));
}
let filter_by = if channel_ids.is_empty() {
Some("channel_id:=__global__".to_string())
} else {
} else if include_global {
let ids: Vec<String> = channel_ids.iter().map(|id| id.to_string()).collect();
Some(format!(
"(channel_id:=[{}] || channel_id:=__global__)",
ids.join(",")
))
} else {
let ids: Vec<String> = channel_ids.iter().map(|id| id.to_string()).collect();
Some(format!("channel_id:=[{}]", ids.join(",")))
};

let search_query = SearchQuery {
Expand Down
85 changes: 67 additions & 18 deletions crates/sprout-relay/src/api/tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,29 @@ pub struct RevokeAllResponse {
pub revoked_count: u64,
}

fn ensure_requested_scopes_within_caller(
ctx: &super::RestAuthContext,
requested_scopes: &[Scope],
) -> Result<(), (StatusCode, Json<serde_json::Value>)> {
if matches!(ctx.auth_method, RestAuthMethod::Nip98) {
return Ok(());
}

for scope in requested_scopes {
if !ctx.scopes.contains(scope) {
return Err((
StatusCode::FORBIDDEN,
Json(serde_json::json!({
"error": "scope_escalation",
"message": format!("Cannot mint scope '{}' — not in your token's scopes", scope)
})),
));
}
}

Ok(())
}

// ── Handlers ──────────────────────────────────────────────────────────────────

/// `POST /api/tokens` — mint a new API token.
Expand Down Expand Up @@ -343,24 +366,14 @@ pub async fn post_tokens(
let mut seen = std::collections::HashSet::new();
parsed_scopes.retain(|s| seen.insert(s.clone()));

// ── Scope escalation prevention (Bearer-authenticated callers) ───────────
// If the caller authenticated via an API token, the requested scopes must be
// a subset of the caller's own scopes, and channel_ids must be a subset too.
// NIP-98 and Okta JWT callers are unrestricted (they authenticate the identity
// directly, not via a scoped token).
if matches!(ctx.auth_method, RestAuthMethod::ApiToken) {
for scope in &parsed_scopes {
if !ctx.scopes.contains(scope) {
return Err((
StatusCode::FORBIDDEN,
Json(serde_json::json!({
"error": "scope_escalation",
"message": format!("Cannot mint scope '{}' — not in your token's scopes", scope)
})),
));
}
}
// ── Scope escalation prevention (all non-bootstrap callers) ──────────────
// Any caller that arrived with an already-authorized identity must stay within
// the scopes granted to that identity. NIP-98 bootstrap mints are the only
// exception because they deliberately authenticate ownership, not preexisting
// relay scopes.
ensure_requested_scopes_within_caller(&ctx, &parsed_scopes)?;

if !matches!(ctx.auth_method, RestAuthMethod::Nip98) {
// If caller has channel_ids restriction, child must also be restricted
// to a subset of those channels.
if let Some(ref caller_channels) = ctx.channel_ids {
Expand Down Expand Up @@ -423,7 +436,7 @@ pub async fn post_tokens(
// token is channel-restricted, this would be an escalation. The subset
// check above already rejects `None` for restricted callers, so we
// must also reject empty arrays here.
if matches!(ctx.auth_method, RestAuthMethod::ApiToken) && ctx.channel_ids.is_some() {
if ctx.channel_ids.is_some() {
return Err((
StatusCode::FORBIDDEN,
Json(serde_json::json!({
Expand Down Expand Up @@ -816,6 +829,24 @@ fn reconstruct_canonical_url_for_tokens(relay_url: &str) -> String {
#[cfg(test)]
mod tests {
use super::*;
use nostr::Keys;

fn auth_context(
auth_method: RestAuthMethod,
scopes: Vec<Scope>,
channel_ids: Option<Vec<Uuid>>,
) -> super::super::RestAuthContext {
let keys = Keys::generate();
let pubkey = keys.public_key();
super::super::RestAuthContext {
pubkey,
pubkey_bytes: pubkey.to_bytes().to_vec(),
scopes,
auth_method,
token_id: None,
channel_ids,
}
}

#[test]
fn rate_limiter_allows_up_to_limit() {
Expand Down Expand Up @@ -883,4 +914,22 @@ mod tests {
let url = reconstruct_canonical_url_for_tokens("wss://relay.example.test/");
assert_eq!(url, "https://relay.example.test/api/tokens");
}

#[test]
fn bearer_callers_cannot_self_mint_new_scopes() {
let ctx = auth_context(RestAuthMethod::OktaJwt, vec![Scope::MessagesRead], None);

let err = ensure_requested_scopes_within_caller(&ctx, &[Scope::MessagesWrite])
.expect_err("bearer-auth callers must stay within their granted scopes");

assert_eq!(err.0, StatusCode::FORBIDDEN);
assert_eq!(err.1 .0["error"].as_str(), Some("scope_escalation"));
}

#[test]
fn nip98_bootstrap_mints_are_not_limited_by_existing_scope_list() {
let ctx = auth_context(RestAuthMethod::Nip98, Vec::new(), None);

assert!(ensure_requested_scopes_within_caller(&ctx, &[Scope::MessagesWrite]).is_ok());
}
}
Loading