From 41ddc5367ea3d1cb476e3ab990eb78a99c1164d0 Mon Sep 17 00:00:00 2001 From: Eman Cickusic Date: Thu, 18 Jun 2026 00:22:04 +0200 Subject: [PATCH] Route food-label analysis through Cloud Run proxy Analysis now works without a user-supplied LLM key: when no key is saved, requests go to the keyless Cloud Run /analyze proxy (Gemini via Vertex/ADC). A saved key still routes to the existing BYOK providers. - ProxyFoodLabelLlmWorkflow: one /analyze POST per scan, sliced across the 3 pipeline stages via a de-duplicated in-flight Deferred so the pipeline's per-stage timeout+retry joins the same call instead of issuing a second one (one scan = one call, correct token accounting). - SelectingFoodLabelLlmWorkflow: proxy when no key, BYOK when present. - PROXY_BASE_URL BuildConfig field (ZEST_PROXY_BASE_URL override). - Drop the "missing key" scanner banner/badge; scanning is keyless. - Result chat unchanged (still uses a personal key). Tests: parse, single-call/dedup, error mapping, and key-based routing. Co-Authored-By: Claude Opus 4.8 (1M context) --- app/build.gradle.kts | 10 + .../ui/AppChromeFunctionalTest.kt | 1 - .../analysis/FoodAnalysisPipeline.kt | 8 +- .../network/llm/ProxyFoodLabelLlmWorkflow.kt | 307 ++++++++++++++++++ .../llm/SelectingFoodLabelLlmWorkflow.kt | 41 +++ .../com/b2/ultraprocessed/ui/ScannerScreen.kt | 37 --- .../b2/ultraprocessed/ui/UltraProcessedApp.kt | 1 - .../llm/ProxyFoodLabelLlmWorkflowTest.kt | 144 ++++++++ .../llm/SelectingFoodLabelLlmWorkflowTest.kt | 125 +++++++ 9 files changed, 634 insertions(+), 40 deletions(-) create mode 100644 app/src/main/java/com/b2/ultraprocessed/network/llm/ProxyFoodLabelLlmWorkflow.kt create mode 100644 app/src/main/java/com/b2/ultraprocessed/network/llm/SelectingFoodLabelLlmWorkflow.kt create mode 100644 app/src/test/java/com/b2/ultraprocessed/network/llm/ProxyFoodLabelLlmWorkflowTest.kt create mode 100644 app/src/test/java/com/b2/ultraprocessed/network/llm/SelectingFoodLabelLlmWorkflowTest.kt diff --git a/app/build.gradle.kts b/app/build.gradle.kts index f4e300e..f3e75a7 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -26,6 +26,11 @@ val usdaBootstrapApiKeyB64 = localProperties .getProperty("ZEST_USDA_BOOTSTRAP_API_KEY_B64") .orEmpty() .trim() +val proxyBaseUrl = localProperties + .getProperty("ZEST_PROXY_BASE_URL") + .orEmpty() + .trim() + .ifBlank { "https://ultraprocessed-ai-proxy-894254677159.us-east1.run.app" } val releaseStoreFile = providers.environmentVariable("ZEST_RELEASE_STORE_FILE").orNull val releaseStorePassword = providers.environmentVariable("ZEST_RELEASE_STORE_PASSWORD").orNull val releaseKeyAlias = providers.environmentVariable("ZEST_RELEASE_KEY_ALIAS").orNull @@ -52,6 +57,11 @@ android { "USDA_BOOTSTRAP_API_KEY_B64", usdaBootstrapApiKeyB64.asBuildConfigStringLiteral(), ) + buildConfigField( + "String", + "PROXY_BASE_URL", + proxyBaseUrl.asBuildConfigStringLiteral(), + ) testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" vectorDrawables { diff --git a/app/src/androidTest/java/com/b2/ultraprocessed/ui/AppChromeFunctionalTest.kt b/app/src/androidTest/java/com/b2/ultraprocessed/ui/AppChromeFunctionalTest.kt index 83cf277..d92bd20 100644 --- a/app/src/androidTest/java/com/b2/ultraprocessed/ui/AppChromeFunctionalTest.kt +++ b/app/src/androidTest/java/com/b2/ultraprocessed/ui/AppChromeFunctionalTest.kt @@ -31,7 +31,6 @@ class AppChromeFunctionalTest { composeRule.setContent { UltraProcessedTheme { ScannerScreen( - hasApiKey = false, hasUsdaApiKey = false, enableLiveCamera = false, onScan = {}, diff --git a/app/src/main/java/com/b2/ultraprocessed/analysis/FoodAnalysisPipeline.kt b/app/src/main/java/com/b2/ultraprocessed/analysis/FoodAnalysisPipeline.kt index 2456a80..8601823 100644 --- a/app/src/main/java/com/b2/ultraprocessed/analysis/FoodAnalysisPipeline.kt +++ b/app/src/main/java/com/b2/ultraprocessed/analysis/FoodAnalysisPipeline.kt @@ -17,7 +17,9 @@ import com.b2.ultraprocessed.network.llm.LlmUsage import com.b2.ultraprocessed.network.llm.MultiProviderFoodLabelLlmWorkflow import com.b2.ultraprocessed.network.llm.NovaClassification import com.b2.ultraprocessed.network.llm.OpenAiCompatibleFoodLabelLlmWorkflow +import com.b2.ultraprocessed.network.llm.ProxyFoodLabelLlmWorkflow import com.b2.ultraprocessed.network.llm.SecretLlmApiKeyProvider +import com.b2.ultraprocessed.network.llm.SelectingFoodLabelLlmWorkflow import com.b2.ultraprocessed.network.usda.SecretUsdaApiKeyProvider import com.b2.ultraprocessed.network.usda.UsdaHttpClientFactory import com.b2.ultraprocessed.network.usda.UsdaApiService @@ -327,7 +329,9 @@ class FoodAnalysisPipeline( client = UsdaHttpClientFactory.create(), ), ), - llmWorkflow = MultiProviderFoodLabelLlmWorkflow( + llmWorkflow = SelectingFoodLabelLlmWorkflow( + proxyWorkflow = ProxyFoodLabelLlmWorkflow(), + byokWorkflow = MultiProviderFoodLabelLlmWorkflow( geminiWorkflow = GeminiFoodLabelLlmWorkflow( context = appContext, apiKeyProvider = SecretLlmApiKeyProvider( @@ -358,6 +362,8 @@ class FoodAnalysisPipeline( baseUrl = "https://api.groq.com/openai/v1", providerTag = "groq", ), + ), + apiKeyProvider = SecretLlmApiKeyProvider(SecretKeyManager(appContext)), ), ) } diff --git a/app/src/main/java/com/b2/ultraprocessed/network/llm/ProxyFoodLabelLlmWorkflow.kt b/app/src/main/java/com/b2/ultraprocessed/network/llm/ProxyFoodLabelLlmWorkflow.kt new file mode 100644 index 0000000..43f3663 --- /dev/null +++ b/app/src/main/java/com/b2/ultraprocessed/network/llm/ProxyFoodLabelLlmWorkflow.kt @@ -0,0 +1,307 @@ +package com.b2.ultraprocessed.network.llm + +import com.b2.ultraprocessed.BuildConfig +import com.b2.ultraprocessed.analysis.AnalysisDebugLogger +import com.b2.ultraprocessed.analysis.AnalysisTelemetry +import java.io.IOException +import java.util.concurrent.TimeUnit +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.async +import kotlinx.coroutines.suspendCancellableCoroutine +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.withContext +import okhttp3.Call +import okhttp3.Callback +import okhttp3.MediaType.Companion.toMediaType +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.RequestBody.Companion.toRequestBody +import okhttp3.Response +import org.json.JSONObject +import kotlin.coroutines.resume +import kotlin.coroutines.resumeWithException + +/** + * Routes food-label analysis through the Cloud Run proxy (`POST /analyze`), which runs Gemini via + * the runtime service account so no user API key is required. + * + * The proxy returns NOVA classification, ingredient analysis, allergens, and token usage in a + * single response, but the [FoodLabelLlmWorkflow] interface (and the pipeline that drives it) + * expects three sequential calls per scan. To guarantee **one scan = one network call** the + * `/analyze` request is run once as a de-duplicated in-flight [Deferred] keyed by the scan's + * ingredient text; all three interface methods join that same call. + * + * The de-dup also survives the pipeline's per-stage `withTimeout` + retry: the request runs in + * this workflow's own [scope], so a caller timeout cancels only the *await* (not the underlying + * call), and the retry joins the still-running request instead of issuing a second one. Token + * usage is therefore reported exactly once (on [classifyNova]). + */ +class ProxyFoodLabelLlmWorkflow( + private val baseUrl: String = BuildConfig.PROXY_BASE_URL, + private val client: OkHttpClient = ProxyHttpClientFactory.create(), + private val scope: CoroutineScope = CoroutineScope(SupervisorJob() + Dispatchers.IO), +) : FoodLabelLlmWorkflow { + + private val mutex = Mutex() + private val inFlightByExtraction = LinkedHashMap>() + private val byCorrectedIngredients = LinkedHashMap() + + override suspend fun classifyNova( + extraction: IngredientExtraction, + modelId: String, + onStatus: (String) -> Unit, + ): Result> = runStage { + val analysis = obtain(extraction) + // Usage is reported here only, so aggregation across the three stages is not tripled. + LlmStageResult(analysis.nova, analysis.usage) + } + + override suspend fun analyzeIngredientList( + extraction: IngredientExtraction, + modelId: String, + onStatus: (String) -> Unit, + ): Result> = runStage { + val analysis = obtain(extraction) + LlmStageResult(analysis.ingredients, usage = null) + } + + override suspend fun detectAllergens( + correctedIngredientNames: List, + modelId: String, + onStatus: (String) -> Unit, + ): Result> = runStage { + val cached = mutex.withLock { byCorrectedIngredients[correctedIngredientNames.cacheKey()] } + ?: run { + // classifyNova/analyzeIngredientList run first in the same scan and populate this + // cache, so a miss means the corrected list desynced. Fail loudly rather than + // silently report "no allergens" — a false negative is unsafe for an allergen check. + AnalysisDebugLogger.log("proxy_allergen_cache_miss", "names=$correctedIngredientNames") + throw IOException("Allergen analysis was lost for this scan. Please scan again.") + } + LlmStageResult(cached.allergens, usage = null) + } + + /** + * Returns the single `/analyze` result for [extraction], starting the request at most once and + * joining any in-flight or already-completed request for the same scan. + */ + private suspend fun obtain(extraction: IngredientExtraction): CachedAnalysis { + val key = extraction.cacheKey() + val deferred = mutex.withLock { + inFlightByExtraction[key]?.takeUnless { it.isCancelled } + ?: scope.async { executeAnalyze(extraction) } + .also { inFlightByExtraction.putBounded(key, it) } + } + val analysis = try { + deferred.await() + } catch (t: Throwable) { + // A non-cancellation throwable means the shared request itself failed (the workflow + // scope never cancels), so evict it to let a later attempt re-fetch. A + // CancellationException means only this awaiter was cancelled (e.g. the pipeline's + // stage timeout) — leave the request running so the retry joins the same call. + if (t !is CancellationException) { + mutex.withLock { if (inFlightByExtraction[key] === deferred) inFlightByExtraction.remove(key) } + } + throw t + } + mutex.withLock { + byCorrectedIngredients.putBounded( + analysis.ingredients.correctedIngredients.cacheKey(), + analysis, + ) + } + return analysis + } + + private suspend fun executeAnalyze(extraction: IngredientExtraction): CachedAnalysis { + val payload = JSONObject().put("ingredient_text", extraction.rawText) + if (extraction.productName.isNotBlank()) { + payload.put("product_name", extraction.productName) + } + val url = "${baseUrl.trimEnd('/')}/analyze" + AnalysisTelemetry.event("proxy_request_start url=$url") + val request = Request.Builder() + .url(url) + .header("Content-Type", "application/json") + .post(payload.toString().toRequestBody(JSON_MEDIA_TYPE)) + .build() + + val body = suspendCancellableCoroutine { continuation -> + val call = client.newCall(request) + continuation.invokeOnCancellation { call.cancel() } + call.enqueue( + object : Callback { + override fun onFailure(call: Call, e: IOException) { + if (!continuation.isCancelled) { + continuation.resumeWithException(e) + } + } + + override fun onResponse(call: Call, response: Response) { + response.use { + runCatching { readResponseBody(it) } + .onSuccess { parsed -> + if (!continuation.isCancelled) continuation.resume(parsed) + } + .onFailure { error -> + if (!continuation.isCancelled) continuation.resumeWithException(error) + } + } + } + }, + ) + } + return parseAnalyzeResponse(body) + } + + private fun readResponseBody(response: Response): String { + val raw = response.body?.string().orEmpty() + AnalysisTelemetry.event("proxy_response http=${response.code}") + AnalysisDebugLogger.log("proxy_http_body", "http=${response.code} body=${raw.take(8_000)}") + if (!response.isSuccessful) { + throw IOException(proxyErrorMessage(response.code, raw)) + } + return raw + } + + private fun proxyErrorMessage(statusCode: Int, body: String): String { + val detailMessage = runCatching { + JSONObject(body).optJSONObject("detail")?.optString("message").orEmpty() + }.getOrDefault("") + return when { + statusCode == 429 -> + "The AI service is temporarily busy (rate limit). Please wait a moment and try again." + statusCode == 422 -> + "The analysis service could not read this label. Please try again." + statusCode in 500..599 -> + "The AI service is temporarily unavailable. Please try again." + + if (detailMessage.isNotBlank()) " ($detailMessage)" else "" + else -> "Analysis service request failed with HTTP $statusCode." + } + } + + private fun parseAnalyzeResponse(body: String): CachedAnalysis { + val root = try { + JSONObject(body) + } catch (e: Exception) { + throw IOException("Analysis service returned an unreadable response.", e) + } + + val novaObj = root.optJSONObject("nova") ?: JSONObject() + val containsFood = novaObj.optBoolean("containsConsumableFoodItem", true) + val nova = NovaClassification( + novaGroup = novaObj.optInt("novaGroup", 0).coerceIn(0, 4), + summary = novaObj.optString("summary").trim(), + confidence = novaObj.optConfidence("confidence"), + warnings = novaObj.optStringList("warnings"), + containsConsumableFoodItem = containsFood, + rejectionReason = novaObj.optString("rejectionReason").trim(), + ) + + val ingredientsObj = root.optJSONObject("ingredients") ?: JSONObject() + val ingredients = IngredientListAnalysis( + correctedIngredients = ingredientsObj.optStringList("correctedIngredients"), + ultraProcessedIngredients = ingredientsObj.optRiskMarkers("ultraProcessedIngredients"), + warnings = ingredientsObj.optStringList("warnings"), + confidence = ingredientsObj.optConfidence("confidence"), + ) + + val allergensObj = root.optJSONObject("allergens") ?: JSONObject() + val allergens = AllergenDetection( + allergens = allergensObj.optStringList("allergens"), + warnings = allergensObj.optStringList("warnings"), + confidence = allergensObj.optConfidence("confidence"), + ) + + return CachedAnalysis( + nova = nova, + ingredients = ingredients, + allergens = allergens, + usage = root.optJSONObject("usage")?.toLlmUsage(), + ) + } + + /** Runs a stage body, mapping failures to [Result.failure] while letting cancellation propagate. */ + private suspend fun runStage(block: suspend () -> T): Result = withContext(Dispatchers.IO) { + try { + Result.success(block()) + } catch (c: CancellationException) { + throw c + } catch (t: Throwable) { + Result.failure(t) + } + } + + private data class CachedAnalysis( + val nova: NovaClassification, + val ingredients: IngredientListAnalysis, + val allergens: AllergenDetection, + val usage: LlmUsage?, + ) + + private fun LinkedHashMap.putBounded(key: K, value: V) { + remove(key) + put(key, value) + while (size > MAX_CACHE_ENTRIES) { + remove(keys.iterator().next()) + } + } + + companion object { + private const val MAX_CACHE_ENTRIES = 8 + private val JSON_MEDIA_TYPE = "application/json; charset=utf-8".toMediaType() + } +} + +object ProxyHttpClientFactory { + fun create(): OkHttpClient = + OkHttpClient.Builder() + .connectTimeout(5, TimeUnit.SECONDS) + .readTimeout(45, TimeUnit.SECONDS) + .writeTimeout(20, TimeUnit.SECONDS) + .callTimeout(60, TimeUnit.SECONDS) + .retryOnConnectionFailure(true) + .build() +} + +private fun IngredientExtraction.cacheKey(): String = productName + "" + rawText + +private fun List.cacheKey(): String = joinToString("") { it.trim().lowercase() } + +private fun JSONObject.optConfidence(name: String): Float = + optDouble(name, 0.5).toFloat().coerceIn(0f, 1f) + +private fun JSONObject.optStringList(name: String): List { + val array = optJSONArray(name) ?: return emptyList() + return buildList { + for (i in 0 until array.length()) { + val value = array.optString(i).trim() + if (value.isNotEmpty()) add(value) + } + } +} + +private fun JSONObject.optRiskMarkers(name: String): List { + val array = optJSONArray(name) ?: return emptyList() + return buildList { + for (i in 0 until array.length()) { + val obj = array.optJSONObject(i) ?: continue + val markerName = obj.optString("name").trim() + if (markerName.isEmpty()) continue + add(IngredientRiskMarker(name = markerName, reason = obj.optString("reason").trim())) + } + } +} + +private fun JSONObject.toLlmUsage(): LlmUsage? { + val input = optInt("inputTokens", 0).coerceAtLeast(0) + val output = optInt("outputTokens", 0).coerceAtLeast(0) + val total = optInt("totalTokens", input + output).coerceAtLeast(input + output) + if (input == 0 && output == 0 && total == 0) return null + return LlmUsage(inputTokens = input, outputTokens = output, totalTokens = total) +} diff --git a/app/src/main/java/com/b2/ultraprocessed/network/llm/SelectingFoodLabelLlmWorkflow.kt b/app/src/main/java/com/b2/ultraprocessed/network/llm/SelectingFoodLabelLlmWorkflow.kt new file mode 100644 index 0000000..9bf4f18 --- /dev/null +++ b/app/src/main/java/com/b2/ultraprocessed/network/llm/SelectingFoodLabelLlmWorkflow.kt @@ -0,0 +1,41 @@ +package com.b2.ultraprocessed.network.llm + +/** + * Picks the analysis backend per call: if the user has saved their own LLM key, route to the + * bring-your-own-key [byokWorkflow]; otherwise use the keyless Cloud Run [proxyWorkflow]. + * + * The key is re-read on every call (not cached at construction) so saving or clearing a key in + * Settings takes effect on the next scan without rebuilding the pipeline. A scan is not split + * across backends in practice: the key can only change from the Settings screen, which the user + * navigates to between scans, never mid-scan. + */ +class SelectingFoodLabelLlmWorkflow( + private val proxyWorkflow: FoodLabelLlmWorkflow, + private val byokWorkflow: FoodLabelLlmWorkflow, + private val apiKeyProvider: LlmApiKeyProvider, +) : FoodLabelLlmWorkflow { + + override suspend fun classifyNova( + extraction: IngredientExtraction, + modelId: String, + onStatus: (String) -> Unit, + ): Result> = + active().classifyNova(extraction, modelId, onStatus) + + override suspend fun analyzeIngredientList( + extraction: IngredientExtraction, + modelId: String, + onStatus: (String) -> Unit, + ): Result> = + active().analyzeIngredientList(extraction, modelId, onStatus) + + override suspend fun detectAllergens( + correctedIngredientNames: List, + modelId: String, + onStatus: (String) -> Unit, + ): Result> = + active().detectAllergens(correctedIngredientNames, modelId, onStatus) + + private fun active(): FoodLabelLlmWorkflow = + if (apiKeyProvider.getApiKey().isNotBlank()) byokWorkflow else proxyWorkflow +} diff --git a/app/src/main/java/com/b2/ultraprocessed/ui/ScannerScreen.kt b/app/src/main/java/com/b2/ultraprocessed/ui/ScannerScreen.kt index 6ca0171..c716fa3 100644 --- a/app/src/main/java/com/b2/ultraprocessed/ui/ScannerScreen.kt +++ b/app/src/main/java/com/b2/ultraprocessed/ui/ScannerScreen.kt @@ -118,7 +118,6 @@ private object ScannerMetrics { @Composable fun ScannerScreen( - hasApiKey: Boolean, hasUsdaApiKey: Boolean, enableLiveCamera: Boolean = true, onScan: (String) -> Unit, @@ -247,44 +246,10 @@ fun ScannerScreen( .background(DarkBg), ) { ScannerHomeHeader( - hasApiKey = hasApiKey, onHistory = onHistory, onSettings = onSettings, ) - if (!hasApiKey) { - Surface( - onClick = onSettings, - color = Amber500.copy(alpha = 0.08f), - shape = RoundedCornerShape(14.dp), - border = androidx.compose.foundation.BorderStroke( - 1.dp, - Amber500.copy(alpha = 0.30f), - ), - modifier = Modifier - .fillMaxWidth() - .padding(horizontal = ScannerMetrics.Space3, vertical = ScannerMetrics.Grid), - ) { - Row( - modifier = Modifier.padding(horizontal = ScannerMetrics.Space2, vertical = ScannerMetrics.Grid), - verticalAlignment = Alignment.CenterVertically, - ) { - Box( - modifier = Modifier - .size(8.dp) - .background(Amber400, CircleShape), - ) - Spacer(modifier = Modifier.width(ScannerMetrics.Grid)) - Text( - text = stringResource(R.string.scanner_missing_key_banner), - color = Amber400, - fontSize = ScannerMetrics.SecondaryText, - fontWeight = FontWeight.Bold, - ) - } - } - } - Column( modifier = Modifier .weight(1f) @@ -660,7 +625,6 @@ fun ScannerScreen( @Composable private fun ScannerHomeHeader( - hasApiKey: Boolean, onHistory: () -> Unit, onSettings: () -> Unit, ) { @@ -717,7 +681,6 @@ private fun ScannerHomeHeader( icon = Icons.Default.Settings, contentDescription = "Settings", onClick = onSettings, - badgeVisible = !hasApiKey, testTag = AppTestTags.HEADER_ACTION_SETTINGS, ) } diff --git a/app/src/main/java/com/b2/ultraprocessed/ui/UltraProcessedApp.kt b/app/src/main/java/com/b2/ultraprocessed/ui/UltraProcessedApp.kt index 9c136fd..ca33f25 100644 --- a/app/src/main/java/com/b2/ultraprocessed/ui/UltraProcessedApp.kt +++ b/app/src/main/java/com/b2/ultraprocessed/ui/UltraProcessedApp.kt @@ -236,7 +236,6 @@ fun UltraProcessedApp( ) AppDestination.Scanner -> ScannerScreen( - hasApiKey = hasLlmApiKey, hasUsdaApiKey = hasUsdaApiKey, enableLiveCamera = enableLiveCamera, onScan = { path -> diff --git a/app/src/test/java/com/b2/ultraprocessed/network/llm/ProxyFoodLabelLlmWorkflowTest.kt b/app/src/test/java/com/b2/ultraprocessed/network/llm/ProxyFoodLabelLlmWorkflowTest.kt new file mode 100644 index 0000000..486abde --- /dev/null +++ b/app/src/test/java/com/b2/ultraprocessed/network/llm/ProxyFoodLabelLlmWorkflowTest.kt @@ -0,0 +1,144 @@ +package com.b2.ultraprocessed.network.llm + +import java.util.concurrent.atomic.AtomicInteger +import kotlinx.coroutines.async +import kotlinx.coroutines.test.runTest +import okhttp3.Interceptor +import okhttp3.MediaType.Companion.toMediaType +import okhttp3.OkHttpClient +import okhttp3.Protocol +import okhttp3.Response +import okhttp3.ResponseBody.Companion.toResponseBody +import org.junit.Assert.assertEquals +import org.junit.Assert.assertNull +import org.junit.Assert.assertTrue +import org.junit.Test + +class ProxyFoodLabelLlmWorkflowTest { + + private val extraction = IngredientExtraction( + code = 0, + productName = "Test cereal", + rawText = "whole grain oats, sugar, corn syrup, salt, soy lecithin", + ingredients = listOf("whole grain oats", "sugar", "corn syrup", "salt", "soy lecithin"), + confidence = 0.6f, + warnings = emptyList(), + ) + + @Test + fun parsesResponse_andMakesExactlyOneHttpCallAcrossThreeStages() = runTest { + val callCount = AtomicInteger(0) + val workflow: FoodLabelLlmWorkflow = ProxyFoodLabelLlmWorkflow( + baseUrl = "https://proxy.test", + client = stubClient(callCount, code = 200, body = SUCCESS_BODY), + ) + + val nova = workflow.classifyNova(extraction, MODEL).getOrThrow() + val ingredients = workflow.analyzeIngredientList(extraction, MODEL).getOrThrow() + val allergens = + workflow.detectAllergens(ingredients.value.correctedIngredients, MODEL).getOrThrow() + + // One scan = one network call; stages 2 and 3 are served from cache. + assertEquals(1, callCount.get()) + + assertEquals(4, nova.value.novaGroup) + assertTrue(nova.value.containsConsumableFoodItem) + assertEquals(0.85f, nova.value.confidence, 0.0001f) + assertEquals(LlmUsage(inputTokens = 12, outputTokens = 34, totalTokens = 46), nova.usage) + + assertEquals( + listOf("water", "sugar", "palm oil", "emulsifier"), + ingredients.value.correctedIngredients, + ) + assertEquals(listOf("emulsifier"), ingredients.value.ultraProcessedIngredients.map { it.name }) + // Usage is reported only once (on the nova stage) so aggregation is not tripled. + assertNull(ingredients.usage) + + assertEquals(listOf("milk"), allergens.value.allergens) + assertNull(allergens.usage) + } + + @Test + fun concurrentStagesForSameScan_shareOneNetworkCall() = runTest { + val callCount = AtomicInteger(0) + val workflow: FoodLabelLlmWorkflow = ProxyFoodLabelLlmWorkflow( + baseUrl = "https://proxy.test", + client = stubClient(callCount, code = 200, body = SUCCESS_BODY), + ) + + val nova = async { workflow.classifyNova(extraction, MODEL) } + val ingredients = async { workflow.analyzeIngredientList(extraction, MODEL) } + + nova.await().getOrThrow() + ingredients.await().getOrThrow() + + assertEquals(1, callCount.get()) + } + + @Test + fun upstreamFailure_mapsToFriendlyFailure() = runTest { + val callCount = AtomicInteger(0) + val workflow: FoodLabelLlmWorkflow = ProxyFoodLabelLlmWorkflow( + baseUrl = "https://proxy.test", + client = stubClient(callCount, code = 502, body = ERROR_BODY), + ) + + val result = workflow.classifyNova(extraction, MODEL) + + assertTrue(result.isFailure) + val message = result.exceptionOrNull()?.message.orEmpty() + assertTrue(message, message.contains("temporarily unavailable", ignoreCase = true)) + } + + private fun stubClient(callCount: AtomicInteger, code: Int, body: String): OkHttpClient = + OkHttpClient.Builder() + .addInterceptor( + Interceptor { chain -> + callCount.incrementAndGet() + Response.Builder() + .request(chain.request()) + .protocol(Protocol.HTTP_1_1) + .code(code) + .message(if (code in 200..299) "OK" else "ERR") + .body(body.toResponseBody("application/json".toMediaType())) + .build() + }, + ) + .build() + + companion object { + private const val MODEL = "gemini-2.5-flash" + + private val SUCCESS_BODY = """ + { + "nova": { + "containsConsumableFoodItem": true, + "novaGroup": 4, + "summary": "Ultra-processed snack with additives.", + "rejectionReason": "", + "confidence": 0.85, + "warnings": [] + }, + "ingredients": { + "correctedIngredients": ["water", "sugar", "palm oil", "emulsifier"], + "ultraProcessedIngredients": [ + {"name": "emulsifier", "reason": "Common ultra-processing marker."} + ], + "confidence": 0.8, + "warnings": [] + }, + "allergens": { + "allergens": ["milk"], + "confidence": 0.6, + "warnings": [] + }, + "model": "gemini-2.5-flash", + "usage": {"inputTokens": 12, "outputTokens": 34, "totalTokens": 46} + } + """.trimIndent() + + private val ERROR_BODY = """ + {"detail": {"error": "model_call_failed", "message": "vertex unavailable"}} + """.trimIndent() + } +} diff --git a/app/src/test/java/com/b2/ultraprocessed/network/llm/SelectingFoodLabelLlmWorkflowTest.kt b/app/src/test/java/com/b2/ultraprocessed/network/llm/SelectingFoodLabelLlmWorkflowTest.kt new file mode 100644 index 0000000..63c83a0 --- /dev/null +++ b/app/src/test/java/com/b2/ultraprocessed/network/llm/SelectingFoodLabelLlmWorkflowTest.kt @@ -0,0 +1,125 @@ +package com.b2.ultraprocessed.network.llm + +import kotlinx.coroutines.test.runTest +import org.junit.Assert.assertEquals +import org.junit.Test + +class SelectingFoodLabelLlmWorkflowTest { + + private val extraction = IngredientExtraction( + code = 0, + productName = "Test", + rawText = "sugar, salt", + ingredients = listOf("sugar", "salt"), + confidence = 0.6f, + warnings = emptyList(), + ) + + @Test + fun routesToProxy_whenNoKey() = runTest { + val proxy = RecordingWorkflow() + val byok = RecordingWorkflow() + val selecting = SelectingFoodLabelLlmWorkflow(proxy, byok, FakeKeyProvider("")) + + selecting.classifyNova(extraction, MODEL) + selecting.analyzeIngredientList(extraction, MODEL) + selecting.detectAllergens(listOf("sugar"), MODEL) + + assertEquals(3, proxy.totalCalls) + assertEquals(0, byok.totalCalls) + } + + @Test + fun routesToByok_whenKeyPresent() = runTest { + val proxy = RecordingWorkflow() + val byok = RecordingWorkflow() + val selecting = SelectingFoodLabelLlmWorkflow(proxy, byok, FakeKeyProvider("aiza-some-key")) + + selecting.classifyNova(extraction, MODEL) + selecting.detectAllergens(listOf("sugar"), MODEL) + + assertEquals(0, proxy.totalCalls) + assertEquals(2, byok.totalCalls) + } + + @Test + fun reReadsKeyPerCall() = runTest { + val proxy = RecordingWorkflow() + val byok = RecordingWorkflow() + val keyProvider = FakeKeyProvider("") + val selecting = SelectingFoodLabelLlmWorkflow(proxy, byok, keyProvider) + + selecting.classifyNova(extraction, MODEL) // no key -> proxy + keyProvider.key = "aiza-now-present" + selecting.classifyNova(extraction, MODEL) // key added -> byok + + assertEquals(1, proxy.totalCalls) + assertEquals(1, byok.totalCalls) + } + + private class FakeKeyProvider(var key: String) : LlmApiKeyProvider { + override fun getApiKey(): String = key + } + + private class RecordingWorkflow : FoodLabelLlmWorkflow { + var totalCalls = 0 + private set + + override suspend fun classifyNova( + extraction: IngredientExtraction, + modelId: String, + onStatus: (String) -> Unit, + ): Result> { + totalCalls++ + return Result.success( + LlmStageResult( + NovaClassification( + novaGroup = 1, + summary = "", + confidence = 0.5f, + warnings = emptyList(), + ), + ), + ) + } + + override suspend fun analyzeIngredientList( + extraction: IngredientExtraction, + modelId: String, + onStatus: (String) -> Unit, + ): Result> { + totalCalls++ + return Result.success( + LlmStageResult( + IngredientListAnalysis( + correctedIngredients = emptyList(), + ultraProcessedIngredients = emptyList(), + warnings = emptyList(), + confidence = 0.5f, + ), + ), + ) + } + + override suspend fun detectAllergens( + correctedIngredientNames: List, + modelId: String, + onStatus: (String) -> Unit, + ): Result> { + totalCalls++ + return Result.success( + LlmStageResult( + AllergenDetection( + allergens = emptyList(), + warnings = emptyList(), + confidence = 0.5f, + ), + ), + ) + } + } + + companion object { + private const val MODEL = "gemini-2.5-flash" + } +}