From 2998fbcfb6d2638a6820acfcf3498268d7fa0d72 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Fri, 8 May 2026 12:45:38 -0700 Subject: [PATCH 01/13] Add v1 OpenAI Endpoint support and remove legacy completions API --- .../synapse/ml/services/openai/OpenAI.scala | 76 ++++-- .../openai/OpenAIChatCompletion.scala | 8 +- .../ml/services/openai/OpenAICompletion.scala | 75 ------ .../ml/services/openai/OpenAIEmbedding.scala | 18 +- .../ml/services/openai/OpenAIPrompt.scala | 18 +- .../ml/services/openai/OpenAIResponses.scala | 9 +- .../ml/services/openai/OpenAISchemas.scala | 19 -- .../openai/OpenAICompletionSuite.scala | 88 ------- .../openai/OpenAIV1EndpointSuite.scala | 155 ++++++++++++ docs/Explore Algorithms/OpenAI/OpenAI.ipynb | 226 +----------------- ...- OpenAI Embedding and GPU based KNN.ipynb | 2 +- .../Quickstart - OpenAI Embedding.ipynb | 2 +- .../Set up Cognitive Services.ipynb | 2 +- tools/docgen/docgen/manifest.yaml | 4 +- 14 files changed, 251 insertions(+), 451 deletions(-) delete mode 100644 cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala delete mode 100644 cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletionSuite.scala create mode 100644 cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIV1EndpointSuite.scala diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index 2af6b308ce0..31af1532333 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -14,32 +14,10 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.types._ import spray.json.DefaultJsonProtocol._ +import java.net.URI +import java.util.Locale import scala.language.existentials - -trait HasPromptInputs extends HasServiceParams { - val prompt: ServiceParam[String] = new ServiceParam[String]( - this, "prompt", "The text to complete", isRequired = false) - - def getPrompt: String = getScalarParam(prompt) - - def setPrompt(v: String): this.type = setScalarParam(prompt, v) - - def getPromptCol: String = getVectorParam(prompt) - - def setPromptCol(v: String): this.type = setVectorParam(prompt, v) - - val batchPrompt: ServiceParam[Seq[String]] = new ServiceParam[Seq[String]]( - this, "batchPrompt", "Sequence of prompts to complete", isRequired = false) - - def getBatchPrompt: Seq[String] = getScalarParam(batchPrompt) - - def setBatchPrompt(v: Seq[String]): this.type = setScalarParam(batchPrompt, v) - - def getBatchPromptCol: String = getVectorParam(batchPrompt) - - def setBatchPromptCol(v: String): this.type = setVectorParam(batchPrompt, v) - -} +import scala.util.Try trait HasMessagesInput extends Params { val messagesCol: Param[String] = new Param[String]( @@ -54,6 +32,21 @@ trait HasMessagesInput extends Params { case object OpenAIDeploymentNameKey extends GlobalKey[Either[String, String]] case object OpenAIEmbeddingDeploymentNameKey extends GlobalKey[Either[String, String]] +private[openai] object OpenAIEndpointUtils { + private def stripTrailingSlashes(value: String): String = value.replaceAll("/+$", "") + + def appendPath(baseUrl: String, path: String): String = { + val normalizedBase = if (baseUrl.endsWith("/")) baseUrl else baseUrl + "/" + normalizedBase + path.stripPrefix("/") + } + + def isV1BaseUrl(baseUrl: String): Boolean = { + val path = Try(new URI(baseUrl.trim).getPath).toOption.getOrElse("") + val normalizedPath = stripTrailingSlashes(path).toLowerCase(Locale.ROOT) + normalizedPath == "/v1" || normalizedPath.endsWith("/openai/v1") + } +} + trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion { val deploymentName = new ServiceParam[String]( @@ -137,7 +130,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { "The maximum number of completion tokens to generate. Has minimum of 0." + " Works with both reasoning and non-reasoning models." + " Sent as max_completion_tokens for chat completions," + - " max_output_tokens for responses API, and max_tokens for legacy completions.", + " and max_output_tokens for responses API.", isRequired = false) { override val payloadName: String = "max_completion_tokens" } @@ -456,6 +449,37 @@ abstract class OpenAIServicesBase(override val uid: String) extends CognitiveSer with HasOpenAISharedParams with OpenAIFabricSetting { setDefault(timeout -> 360.0) + protected[openai] def isOpenAIV1BaseUrl: Boolean = + get(url).orElse(getDefault(url)).exists(OpenAIEndpointUtils.isV1BaseUrl) + + protected[openai] def endpointUrl(path: String): String = OpenAIEndpointUtils.appendPath(getUrl, path) + + protected[openai] def withV1DeploymentModel(params: Map[String, Any], row: Row): Map[String, Any] = { + if (isOpenAIV1BaseUrl && !params.contains("model")) { + params.updated("model", getValue(row, deploymentName)) + } else { + params + } + } + + private def warnIfV1ApiVersionConfigured(): Unit = { + if (isOpenAIV1BaseUrl && get(apiVersion).nonEmpty) { + logWarning( + "apiVersion is ignored when the OpenAI URL is a v1 base URL. " + + "Remove apiVersion or use an Azure OpenAI endpoint without /openai/v1.") + } + } + + override protected def getUrlParams: Array[ServiceParam[_]] = { + val params = super.getUrlParams + if (isOpenAIV1BaseUrl) { + warnIfV1ApiVersionConfigured() + params.filterNot(_.name == apiVersion.name) + } else { + params + } + } + private def usingDefaultOpenAIEndpoint(): Boolean = { getUrl == FabricClient.MLWorkloadEndpointML + "/cognitive/openai/" } diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala index eee8b0b8c5a..e915c4d4677 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala @@ -113,7 +113,11 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase( } override protected def prepareUrlRoot: Row => String = { row => - s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/chat/completions" + if (isOpenAIV1BaseUrl) { + endpointUrl("chat/completions") + } else { + endpointUrl(s"openai/deployments/${getValue(row, deploymentName)}/chat/completions") + } } override private[ml] def getOptionalParams(r: Row): Map[String, Any] = { @@ -125,7 +129,7 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase( r => lazy val optionalParams: Map[String, Any] = getOptionalParams(r) val messages = r.getAs[Seq[Row]](getMessagesCol) - Some(getStringEntity(messages, optionalParams)) + Some(getStringEntity(messages, withV1DeploymentModel(optionalParams, r))) } override val subscriptionKeyHeaderName: String = "api-key" diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala deleted file mode 100644 index 4b5b26a84b5..00000000000 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (C) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See LICENSE in project root for information. - -package com.microsoft.azure.synapse.ml.services.openai - -import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} -import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat -import com.microsoft.azure.synapse.ml.services.{HasCognitiveServiceInput, HasInternalJsonOutputParser} -import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity} -import org.apache.spark.ml.ComplexParamsReadable -import org.apache.spark.ml.util._ -import org.apache.spark.sql.{functions => F, Row} -import org.apache.spark.sql.types._ -import spray.json.DefaultJsonProtocol._ -import spray.json._ - -import scala.language.existentials - -object OpenAICompletion extends ComplexParamsReadable[OpenAICompletion] - -class OpenAICompletion(override val uid: String) extends OpenAIServicesBase(uid) - with HasOpenAITextParams with HasPromptInputs with HasCognitiveServiceInput - with HasInternalJsonOutputParser with SynapseMLLogging with HasTextOutput { - logClass(FeatureNames.AiServices.OpenAI) - - def this() = this(Identifiable.randomUID("OpenAICompletion")) - - def urlPath: String = "" - - override private[ml] def internalServiceType: String = "openai" - - setDefault(apiVersion -> Left("2024-02-01")) - - override def setCustomServiceName(v: String): this.type = { - setUrl(s"https://$v.openai.azure.com/" + urlPath.stripPrefix("/")) - } - - override protected def prepareUrlRoot: Row => String = { row => - s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/completions" - } - - override private[ml] def getOptionalParams(r: Row): Map[String, Any] = { - val base = super.getOptionalParams(r) - resolveMaxTokens(base, "max_tokens") - } - - override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = { - r => - lazy val optionalParams: Map[String, Any] = getOptionalParams(r) - getValueOpt(r, prompt) - .map(prompt => getStringEntity(prompt, optionalParams)) - .orElse(getValueOpt(r, batchPrompt) - .map(batchPrompt => getStringEntity(batchPrompt, optionalParams))) - .orElse(throw new IllegalArgumentException( - "Please set one of prompt, batchPrompt, indexPrompt or batchIndexPrompt.")) - } - - override val subscriptionKeyHeaderName: String = "api-key" - - override def shouldSkip(row: Row): Boolean = - super.shouldSkip(row) || - (emptyParamData(row, prompt) && emptyParamData(row, batchPrompt)) - - override def responseDataType: DataType = CompletionResponse.schema - - private[this] def getStringEntity[A](prompt: A, optionalParams: Map[String, Any]): StringEntity = { - val fullPayload = optionalParams.updated("prompt", prompt) - new StringEntity(fullPayload.toJson.compactPrint, ContentType.APPLICATION_JSON) - } - - override private[openai] def getOutputMessageText(outputColName: String): org.apache.spark.sql.Column = { - F.element_at(F.col(outputColName).getField("choices"), 1).getField("text") - } - -} diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala index 16821707020..6a67dfb92be 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala @@ -64,10 +64,19 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid) } override protected def prepareUrlRoot: Row => String = { row => + val dep = getEmbeddingDeployment(row) + if (isOpenAIV1BaseUrl) { + endpointUrl("embeddings") + } else { + endpointUrl(s"openai/deployments/$dep/embeddings") + } + } + + private[this] def getEmbeddingDeployment(row: Row): String = { val globalEmbeddingDeployment = GlobalParams.getGlobalParam(OpenAIEmbeddingDeploymentNameKey).flatMap(_.left.toOption) - val dep = globalEmbeddingDeployment.orElse { + globalEmbeddingDeployment.orElse { // If embedding-specific deployment is not set, check instance param if (isSet(deploymentName)) { getValueOpt(row, deploymentName) @@ -77,8 +86,6 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid) }.getOrElse(throw new IllegalArgumentException( "No embedding deployment name provided. Set the 'deploymentName' param or call " + "OpenAIDefaults.setEmbeddingDeploymentName('') to set a global default.")) - - s"${getUrl}openai/deployments/$dep/embeddings" } private[this] def getStringEntity[A](text: A, optionalParams: Map[String, Any]): StringEntity = { @@ -88,7 +95,10 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid) override protected def prepareEntity: Row => Option[AbstractHttpEntity] = { r => - lazy val optionalParams: Map[String, Any] = getOptionalParams(r) + lazy val optionalParams: Map[String, Any] = { + val params = getOptionalParams(r) + if (isOpenAIV1BaseUrl) params.updated("model", getEmbeddingDeployment(r)) else params + } getValueOpt(r, text) .map(text => getStringEntity(text, optionalParams)) .orElse(throw new IllegalArgumentException( diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index fbf1d584285..37787aa861c 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -284,8 +284,6 @@ class OpenAIPrompt(override val uid: String) extends Transformer df: DataFrame, messagesCol: Column ): (DataFrame, String, OpenAIServicesBase with HasTextOutput) = { - // All services are now HasMessagesInput (OpenAIChatCompletion, OpenAIResponses, AIFoundryChatCompletion) - // Legacy OpenAICompletion did not support MessagesInput which is no longer used in this class. val messagesService = service.asInstanceOf[HasMessagesInput] if (isSet(responseFormat)) { @@ -639,7 +637,12 @@ class OpenAIPrompt(override val uid: String) extends Transformer host.exists(_.toLowerCase.endsWith("services.ai.azure.com")) } - private[openai] def hasAIFoundryModel: Boolean = this.isDefined(model) && isAIFoundryEndpoint + private def isOpenAIV1Endpoint: Boolean = { + get(url).orElse(getDefault(url)).exists(OpenAIEndpointUtils.isV1BaseUrl) + } + + private[openai] def hasAIFoundryModel: Boolean = + this.isDefined(model) && isAIFoundryEndpoint && !isOpenAIV1Endpoint //deployment name can be set by user, it doesn't have to match with model name private def getOpenAIChatService: OpenAIServicesBase with HasTextOutput = { @@ -658,11 +661,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer .filter(p => !localParamNames.contains(p.param.name) && completion.hasParam(p.param.name)) .foreach(p => completion.set(completion.getParam(p.param.name), p.value)) - completion match { - case resp: OpenAIResponses - if this.isDefined(model) && get(deploymentName).orElse(getDefault(deploymentName)).isEmpty => - resp.setDeploymentName(getModel) - case _ => + if (this.isDefined(model) && + get(deploymentName).orElse(getDefault(deploymentName)).isEmpty && + (isOpenAIV1Endpoint || completion.isInstanceOf[OpenAIResponses])) { + completion.setDeploymentName(getModel) } completion diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIResponses.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIResponses.scala index 53fa65ba6db..21ead3b99b3 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIResponses.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIResponses.scala @@ -142,7 +142,11 @@ class OpenAIResponses(override val uid: String) extends OpenAIServicesBase(uid) } override protected def prepareUrlRoot: Row => String = { row => - s"${getUrl}openai/responses" + if (isOpenAIV1BaseUrl) { + endpointUrl("responses") + } else { + endpointUrl("openai/responses") + } } override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = { @@ -164,6 +168,9 @@ class OpenAIResponses(override val uid: String) extends OpenAIServicesBase(uid) private def mergeModel(params: Map[String, Any], r: Row): Map[String, Any] = { getValueOpt(r, deploymentName) match { case Some(m) if m != null && m.nonEmpty => params.updated("model", m) + case _ if isOpenAIV1BaseUrl && !params.contains("model") => + throw new IllegalArgumentException( + "No deployment/model name provided for OpenAI v1 endpoint. Set the 'deploymentName' param.") case _ => params } } diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAISchemas.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAISchemas.scala index 5f4c34ef61e..6c0f78b218a 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAISchemas.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAISchemas.scala @@ -4,27 +4,8 @@ package com.microsoft.azure.synapse.ml.services.openai import com.microsoft.azure.synapse.ml.core.schema.SparkBindings -import org.apache.spark.sql.Row import spray.json.{DefaultJsonProtocol, RootJsonFormat} -object CompletionResponse extends SparkBindings[CompletionResponse] - -case class CompletionResponse(id: String, - `object`: String, - created: String, - model: String, - choices: Seq[OpenAIChoice]) - -case class OpenAIChoice(text: String, - index: Long, - logprobs: Option[OpenAILogProbs], - finish_reason: String) - -case class OpenAILogProbs(tokens: Seq[String], - token_logprobs: Seq[Double], - top_logprobs: Seq[Map[String, Double]], - text_offset: Seq[Long]) - object EmbeddingUsage extends SparkBindings[EmbeddingUsage] case class EmbeddingUsage(prompt_tokens: Long, diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletionSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletionSuite.scala deleted file mode 100644 index 997838b2841..00000000000 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletionSuite.scala +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (C) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See LICENSE in project root for information. - -package com.microsoft.azure.synapse.ml.services.openai - -import com.microsoft.azure.synapse.ml.Secrets -import com.microsoft.azure.synapse.ml.Secrets.getAccessToken -import com.microsoft.azure.synapse.ml.core.test.base.Flaky -import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} -import org.apache.spark.ml.util.MLReadable -import org.apache.spark.sql.{DataFrame, Row} - -class OpenAICompletionSuite extends TransformerFuzzing[OpenAICompletion] with OpenAIAPIKey with Flaky { - override val compareDataInSerializationTest: Boolean = false - - - import spark.implicits._ - - override def beforeAll(): Unit = { - val aadToken = getAccessToken("https://cognitiveservices.azure.com/") - println(s"Triggering token creation early ${aadToken.length}") - super.beforeAll() - } - - def newCompletion: OpenAICompletion = new OpenAICompletion() - .setDeploymentName(deploymentName) - .setCustomServiceName(openAIServiceName) - .setMaxTokens(200) - .setOutputCol("out") - .setSubscriptionKey(openAIAPIKey) - - lazy val promptCompletion: OpenAICompletion = newCompletion.setPromptCol("prompt") - lazy val batchPromptCompletion: OpenAICompletion = newCompletion.setBatchPromptCol("batchPrompt") - - lazy val df: DataFrame = Seq( - "Once upon a time", - "Best programming language award goes to", - "SynapseML is " - ).toDF("prompt") - - lazy val promptDF: DataFrame = Seq( - "Once upon a time", - "Best programming language award goes to", - "SynapseML is " - ).toDF("prompt") - - lazy val batchPromptDF: DataFrame = Seq( - Seq( - "This is a test", - "Now is the time", - "Knock, knock") - ).toDF("batchPrompt") - - ignore("Basic Usage") { - testCompletion(promptCompletion, promptDF) - } - - ignore("Basic usage with AAD auth") { - val aadToken = getAccessToken("https://cognitiveservices.azure.com/") - - val completion = new OpenAICompletion() - .setAADToken(aadToken) - .setDeploymentName(deploymentName) - .setCustomServiceName(openAIServiceName) - .setPromptCol("prompt") - .setOutputCol("out") - - testCompletion(completion, promptDF) - } - - ignore("Batch Prompt") { - testCompletion(batchPromptCompletion, batchPromptDF) - } - - def testCompletion(completion: OpenAICompletion, df: DataFrame, requiredLength: Int = 10): Unit = { - val fromRow = CompletionResponse.makeFromRowConverter - completion.transform(df).collect().foreach(r => - fromRow(r.getAs[Row]("out")).choices.foreach(c => - assert(c.text.length > requiredLength))) - } - - - override def testObjects(): Seq[TestObject[OpenAICompletion]] = - Seq(new TestObject(newCompletion, df)) - - override def reader: MLReadable[_] = OpenAICompletion - -} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIV1EndpointSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIV1EndpointSuite.scala new file mode 100644 index 00000000000..b2bf023d0be --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIV1EndpointSuite.scala @@ -0,0 +1,155 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.openai + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.http.entity.AbstractHttpEntity +import org.apache.http.util.EntityUtils +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} +import spray.json._ + +class OpenAIV1EndpointSuite extends TestBase { + + import spark.implicits._ + + private class InspectableChatCompletion extends OpenAIChatCompletion { + def requestUrl(row: Row): String = prepareUrl.apply(row) + def requestPayload(row: Row): JsObject = + EntityUtils.toString(prepareEntity.apply(row).get).parseJson.asJsObject + } + + private class InspectableEmbedding extends OpenAIEmbedding { + def requestUrl(row: Row): String = prepareUrl.apply(row) + def requestPayload(row: Row): JsObject = + EntityUtils.toString(prepareEntity.apply(row).get).parseJson.asJsObject + } + + private class InspectableResponses extends OpenAIResponses { + def requestUrl(row: Row): String = prepareUrl.apply(row) + def requestPayload(row: Row): JsObject = + EntityUtils.toString(prepareEntity.apply(row).get).parseJson.asJsObject + } + + private val messageSchema = StructType(Seq( + StructField("role", StringType, nullable = false), + StructField("content", StringType, nullable = true), + StructField("name", StringType, nullable = true) + )) + + private val messagesRequestSchema = StructType(Seq( + StructField("messages", ArrayType(messageSchema, containsNull = false), nullable = true) + )) + + private def messagesRow: Row = { + val message = new GenericRowWithSchema( + Array[Any]("user", "hello", null), // scalastyle:ignore null + messageSchema + ) + new GenericRowWithSchema(Array[Any](Seq(message)), messagesRequestSchema) + } + + test("chat completions uses OpenAI v1 base URL without api-version and sends model") { + val transformer = new InspectableChatCompletion() + .setUrl("https://example.services.ai.azure.com/openai/v1") + .setDeploymentName("gpt-4o") + .setMessagesCol("messages") + .setApiVersion("2025-04-01-preview") + + val row = messagesRow + assert(transformer.requestUrl(row) == "https://example.services.ai.azure.com/openai/v1/chat/completions") + + val payload = transformer.requestPayload(row) + assert(payload.fields.get("model").contains(JsString("gpt-4o"))) + assert(payload.fields.contains("messages")) + } + + test("chat completions keeps legacy Azure deployment URL and api-version") { + val transformer = new InspectableChatCompletion() + .setUrl("https://example.openai.azure.com/") + .setDeploymentName("gpt-4o") + .setMessagesCol("messages") + .setApiVersion("2025-04-01-preview") + + val row = messagesRow + assert(transformer.requestUrl(row) == + "https://example.openai.azure.com/openai/deployments/gpt-4o/chat/completions" + + "?api-version=2025-04-01-preview") + assert(!transformer.requestPayload(row).fields.contains("model")) + } + + test("embeddings uses OpenAI v1 base URL and sends deployment as model") { + val transformer = new InspectableEmbedding() + .setUrl("https://example.services.ai.azure.com/openai/v1") + .setDeploymentName("text-embedding-3-large") + .setTextCol("text") + .setApiVersion("2025-04-01-preview") + + val row = Seq("hello").toDF("text").collect().head + assert(transformer.requestUrl(row) == "https://example.services.ai.azure.com/openai/v1/embeddings") + + val payload = transformer.requestPayload(row) + assert(payload.fields.get("model").contains(JsString("text-embedding-3-large"))) + assert(payload.fields.get("input").contains(JsString("hello"))) + } + + test("embeddings keeps legacy Azure deployment URL and api-version") { + val transformer = new InspectableEmbedding() + .setUrl("https://example.openai.azure.com/") + .setDeploymentName("text-embedding-3-large") + .setTextCol("text") + .setApiVersion("2025-04-01-preview") + + val row = Seq("hello").toDF("text").collect().head + assert(transformer.requestUrl(row) == + "https://example.openai.azure.com/openai/deployments/text-embedding-3-large/embeddings" + + "?api-version=2025-04-01-preview") + + val payload = transformer.requestPayload(row) + assert(!payload.fields.contains("model")) + assert(payload.fields.get("input").contains(JsString("hello"))) + } + + test("responses uses OpenAI v1 base URL without api-version") { + val transformer = new InspectableResponses() + .setUrl("https://example.services.ai.azure.com/openai/v1") + .setDeploymentName("gpt-5-mini") + .setMessagesCol("messages") + .setApiVersion("2025-04-01-preview") + + val row = messagesRow + assert(transformer.requestUrl(row) == "https://example.services.ai.azure.com/openai/v1/responses") + + val payload = transformer.requestPayload(row) + assert(payload.fields.get("model").contains(JsString("gpt-5-mini"))) + assert(payload.fields.contains("input")) + } + + test("responses keeps legacy Azure URL shape when URL is not an OpenAI v1 base") { + val transformer = new InspectableResponses() + .setUrl("https://example.openai.azure.com/") + .setDeploymentName("gpt-5-mini") + .setMessagesCol("messages") + .setApiVersion("2025-04-01-preview") + + assert(transformer.requestUrl(messagesRow) == + "https://example.openai.azure.com/openai/responses?api-version=2025-04-01-preview") + } + + test("OpenAIPrompt treats services.ai.azure.com/openai/v1 as OpenAI v1, not models chat endpoint") { + val prompt = new OpenAIPrompt() + .setUrl("https://example.services.ai.azure.com/openai/v1") + .setModel("gpt-4o") + .setMessagesCol("messages") + + val prepareEntity = classOf[OpenAIPrompt].getDeclaredMethod("prepareEntity") + prepareEntity.setAccessible(true) + val buildEntity = prepareEntity.invoke(prompt).asInstanceOf[Row => Option[AbstractHttpEntity]] + + val payload = EntityUtils.toString(buildEntity(messagesRow).get).parseJson.asJsObject + assert(payload.fields.get("model").contains(JsString("gpt-4o"))) + assert(payload.fields.contains("messages")) + } +} diff --git a/docs/Explore Algorithms/OpenAI/OpenAI.ipynb b/docs/Explore Algorithms/OpenAI/OpenAI.ipynb index 39d125cd7bb..614ec3d9e4d 100644 --- a/docs/Explore Algorithms/OpenAI/OpenAI.ipynb +++ b/docs/Explore Algorithms/OpenAI/OpenAI.ipynb @@ -7,7 +7,7 @@ "source": [ "# Azure OpenAI for big data\n", "\n", - "The Azure OpenAI service can be used to solve a large number of natural language tasks through prompting the completion API. To make it easier to scale your prompting workflows from a few examples to large datasets of examples, we have integrated the Azure OpenAI service with the distributed machine learning library [SynapseML](https://www.microsoft.com/en-us/research/blog/synapseml-a-simple-multilingual-and-massively-parallel-machine-learning-library/). This integration makes it easy to use the [Apache Spark](https://spark.apache.org/) distributed computing framework to process millions of prompts with the OpenAI service. This tutorial shows how to apply large language models at a distributed scale using Azure OpenAI. " + "The Azure OpenAI service can be used to solve a large number of natural language tasks through chat, responses, and embedding APIs. To make it easier to scale your prompting workflows from a few examples to large datasets of examples, we have integrated the Azure OpenAI service with the distributed machine learning library [SynapseML](https://www.microsoft.com/en-us/research/blog/synapseml-a-simple-multilingual-and-massively-parallel-machine-learning-library/). This integration makes it easy to use the [Apache Spark](https://spark.apache.org/) distributed computing framework to process millions of prompts with the OpenAI service. This tutorial shows how to apply large language models at a distributed scale using Azure OpenAI.\n" ] }, { @@ -262,229 +262,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## (Legacy) Create the OpenAICompletion Apache Spark Client\n", + "## Retired Completions API\n", "\n", - "To apply the OpenAI Completion service to your dataframe you created, create an OpenAICompletion object, which serves as a distributed client. Parameters of the service can be set either with a single value, or by a column of the dataframe with the appropriate setters on the `OpenAICompletion` object. Here we're setting `maxTokens` to 200. A token is around four characters, and this limit applies to the sum of the prompt and the result. We're also setting the `promptCol` parameter with the name of the prompt column in the dataframe." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from synapse.ml.services.openai import OpenAICompletion\n", - "\n", - "completion = (\n", - " OpenAICompletion()\n", - " .setSubscriptionKey(key)\n", - " .setDeploymentName(deployment_name)\n", - " .setCustomServiceName(service_name)\n", - " .setMaxTokens(200)\n", - " .setPromptCol(\"prompt\")\n", - " .setErrorCol(\"error\")\n", - " .setOutputCol(\"completions\")\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## (Legacy) Transform the dataframe with the OpenAICompletion Client\n", - "\n", - "After creating the dataframe and the completion client, you can transform your input dataset and add a column called `completions` with all of the information the service adds. Select just the text for simplicity." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from pyspark.sql.functions import col\n", - "\n", - "completed_df = completion.transform(df).cache()\n", - "display(\n", - " completed_df.select(\n", - " col(\"prompt\"),\n", - " col(\"error\"),\n", - " col(\"completions.choices.text\").getItem(0).alias(\"text\"),\n", - " ).show(truncate=False)\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Your output should look something like this. The completion text will be different from the sample.\n", - "\n", - "| **prompt** \t| **error** \t| **text** \t|\n", - "|:----------------------------:\t|:----------:\t|:-------------------------------------------------------------------------------------------------------------------------------------:\t|\n", - "| Hello my name is \t| null \t| Makaveli I'm eighteen years old and I want to be a rapper when I grow up I love writing and making music I'm from Los Angeles, CA \t|\n", - "| The best code is code thats \t| null \t| understandable This is a subjective statement, and there is no definitive answer. \t|\n", - "| SynapseML is \t| null \t| A machine learning algorithm that is able to learn how to predict the future outcome of events. \t|" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Improve throughput with request batching for OpenAICompletion\n", - "\n", - "The example makes several requests to the service, one for each prompt. To complete multiple prompts in a single request, use batch mode. First, in the OpenAICompletion object, instead of setting the Prompt column to \"Prompt\", specify \"batchPrompt\" for the BatchPrompt column.\n", - "To do so, create a dataframe with a list of prompts per row.\n", - "\n", - "As of this writing there's currently a limit of 20 prompts in a single request, and a hard limit of 2048 \"tokens\", or approximately 1500 words." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "batch_df = spark.createDataFrame(\n", - " [\n", - " ([\"The time has come\", \"Pleased to\", \"Today stocks\", \"Here's to\"],),\n", - " ([\"The only thing\", \"Ask not what\", \"Every litter\", \"I am\"],),\n", - " ]\n", - ").toDF(\"batchPrompt\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next we create the OpenAICompletion object. Rather than setting the prompt column, set the batchPrompt column if your column is of type `Array[String]`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "batch_completion = (\n", - " OpenAICompletion()\n", - " .setSubscriptionKey(key)\n", - " .setDeploymentName(deployment_name)\n", - " .setCustomServiceName(service_name)\n", - " .setMaxTokens(200)\n", - " .setBatchPromptCol(\"batchPrompt\")\n", - " .setErrorCol(\"error\")\n", - " .setOutputCol(\"completions\")\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In the call to transform, a request will be made per row. Since there are multiple prompts in a single row, each request is sent with all prompts in that row. The results contain a row for each row in the request." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "completed_batch_df = batch_completion.transform(batch_df).cache()\n", - "display(completed_batch_df.show(truncate=False))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Using an automatic minibatcher\n", - "\n", - "If your data is in column format, you can transpose it to row format using SynapseML's `FixedMiniBatcherTransformer`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from pyspark.sql.types import StringType\n", - "from synapse.ml.stages import FixedMiniBatchTransformer\n", - "from synapse.ml.core.spark import FluentAPI\n", - "\n", - "completed_autobatch_df = (\n", - " df.coalesce(\n", - " 1\n", - " ) # Force a single partition so that our little 4-row dataframe makes a batch of size 4, you can remove this step for large datasets\n", - " .mlTransform(FixedMiniBatchTransformer(batchSize=4))\n", - " .withColumnRenamed(\"prompt\", \"batchPrompt\")\n", - " .mlTransform(batch_completion)\n", - ")\n", - "\n", - "display(completed_autobatch_df.show(truncate=False))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Prompt engineering for translation\n", - "\n", - "The Azure OpenAI service can solve many different natural language tasks through [prompt engineering](https://docs.microsoft.com/en-us/azure/cognitive-services/openai/how-to/completions). Here, we show an example of prompting for language translation:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "translate_df = spark.createDataFrame(\n", - " [\n", - " (\"Japanese: Ookina hako English: Big box Japanese: Midori takoEnglish:\",),\n", - " (\n", - " \"French: Quel heure et il au Montreal? English: What time is it in Montreal? French: Ou est le poulet? English:\",\n", - " ),\n", - " ]\n", - ").toDF(\"prompt\")\n", - "\n", - "display(completion.transform(translate_df).show(truncate=False))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Prompt for question answering\n", - "\n", - "Here, we prompt GPT-3 for general-knowledge question answering:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "qa_df = spark.createDataFrame(\n", - " [\n", - " (\n", - " \"Q: Where is the Grand Canyon?A: The Grand Canyon is in Arizona.Q: What is the weight of the Burj Khalifa in kilograms?A:\",\n", - " )\n", - " ]\n", - ").toDF(\"prompt\")\n", - "\n", - "display(completion.transform(qa_df).show(truncate=False))" + "The `OpenAICompletion` transformer has been removed because the legacy Completions API is deprecated and retired. Use `OpenAIChatCompletion`, `OpenAIPrompt` with `chat_completions` or `responses`, or `OpenAIResponses` for text generation workloads.\n" ] }, { diff --git a/docs/Explore Algorithms/OpenAI/Quickstart - OpenAI Embedding and GPU based KNN.ipynb b/docs/Explore Algorithms/OpenAI/Quickstart - OpenAI Embedding and GPU based KNN.ipynb index 6e90974a480..82ae3f185cc 100644 --- a/docs/Explore Algorithms/OpenAI/Quickstart - OpenAI Embedding and GPU based KNN.ipynb +++ b/docs/Explore Algorithms/OpenAI/Quickstart - OpenAI Embedding and GPU based KNN.ipynb @@ -17,7 +17,7 @@ "source": [ "# Embedding Text with Azure OpenAI and GPU based KNN\n", "\n", - "The Azure OpenAI service can be used to solve a large number of natural language tasks through prompting the completion API. To make it easier to scale your prompting workflows from a few examples to large datasets of examples we have integrated the Azure OpenAI service with the distributed machine learning library [Spark Rapids ML](https://github.com/NVIDIA/spark-rapids-ml/). This integration makes it easy to use the [Apache Spark](https://spark.apache.org/) distributed computing framework to process millions of prompts with the OpenAI service. This tutorial shows how to apply large language models to generate embeddings for large datasets of text. This demo is based on \"Quickstart - OpenAI Embedding\" notebook with NVIDIA GPU accelerated KNN.\n", + "The Azure OpenAI service can be used to generate embeddings for large datasets of text. To make it easier to scale your embedding workflows from a few examples to large datasets of examples we have integrated the Azure OpenAI service with the distributed machine learning library [Spark Rapids ML](https://github.com/NVIDIA/spark-rapids-ml/). This integration makes it easy to use the [Apache Spark](https://spark.apache.org/) distributed computing framework to process millions of inputs with the OpenAI service. This tutorial shows how to apply large language models to generate embeddings for large datasets of text. This demo is based on \"Quickstart - OpenAI Embedding\" notebook with NVIDIA GPU accelerated KNN.\n", "\n", "**Note**: Running the notebook with the demo dataset (Step 4) will generate the same results as CPU based “Quickstart - OpenAI Embedding” notebook. To see GPU acceleration you need to run query against bigger embeddings. \n", "For example, running 100K rows dataset will give 6x acceleration and consume less than 10x memory on 2 nodes NVIDIA T4 cluster compare to AMD Epic (Rome) 2 nodes CPU cluster.\n", diff --git a/docs/Explore Algorithms/OpenAI/Quickstart - OpenAI Embedding.ipynb b/docs/Explore Algorithms/OpenAI/Quickstart - OpenAI Embedding.ipynb index 6b973bab22b..78995acdea7 100644 --- a/docs/Explore Algorithms/OpenAI/Quickstart - OpenAI Embedding.ipynb +++ b/docs/Explore Algorithms/OpenAI/Quickstart - OpenAI Embedding.ipynb @@ -17,7 +17,7 @@ "source": [ "# Embedding Text with Azure OpenAI\n", "\n", - "The Azure OpenAI service can be used to solve a large number of natural language tasks through prompting the completion API. To make it easier to scale your prompting workflows from a few examples to large datasets of examples we have integrated the Azure OpenAI service with the distributed machine learning library [SynapseML](https://www.microsoft.com/en-us/research/blog/synapseml-a-simple-multilingual-and-massively-parallel-machine-learning-library/). This integration makes it easy to use the [Apache Spark](https://spark.apache.org/) distributed computing framework to process millions of prompts with the OpenAI service. This tutorial shows how to apply large language models to generate embeddings for large datasets of text. \n", + "The Azure OpenAI service can be used to generate embeddings for large datasets of text. To make it easier to scale your embedding workflows from a few examples to large datasets of examples we have integrated the Azure OpenAI service with the distributed machine learning library [SynapseML](https://www.microsoft.com/en-us/research/blog/synapseml-a-simple-multilingual-and-massively-parallel-machine-learning-library/). This integration makes it easy to use the [Apache Spark](https://spark.apache.org/) distributed computing framework to process millions of inputs with the OpenAI service. This tutorial shows how to apply large language models to generate embeddings for large datasets of text.\n", "\n", "## Step 1: Prerequisites\n", "\n", diff --git a/docs/Get Started/Set up Cognitive Services.ipynb b/docs/Get Started/Set up Cognitive Services.ipynb index 7bf4333434b..73fdd8af729 100644 --- a/docs/Get Started/Set up Cognitive Services.ipynb +++ b/docs/Get Started/Set up Cognitive Services.ipynb @@ -27,7 +27,7 @@ "source": [ "## Azure OpenAI\n", "\n", - "The [Azure OpenAI service](https://azure.microsoft.com/products/cognitive-services/openai-service/) can be used to solve a large number of natural language tasks through prompting the completion API. To make it easier to scale your prompting workflows from a few examples to large datasets of examples, we have integrated the Azure OpenAI service with the distributed machine learning library SynapseML. This integration makes it easy to use the Apache Spark distributed computing framework to process millions of prompts with the OpenAI service." + "The [Azure OpenAI service](https://azure.microsoft.com/products/cognitive-services/openai-service/) can be used to solve a large number of natural language tasks through chat, responses, and embedding APIs. To make it easier to scale your prompting workflows from a few examples to large datasets of examples, we have integrated the Azure OpenAI service with the distributed machine learning library SynapseML. This integration makes it easy to use the Apache Spark distributed computing framework to process millions of prompts with the OpenAI service. The legacy Completions API and SynapseML `OpenAICompletion` transformer are deprecated and retired; use chat completions or responses APIs for text generation." ] }, { diff --git a/tools/docgen/docgen/manifest.yaml b/tools/docgen/docgen/manifest.yaml index 77302a46d8e..d141445fc30 100644 --- a/tools/docgen/docgen/manifest.yaml +++ b/tools/docgen/docgen/manifest.yaml @@ -99,7 +99,7 @@ channels: filename: open-ai metadata: title: Azure OpenAI for big data - description: Use Azure OpenAI service to solve a large number of natural language tasks through prompting the completion API. + description: Use Azure OpenAI service to solve a large number of natural language tasks through chat, responses, and embedding APIs. ms.topic: how-to ms.custom: build-2023 ms.reviewer: jessiwang @@ -161,4 +161,4 @@ channels: ms.topic: overview ms.reviewer: sngun, garye, negust, ruxu, jessiwang author: WilliamDAssafMSFT - ms.author: wiassaf \ No newline at end of file + ms.author: wiassaf From aea26ca9dede45805f903ce6b8e4723eae861b0f Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Fri, 8 May 2026 18:06:38 -0700 Subject: [PATCH 02/13] Fix FuzzingUnitTest --- .../openai/OpenAIV1EndpointSuite.scala | 59 +++++++++---------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIV1EndpointSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIV1EndpointSuite.scala index b2bf023d0be..2b4b783c1a7 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIV1EndpointSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIV1EndpointSuite.scala @@ -4,6 +4,7 @@ package com.microsoft.azure.synapse.ml.services.openai import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import com.microsoft.azure.synapse.ml.services.HasCognitiveServiceInput import org.apache.http.entity.AbstractHttpEntity import org.apache.http.util.EntityUtils import org.apache.spark.sql.Row @@ -15,22 +16,18 @@ class OpenAIV1EndpointSuite extends TestBase { import spark.implicits._ - private class InspectableChatCompletion extends OpenAIChatCompletion { - def requestUrl(row: Row): String = prepareUrl.apply(row) - def requestPayload(row: Row): JsObject = - EntityUtils.toString(prepareEntity.apply(row).get).parseJson.asJsObject - } + private val prepareUrl = classOf[HasCognitiveServiceInput].getDeclaredMethod("prepareUrl") + prepareUrl.setAccessible(true) - private class InspectableEmbedding extends OpenAIEmbedding { - def requestUrl(row: Row): String = prepareUrl.apply(row) - def requestPayload(row: Row): JsObject = - EntityUtils.toString(prepareEntity.apply(row).get).parseJson.asJsObject - } + private val prepareEntity = classOf[HasCognitiveServiceInput].getDeclaredMethod("prepareEntity") + prepareEntity.setAccessible(true) + + private def requestUrl(transformer: HasCognitiveServiceInput, row: Row): String = + prepareUrl.invoke(transformer).asInstanceOf[Row => String].apply(row) - private class InspectableResponses extends OpenAIResponses { - def requestUrl(row: Row): String = prepareUrl.apply(row) - def requestPayload(row: Row): JsObject = - EntityUtils.toString(prepareEntity.apply(row).get).parseJson.asJsObject + private def requestPayload(transformer: HasCognitiveServiceInput, row: Row): JsObject = { + val entityBuilder = prepareEntity.invoke(transformer).asInstanceOf[Row => Option[AbstractHttpEntity]] + EntityUtils.toString(entityBuilder.apply(row).get).parseJson.asJsObject } private val messageSchema = StructType(Seq( @@ -52,89 +49,89 @@ class OpenAIV1EndpointSuite extends TestBase { } test("chat completions uses OpenAI v1 base URL without api-version and sends model") { - val transformer = new InspectableChatCompletion() + val transformer = new OpenAIChatCompletion() .setUrl("https://example.services.ai.azure.com/openai/v1") .setDeploymentName("gpt-4o") .setMessagesCol("messages") .setApiVersion("2025-04-01-preview") val row = messagesRow - assert(transformer.requestUrl(row) == "https://example.services.ai.azure.com/openai/v1/chat/completions") + assert(requestUrl(transformer, row) == "https://example.services.ai.azure.com/openai/v1/chat/completions") - val payload = transformer.requestPayload(row) + val payload = requestPayload(transformer, row) assert(payload.fields.get("model").contains(JsString("gpt-4o"))) assert(payload.fields.contains("messages")) } test("chat completions keeps legacy Azure deployment URL and api-version") { - val transformer = new InspectableChatCompletion() + val transformer = new OpenAIChatCompletion() .setUrl("https://example.openai.azure.com/") .setDeploymentName("gpt-4o") .setMessagesCol("messages") .setApiVersion("2025-04-01-preview") val row = messagesRow - assert(transformer.requestUrl(row) == + assert(requestUrl(transformer, row) == "https://example.openai.azure.com/openai/deployments/gpt-4o/chat/completions" + "?api-version=2025-04-01-preview") - assert(!transformer.requestPayload(row).fields.contains("model")) + assert(!requestPayload(transformer, row).fields.contains("model")) } test("embeddings uses OpenAI v1 base URL and sends deployment as model") { - val transformer = new InspectableEmbedding() + val transformer = new OpenAIEmbedding() .setUrl("https://example.services.ai.azure.com/openai/v1") .setDeploymentName("text-embedding-3-large") .setTextCol("text") .setApiVersion("2025-04-01-preview") val row = Seq("hello").toDF("text").collect().head - assert(transformer.requestUrl(row) == "https://example.services.ai.azure.com/openai/v1/embeddings") + assert(requestUrl(transformer, row) == "https://example.services.ai.azure.com/openai/v1/embeddings") - val payload = transformer.requestPayload(row) + val payload = requestPayload(transformer, row) assert(payload.fields.get("model").contains(JsString("text-embedding-3-large"))) assert(payload.fields.get("input").contains(JsString("hello"))) } test("embeddings keeps legacy Azure deployment URL and api-version") { - val transformer = new InspectableEmbedding() + val transformer = new OpenAIEmbedding() .setUrl("https://example.openai.azure.com/") .setDeploymentName("text-embedding-3-large") .setTextCol("text") .setApiVersion("2025-04-01-preview") val row = Seq("hello").toDF("text").collect().head - assert(transformer.requestUrl(row) == + assert(requestUrl(transformer, row) == "https://example.openai.azure.com/openai/deployments/text-embedding-3-large/embeddings" + "?api-version=2025-04-01-preview") - val payload = transformer.requestPayload(row) + val payload = requestPayload(transformer, row) assert(!payload.fields.contains("model")) assert(payload.fields.get("input").contains(JsString("hello"))) } test("responses uses OpenAI v1 base URL without api-version") { - val transformer = new InspectableResponses() + val transformer = new OpenAIResponses() .setUrl("https://example.services.ai.azure.com/openai/v1") .setDeploymentName("gpt-5-mini") .setMessagesCol("messages") .setApiVersion("2025-04-01-preview") val row = messagesRow - assert(transformer.requestUrl(row) == "https://example.services.ai.azure.com/openai/v1/responses") + assert(requestUrl(transformer, row) == "https://example.services.ai.azure.com/openai/v1/responses") - val payload = transformer.requestPayload(row) + val payload = requestPayload(transformer, row) assert(payload.fields.get("model").contains(JsString("gpt-5-mini"))) assert(payload.fields.contains("input")) } test("responses keeps legacy Azure URL shape when URL is not an OpenAI v1 base") { - val transformer = new InspectableResponses() + val transformer = new OpenAIResponses() .setUrl("https://example.openai.azure.com/") .setDeploymentName("gpt-5-mini") .setMessagesCol("messages") .setApiVersion("2025-04-01-preview") - assert(transformer.requestUrl(messagesRow) == + assert(requestUrl(transformer, messagesRow) == "https://example.openai.azure.com/openai/responses?api-version=2025-04-01-preview") } From 60c6f7a9bebac7b7be12b05f2952b7d108fce142 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Fri, 8 May 2026 18:30:29 -0700 Subject: [PATCH 03/13] Add test to increase code coverage --- .../ml/services/openai/OpenAIV1EndpointSuite.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIV1EndpointSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIV1EndpointSuite.scala index 2b4b783c1a7..b5a3fdda707 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIV1EndpointSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIV1EndpointSuite.scala @@ -124,6 +124,17 @@ class OpenAIV1EndpointSuite extends TestBase { assert(payload.fields.contains("input")) } + test("responses v1 endpoint requires deployment name as model") { + val transformer = new OpenAIResponses() + .setUrl("https://example.services.ai.azure.com/openai/v1") + .setMessagesCol("messages") + + val err = intercept[IllegalArgumentException] { + requestPayload(transformer, messagesRow) + } + assert(err.getMessage.contains("No deployment/model name provided for OpenAI v1 endpoint")) + } + test("responses keeps legacy Azure URL shape when URL is not an OpenAI v1 base") { val transformer = new OpenAIResponses() .setUrl("https://example.openai.azure.com/") From ce41a3350d9ed7d91b838c661dededaf72caf5c0 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Sat, 9 May 2026 01:45:23 -0700 Subject: [PATCH 04/13] Make v1 api assumption cleaner --- .../aifoundry/AIFoundryChatCompletion.scala | 5 +- .../synapse/ml/services/openai/OpenAI.scala | 26 +- .../ml/services/openai/OpenAIDefaults.scala | 3 +- .../ml/services/openai/OpenAIPrompt.scala | 2 + .../openai/OpenAIV1EndpointSuite.scala | 238 +++++++++++++++--- 5 files changed, 232 insertions(+), 42 deletions(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/aifoundry/AIFoundryChatCompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/aifoundry/AIFoundryChatCompletion.scala index 198a8a07c14..306b9fcf787 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/aifoundry/AIFoundryChatCompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/aifoundry/AIFoundryChatCompletion.scala @@ -58,9 +58,8 @@ class AIFoundryChatCompletion(override val uid: String) extends OpenAIChatComple setUrl(s"https://$v.services.ai.azure.com/" + urlPath.stripPrefix("/")) } - override protected def prepareUrlRoot: Row => String = { row => - s"${getUrl}models/chat/completions" + override protected def prepareUrlRoot: Row => String = { _ => + endpointUrl("models/chat/completions") } } - diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index 31af1532333..0f453dbc438 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -14,10 +14,8 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.types._ import spray.json.DefaultJsonProtocol._ -import java.net.URI import java.util.Locale import scala.language.existentials -import scala.util.Try trait HasMessagesInput extends Params { val messagesCol: Param[String] = new Param[String]( @@ -35,15 +33,23 @@ case object OpenAIEmbeddingDeploymentNameKey extends GlobalKey[Either[String, St private[openai] object OpenAIEndpointUtils { private def stripTrailingSlashes(value: String): String = value.replaceAll("/+$", "") + private def withoutQueryOrFragment(value: String): String = { + val stopAt = Seq(value.indexOf("?"), value.indexOf("#")).filter(_ >= 0) match { + case Seq() => value.length + case indexes => indexes.min + } + value.take(stopAt) + } + def appendPath(baseUrl: String, path: String): String = { - val normalizedBase = if (baseUrl.endsWith("/")) baseUrl else baseUrl + "/" - normalizedBase + path.stripPrefix("/") + val separator = if (baseUrl.endsWith("/")) "" else "/" + baseUrl + separator + path.stripPrefix("/") } def isV1BaseUrl(baseUrl: String): Boolean = { - val path = Try(new URI(baseUrl.trim).getPath).toOption.getOrElse("") - val normalizedPath = stripTrailingSlashes(path).toLowerCase(Locale.ROOT) - normalizedPath == "/v1" || normalizedPath.endsWith("/openai/v1") + stripTrailingSlashes(withoutQueryOrFragment(baseUrl)) + .toLowerCase(Locale.ROOT) + .endsWith("/v1") } } @@ -449,6 +455,8 @@ abstract class OpenAIServicesBase(override val uid: String) extends CognitiveSer with HasOpenAISharedParams with OpenAIFabricSetting { setDefault(timeout -> 360.0) + override def setUrl(value: String): this.type = set(url, value) + protected[openai] def isOpenAIV1BaseUrl: Boolean = get(url).orElse(getDefault(url)).exists(OpenAIEndpointUtils.isV1BaseUrl) @@ -463,10 +471,10 @@ abstract class OpenAIServicesBase(override val uid: String) extends CognitiveSer } private def warnIfV1ApiVersionConfigured(): Unit = { - if (isOpenAIV1BaseUrl && get(apiVersion).nonEmpty) { + if (isOpenAIV1BaseUrl && (get(apiVersion).nonEmpty || GlobalParams.getParam(apiVersion).nonEmpty)) { logWarning( "apiVersion is ignored when the OpenAI URL is a v1 base URL. " + - "Remove apiVersion or use an Azure OpenAI endpoint without /openai/v1.") + "Remove apiVersion or use a non-v1 endpoint.") } } diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala index 8d63032898a..cc86df4478c 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala @@ -47,8 +47,7 @@ object OpenAIDefaults { } def setURL(v: String): Unit = { - val url = if (v.endsWith("/")) v else v + "/" - GlobalParams.setGlobalParam(URLKey, url) + GlobalParams.setGlobalParam(URLKey, v) } def getURL: Option[String] = { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index 37787aa861c..9e56e997748 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -241,6 +241,8 @@ class OpenAIPrompt(override val uid: String) extends Transformer store -> Left(false) ) + override def setUrl(value: String): this.type = set(url, value) + override def setCustomServiceName(v: String): this.type = { setUrl(s"https://$v.openai.azure.com/" + urlPath.stripPrefix("/")) } diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIV1EndpointSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIV1EndpointSuite.scala index b5a3fdda707..e2d765eb22d 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIV1EndpointSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIV1EndpointSuite.scala @@ -5,6 +5,7 @@ package com.microsoft.azure.synapse.ml.services.openai import com.microsoft.azure.synapse.ml.core.test.base.TestBase import com.microsoft.azure.synapse.ml.services.HasCognitiveServiceInput +import com.microsoft.azure.synapse.ml.services.aifoundry.AIFoundryChatCompletion import org.apache.http.entity.AbstractHttpEntity import org.apache.http.util.EntityUtils import org.apache.spark.sql.Row @@ -48,6 +49,52 @@ class OpenAIV1EndpointSuite extends TestBase { new GenericRowWithSchema(Array[Any](Seq(message)), messagesRequestSchema) } + test("OpenAI URLs preserve configured base URL strings") { + val root = new OpenAIChatCompletion().setUrl("https://example.openai.azure.com") + assert(root.getUrl == "https://example.openai.azure.com") + + val v1 = new OpenAIChatCompletion().setUrl("https://example.openai.azure.com/openai/v1") + assert(v1.getUrl == "https://example.openai.azure.com/openai/v1") + + val prompt = new OpenAIPrompt().setUrl("https://example.services.ai.azure.com") + assert(prompt.getUrl == "https://example.services.ai.azure.com") + + val versionedPath = "https://synapseml-openai-3.openai.azure.com/openai/v2" + OpenAIDefaults.setURL(versionedPath) + try { + assert(OpenAIDefaults.getURL.contains(versionedPath)) + } finally { + OpenAIDefaults.resetURL() + } + + OpenAIDefaults.setURL("https://example.services.ai.azure.com/openai/v1") + try { + val transformer = new OpenAIChatCompletion() + transformer.transferGlobalParamsToParamMap() + assert(transformer.getUrl == "https://example.services.ai.azure.com/openai/v1") + } finally { + OpenAIDefaults.resetURL() + } + } + + test("non-v1 versioned paths remain literal non-v1 base URLs") { + val versionedPath = "https://synapseml-openai-3.openai.azure.com/openai/v2" + OpenAIDefaults.setURL(versionedPath) + try { + val transformer = new OpenAIChatCompletion() + .setDeploymentName("gpt-4o") + .setMessagesCol("messages") + transformer.transferGlobalParamsToParamMap() + + assert(OpenAIDefaults.getURL.contains(versionedPath)) + assert(transformer.getUrl == versionedPath) + assert(requestUrl(transformer, messagesRow) == + versionedPath + "/openai/deployments/gpt-4o/chat/completions?api-version=2025-04-01-preview") + } finally { + OpenAIDefaults.resetURL() + } + } + test("chat completions uses OpenAI v1 base URL without api-version and sends model") { val transformer = new OpenAIChatCompletion() .setUrl("https://example.services.ai.azure.com/openai/v1") @@ -63,33 +110,159 @@ class OpenAIV1EndpointSuite extends TestBase { assert(payload.fields.contains("messages")) } - test("chat completions keeps legacy Azure deployment URL and api-version") { + test("chat completions accepts OpenAI-compatible v1 base URLs with and without trailing slash") { + Seq( + "https://example.openai.azure.com/openai/v1" -> + "https://example.openai.azure.com/openai/v1/chat/completions", + "https://example.openai.azure.com/openai/v1/" -> + "https://example.openai.azure.com/openai/v1/chat/completions", + "https://api.openai.com/v1" -> + "https://api.openai.com/v1/chat/completions", + "http://localhost:8000/v1/" -> + "http://localhost:8000/v1/chat/completions" + ).foreach { case (baseUrl, expectedUrl) => + val transformer = new OpenAIChatCompletion() + .setUrl(baseUrl) + .setDeploymentName("gpt-4o") + .setMessagesCol("messages") + .setApiVersion("2025-04-01-preview") + + assert(requestUrl(transformer, messagesRow) == expectedUrl) + } + } + + test("chat completions keeps legacy Azure deployment URL and api-version with and without trailing slash") { + Seq("https://example.openai.azure.com", "https://example.openai.azure.com/").foreach { baseUrl => + val transformer = new OpenAIChatCompletion() + .setUrl(baseUrl) + .setDeploymentName("gpt-4o") + .setMessagesCol("messages") + .setApiVersion("2025-04-01-preview") + + val row = messagesRow + assert(requestUrl(transformer, row) == + "https://example.openai.azure.com/openai/deployments/gpt-4o/chat/completions" + + "?api-version=2025-04-01-preview") + assert(!requestPayload(transformer, row).fields.contains("model")) + } + } + + test("chat completions accepts services.ai.azure.com resource root with and without trailing slash") { + Seq("https://example.services.ai.azure.com", "https://example.services.ai.azure.com/").foreach { baseUrl => + val transformer = new OpenAIChatCompletion() + .setUrl(baseUrl) + .setDeploymentName("gpt-4o") + .setMessagesCol("messages") + .setApiVersion("2025-04-01-preview") + + assert(requestUrl(transformer, messagesRow) == + "https://example.services.ai.azure.com/openai/deployments/gpt-4o/chat/completions" + + "?api-version=2025-04-01-preview") + } + } + + test("AI Foundry chat accepts services.ai.azure.com resource root with and without trailing slash") { + Seq("https://example.services.ai.azure.com", "https://example.services.ai.azure.com/").foreach { baseUrl => + val transformer = new AIFoundryChatCompletion() + .setUrl(baseUrl) + .setModel("gpt-4o") + .setMessagesCol("messages") + .setApiVersion("2025-04-01-preview") + + assert(requestUrl(transformer, messagesRow) == + "https://example.services.ai.azure.com/models/chat/completions?api-version=2025-04-01-preview") + } + } + + test("non-v1 URL paths remain permissive and use legacy request construction") { val transformer = new OpenAIChatCompletion() - .setUrl("https://example.openai.azure.com/") + .setUrl("https://example.openai.azure.com/openai") .setDeploymentName("gpt-4o") .setMessagesCol("messages") .setApiVersion("2025-04-01-preview") - val row = messagesRow - assert(requestUrl(transformer, row) == - "https://example.openai.azure.com/openai/deployments/gpt-4o/chat/completions" + + assert(requestUrl(transformer, messagesRow) == + "https://example.openai.azure.com/openai/openai/deployments/gpt-4o/chat/completions" + "?api-version=2025-04-01-preview") - assert(!requestPayload(transformer, row).fields.contains("model")) } - test("embeddings uses OpenAI v1 base URL and sends deployment as model") { - val transformer = new OpenAIEmbedding() - .setUrl("https://example.services.ai.azure.com/openai/v1") - .setDeploymentName("text-embedding-3-large") - .setTextCol("text") + test("custom non-Azure URL strings remain permissive") { + val transformer = new OpenAIChatCompletion() + .setUrl("https://proxy.contoso.com/openai") + .setDeploymentName("gpt-4o") + .setMessagesCol("messages") .setApiVersion("2025-04-01-preview") - val row = Seq("hello").toDF("text").collect().head - assert(requestUrl(transformer, row) == "https://example.services.ai.azure.com/openai/v1/embeddings") + assert(requestUrl(transformer, messagesRow) == + "https://proxy.contoso.com/openai/openai/deployments/gpt-4o/chat/completions" + + "?api-version=2025-04-01-preview") + } - val payload = requestPayload(transformer, row) - assert(payload.fields.get("model").contains(JsString("text-embedding-3-large"))) - assert(payload.fields.get("input").contains(JsString("hello"))) + test("OpenAI defaults allow non-v1 URL paths") { + OpenAIDefaults.setURL("https://example.openai.azure.com/openai") + try { + val transformer = new OpenAIChatCompletion() + .setDeploymentName("gpt-4o") + .setMessagesCol("messages") + transformer.transferGlobalParamsToParamMap() + + assert(requestUrl(transformer, messagesRow) == + "https://example.openai.azure.com/openai/openai/deployments/gpt-4o/chat/completions" + + "?api-version=2025-04-01-preview") + } finally { + OpenAIDefaults.resetURL() + } + } + + test("OpenAI defaults allow arbitrary URL strings") { + OpenAIDefaults.setURL("not-a-url") + try { + val transformer = new OpenAIChatCompletion() + transformer.transferGlobalParamsToParamMap() + assert(transformer.getUrl == "not-a-url") + } finally { + OpenAIDefaults.resetURL() + } + } + + test("OpenAI defaults accept v1 URL and omit global api-version") { + OpenAIDefaults.setURL("https://example.openai.azure.com/openai/v1") + OpenAIDefaults.setApiVersion("2025-04-01-preview") + try { + val transformer = new OpenAIChatCompletion() + .setDeploymentName("gpt-4o") + .setMessagesCol("messages") + transformer.transferGlobalParamsToParamMap() + + assert(requestUrl(transformer, messagesRow) == "https://example.openai.azure.com/openai/v1/chat/completions") + } finally { + OpenAIDefaults.resetURL() + OpenAIDefaults.resetApiVersion() + } + } + + test("embeddings uses OpenAI v1 base URL and sends deployment as model") { + Seq( + "https://example.services.ai.azure.com/openai/v1" -> + "https://example.services.ai.azure.com/openai/v1/embeddings", + "https://example.services.ai.azure.com/openai/v1/" -> + "https://example.services.ai.azure.com/openai/v1/embeddings", + "https://api.openai.com/v1" -> + "https://api.openai.com/v1/embeddings" + ).foreach { case (baseUrl, expectedUrl) => + val transformer = new OpenAIEmbedding() + .setUrl(baseUrl) + .setDeploymentName("text-embedding-3-large") + .setTextCol("text") + .setApiVersion("2025-04-01-preview") + + val row = Seq("hello").toDF("text").collect().head + assert(requestUrl(transformer, row) == expectedUrl) + + val payload = requestPayload(transformer, row) + assert(payload.fields.get("model").contains(JsString("text-embedding-3-large"))) + assert(payload.fields.get("input").contains(JsString("hello"))) + } } test("embeddings keeps legacy Azure deployment URL and api-version") { @@ -110,18 +283,27 @@ class OpenAIV1EndpointSuite extends TestBase { } test("responses uses OpenAI v1 base URL without api-version") { - val transformer = new OpenAIResponses() - .setUrl("https://example.services.ai.azure.com/openai/v1") - .setDeploymentName("gpt-5-mini") - .setMessagesCol("messages") - .setApiVersion("2025-04-01-preview") - - val row = messagesRow - assert(requestUrl(transformer, row) == "https://example.services.ai.azure.com/openai/v1/responses") - - val payload = requestPayload(transformer, row) - assert(payload.fields.get("model").contains(JsString("gpt-5-mini"))) - assert(payload.fields.contains("input")) + Seq( + "https://example.services.ai.azure.com/openai/v1" -> + "https://example.services.ai.azure.com/openai/v1/responses", + "https://example.services.ai.azure.com/openai/v1/" -> + "https://example.services.ai.azure.com/openai/v1/responses", + "https://api.openai.com/v1" -> + "https://api.openai.com/v1/responses" + ).foreach { case (baseUrl, expectedUrl) => + val transformer = new OpenAIResponses() + .setUrl(baseUrl) + .setDeploymentName("gpt-5-mini") + .setMessagesCol("messages") + .setApiVersion("2025-04-01-preview") + + val row = messagesRow + assert(requestUrl(transformer, row) == expectedUrl) + + val payload = requestPayload(transformer, row) + assert(payload.fields.get("model").contains(JsString("gpt-5-mini"))) + assert(payload.fields.contains("input")) + } } test("responses v1 endpoint requires deployment name as model") { From fa708e25cbacebd959bd2ca3ef9573a9d0bb8670 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Mon, 18 May 2026 23:47:34 -0700 Subject: [PATCH 05/13] Add OpenAICompletion deprecation --- .../ml/services/openai/OpenAICompletion.py | 29 ++++++++ .../openai/test_OpenAICompletionDeprecated.py | 68 +++++++++++++++++++ .../azure/synapse/ml/codegen/PyCodegen.scala | 29 +++++++- 3 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 cognitive/src/main/python/synapse/ml/services/openai/OpenAICompletion.py create mode 100644 cognitive/src/test/python/synapsemltest/services/openai/test_OpenAICompletionDeprecated.py diff --git a/cognitive/src/main/python/synapse/ml/services/openai/OpenAICompletion.py b/cognitive/src/main/python/synapse/ml/services/openai/OpenAICompletion.py new file mode 100644 index 00000000000..6b3df7ec9c2 --- /dev/null +++ b/cognitive/src/main/python/synapse/ml/services/openai/OpenAICompletion.py @@ -0,0 +1,29 @@ +# Copyright (C) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE in project root for information. + +import warnings + +__all__ = ["OpenAICompletion"] + +_OPENAI_COMPLETION_DEPRECATION_MESSAGE = ( + "OpenAICompletion has been removed because the legacy OpenAI Completions API " + "is deprecated and retired. Use OpenAIResponses, OpenAIChatCompletion, or " + "OpenAIPrompt with setApiType('chat_completions') or setApiType('responses') instead." +) + + +def warn_openai_completion_deprecated(stacklevel=2): + warnings.warn( + _OPENAI_COMPLETION_DEPRECATION_MESSAGE, + FutureWarning, + stacklevel=stacklevel, + ) + + +warn_openai_completion_deprecated(stacklevel=2) + + +class OpenAICompletion: + def __init__(self, *args, **kwargs): + warn_openai_completion_deprecated(stacklevel=2) + raise RuntimeError(_OPENAI_COMPLETION_DEPRECATION_MESSAGE) diff --git a/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAICompletionDeprecated.py b/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAICompletionDeprecated.py new file mode 100644 index 00000000000..075359b1d79 --- /dev/null +++ b/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAICompletionDeprecated.py @@ -0,0 +1,68 @@ +# Copyright (C) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE in project root for information. + +import sys +import unittest +import warnings + +_MODULE_NAME = "synapse.ml.services.openai.OpenAICompletion" +_PACKAGE_NAME = "synapse.ml.services.openai" +_WARNING_TEXT = "OpenAICompletion has been removed" + + +def _clear_openai_completion_imports(): + sys.modules.pop(_MODULE_NAME, None) + package = sys.modules.get(_PACKAGE_NAME) + if package is not None: + package.__dict__.pop("OpenAICompletion", None) + + +def _has_openai_completion_warning(caught): + return any( + issubclass(warning.category, FutureWarning) + and _WARNING_TEXT in str(warning.message) + for warning in caught + ) + + +class TestOpenAICompletionDeprecated(unittest.TestCase): + def test_package_import_warns(self): + _clear_openai_completion_imports() + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + from synapse.ml.services.openai import OpenAICompletion + + package = sys.modules[_PACKAGE_NAME] + if hasattr(package, "__getattr__"): + self.assertIsInstance(OpenAICompletion, type) + else: + self.assertIsInstance(OpenAICompletion.OpenAICompletion, type) + self.assertTrue(_has_openai_completion_warning(caught)) + + def test_submodule_import_warns(self): + _clear_openai_completion_imports() + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + from synapse.ml.services.openai.OpenAICompletion import OpenAICompletion + + self.assertIsInstance(OpenAICompletion, type) + self.assertTrue(_has_openai_completion_warning(caught)) + + def test_instantiation_warns_and_raises(self): + _clear_openai_completion_imports() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from synapse.ml.services.openai.OpenAICompletion import OpenAICompletion + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + with self.assertRaisesRegex(RuntimeError, _WARNING_TEXT): + OpenAICompletion() + + self.assertTrue(_has_openai_completion_warning(caught)) + + +if __name__ == "__main__": + result = unittest.main() diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala index 425d7314f6f..e316202f80b 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala @@ -18,6 +18,32 @@ object PyCodegen { import CodeGenUtils._ + private val DeprecatedOpenAICompletionFile = "OpenAICompletion.py" + + private val OpenAICompletionImportHook: String = + """ + |def __getattr__(name): + | if name == "OpenAICompletion": + | import warnings + | + | with warnings.catch_warnings(): + | warnings.simplefilter("ignore", FutureWarning) + | from synapse.ml.services.openai.OpenAICompletion import ( + | OpenAICompletion, + | warn_openai_completion_deprecated, + | ) + | warn_openai_completion_deprecated(stacklevel=2) + | globals()["OpenAICompletion"] = OpenAICompletion + | return OpenAICompletion + | raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + |""".stripMargin + + private def isOpenAICompletionStub(packageFolder: String, fileName: String): Boolean = + packageFolder == "/services/openai" && fileName == DeprecatedOpenAICompletionFile + + private def initFileExtra(packageFolder: String): String = + if (packageFolder == "/services/openai") OpenAICompletionImportHook else "" + def generatePythonClasses(conf: CodegenConfig): Unit = { val instantiatedClasses = instantiateServices[PythonWrappable](conf.jarName) instantiatedClasses.foreach { w => @@ -37,12 +63,13 @@ object PyCodegen { dir.listFiles.filter(_.isFile).sorted .map(_.getName) .filter(name => name.endsWith(".py") && !name.startsWith("_") && !name.startsWith("test")) + .filterNot(name => isOpenAICompletionStub(packageFolder, name)) .map(name => s"from synapse.ml$packageString.${getBaseName(name)} import *\n").mkString("") } val initFile = new File(dir, "__init__.py") if (packageFolder != "/cognitive"){ if (packageFolder != "") { - writeFile(initFile, conf.packageHelp(importStrings)) + writeFile(initFile, conf.packageHelp(importStrings) + initFileExtra(packageFolder)) } else if (initFile.exists()) { initFile.delete() } From 9a40c5c21ff9ffbdb0ed04cd3eb10703a5569165 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Mon, 18 May 2026 23:50:22 -0700 Subject: [PATCH 06/13] Remove deprecation warnings --- .../synapse/ml/causal/DoubleMLEstimator.scala | 2 +- .../ml/causal/OrthoForestDMLEstimator.scala | 2 +- .../ml/causal/ResidualTransformer.scala | 4 ++-- .../ml/core/utils/CloseableIterator.scala | 3 --- .../azure/synapse/ml/param/GlobalParams.scala | 23 ++++++++----------- 5 files changed, 14 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala index 738b7ffeed1..dffe6246823 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala @@ -246,7 +246,7 @@ class DoubleMLEstimator(override val uid: String) 4. Cross-fit treatment and outcome models with the second split, residual model with the first split. 5. Average slopes from the two residual models. */ - val splits = dataset.randomSplit(getSampleSplitRatio) + val splits = dataset.toDF().randomSplit(getSampleSplitRatio) val (train, test) = (splits(0).cache, splits(1).cache) val residualsDF1 = calculateResiduals(train, test).select(outcomeResidualCol, treatmentResidualVecCol) val residualsDF2 = calculateResiduals(test, train).select(outcomeResidualCol, treatmentResidualVecCol) diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/OrthoForestDMLEstimator.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/OrthoForestDMLEstimator.scala index 46c7e4a9593..cbb82267888 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/OrthoForestDMLEstimator.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/OrthoForestDMLEstimator.scala @@ -104,7 +104,7 @@ class OrthoForestDMLEstimator(override val uid: String) 4. Cross-fit treatment and outcome models with the second split, residual model with the first split. 5. Average slopes from the two residual models is eqiuivalent to fitting one tree */ - val splits = dataset.randomSplit(getSampleSplitRatio) + val splits = dataset.toDF().randomSplit(getSampleSplitRatio) val (train, test) = (splits(0).cache, splits(1).cache) val residualsDF1 = calculateResiduals(train, test) val residualsDF2 = calculateResiduals(test, train) diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/ResidualTransformer.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/ResidualTransformer.scala index de6dfe9d3f9..5248b944554 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/ResidualTransformer.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/ResidualTransformer.scala @@ -72,9 +72,9 @@ class ResidualTransformer(override val uid: String) extends Transformer s"${this.getClass.getSimpleName}: " + s"observedCol must be of type DoubleType, LongType, IntegerType or BooleanType but got $observedColType") - val convertedDataset = if (observedColType == BooleanType) { + val convertedDataset: DataFrame = if (observedColType == BooleanType) { dataset.withColumn(getObservedCol, col(getObservedCol).cast(IntegerType)) - } else dataset + } else dataset.toDF() val predictedColDataType = convertedDataset.schema(getPredictedCol).dataType diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/core/utils/CloseableIterator.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/core/utils/CloseableIterator.scala index 68656f6ff0c..e541079eee9 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/core/utils/CloseableIterator.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/core/utils/CloseableIterator.scala @@ -26,9 +26,6 @@ class CloseableIterator[+T](delegate: Iterator[T], cleanup: => Unit) extends Ite catch { case _: Throwable => } - - super.finalize() } } //scalastyle:on no.finalize - diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala index 9ff42cd8b28..8d888df85d9 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala @@ -10,34 +10,31 @@ import scala.collection.mutable trait GlobalKey[T] object GlobalParams { - private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty - private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty + private val ParamToKeyMap: mutable.Map[Any, GlobalKey[Any]] = mutable.Map.empty + private val GlobalParams: mutable.Map[GlobalKey[Any], Any] = mutable.Map.empty + private def untypedKey[T](key: GlobalKey[T]): GlobalKey[Any] = { + key.asInstanceOf[GlobalKey[Any]] + } def setGlobalParam[T](key: GlobalKey[T], value: T): Unit = { - GlobalParams(key) = value + GlobalParams(untypedKey(key)) = value } def getGlobalParam[T](key: GlobalKey[T]): Option[T] = { - GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[T]) + GlobalParams.get(untypedKey(key)).map(_.asInstanceOf[T]) } def resetGlobalParam[T](key: GlobalKey[T]): Unit = { - GlobalParams -= key + GlobalParams -= untypedKey(key) } def getParam[T](p: Param[T]): Option[T] = { - ParamToKeyMap.get(p).flatMap { key => - key match { - case k: GlobalKey[T] => - getGlobalParam(k) - case _ => None - } - } + ParamToKeyMap.get(p).flatMap(GlobalParams.get).map(_.asInstanceOf[T]) } def registerParam[T](p: Param[T], key: GlobalKey[T]): Unit = { - ParamToKeyMap(p) = key + ParamToKeyMap(p) = untypedKey(key) } } From 3ed60449fe0ef594e742b24c701040e2b3624e8d Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Tue, 19 May 2026 02:45:30 -0700 Subject: [PATCH 07/13] Fix RAI test for OpenAIPrompt --- .../openai/OpenAIChatCompletionSuite.scala | 28 +++++++++++++++++++ .../services/openai/OpenAIPromptSuite.scala | 20 +++++++++---- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala index 60b1bc54018..d540cd71a55 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala @@ -539,6 +539,34 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion] testCompletion(completion, goodDf) } + test("content filtering identifies empty assistant content") { + val responseJson = + """{ + | "id":"chatcmpl_test", + | "object":"chat.completion", + | "created":"1", + | "model":"gpt-4.1", + | "choices":[ + | { + | "message":{"role":"assistant","content":null,"name":null}, + | "index":0, + | "finish_reason":"content_filter" + | } + | ], + | "system_fingerprint":null, + | "usage":null + |}""".stripMargin + + val outputRow = spark.read + .schema(ChatModelResponse.schema) + .json(Seq(responseJson).toDS) + .collect() + .head + val completion = new OpenAIChatCompletion() + assert(completion.isContentFiltered(outputRow)) + assert(completion.getFilterReason(outputRow) == "content_filter") + } + ignore("Custom EndPoint") { lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "") lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "") diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala index 35b425c973b..1942bbdb71c 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala @@ -112,17 +112,27 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } // scalastyle:on null + private def hasNonEmptyError(row: Row, errorColName: String): Boolean = { + Option(row.getAs[Row](errorColName)) + .flatMap(error => Option(error.getAs[String]("response"))) + .exists(_.nonEmpty) + } + test("RAI Usage") { - val result = prompt + val rows = prompt .setDeploymentName(deploymentName) .setPromptTemplate("Tell me about a graphically disgusting " + "and violent movie in detail, " + "be very gory and NSFW in your description.") .transform(df) - .where(col(prompt.getErrorCol).isNotNull) - .select(prompt.getErrorCol) - .collect().head.getAs[Row](0) - assert(Option(result).nonEmpty) + .select(col(prompt.getOutputCol), col(prompt.getErrorCol)) + .collect() + + assert(rows.length == 3) + rows.foreach { row => + val hasOutput = Option(row.getAs[String](prompt.getOutputCol)).exists(_.nonEmpty) + assert(hasOutput || hasNonEmptyError(row, prompt.getErrorCol)) + } } test("Basic Usage") { From 10715cd4c20ef80b0b318e27710a3027f5bd5236 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Tue, 19 May 2026 03:06:54 -0700 Subject: [PATCH 08/13] Revert "Add OpenAICompletion deprecation" This reverts commit fa708e25cbacebd959bd2ca3ef9573a9d0bb8670. --- .../ml/services/openai/OpenAICompletion.py | 29 -------- .../openai/test_OpenAICompletionDeprecated.py | 68 ------------------- .../azure/synapse/ml/codegen/PyCodegen.scala | 29 +------- 3 files changed, 1 insertion(+), 125 deletions(-) delete mode 100644 cognitive/src/main/python/synapse/ml/services/openai/OpenAICompletion.py delete mode 100644 cognitive/src/test/python/synapsemltest/services/openai/test_OpenAICompletionDeprecated.py diff --git a/cognitive/src/main/python/synapse/ml/services/openai/OpenAICompletion.py b/cognitive/src/main/python/synapse/ml/services/openai/OpenAICompletion.py deleted file mode 100644 index 6b3df7ec9c2..00000000000 --- a/cognitive/src/main/python/synapse/ml/services/openai/OpenAICompletion.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (C) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See LICENSE in project root for information. - -import warnings - -__all__ = ["OpenAICompletion"] - -_OPENAI_COMPLETION_DEPRECATION_MESSAGE = ( - "OpenAICompletion has been removed because the legacy OpenAI Completions API " - "is deprecated and retired. Use OpenAIResponses, OpenAIChatCompletion, or " - "OpenAIPrompt with setApiType('chat_completions') or setApiType('responses') instead." -) - - -def warn_openai_completion_deprecated(stacklevel=2): - warnings.warn( - _OPENAI_COMPLETION_DEPRECATION_MESSAGE, - FutureWarning, - stacklevel=stacklevel, - ) - - -warn_openai_completion_deprecated(stacklevel=2) - - -class OpenAICompletion: - def __init__(self, *args, **kwargs): - warn_openai_completion_deprecated(stacklevel=2) - raise RuntimeError(_OPENAI_COMPLETION_DEPRECATION_MESSAGE) diff --git a/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAICompletionDeprecated.py b/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAICompletionDeprecated.py deleted file mode 100644 index 075359b1d79..00000000000 --- a/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAICompletionDeprecated.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (C) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See LICENSE in project root for information. - -import sys -import unittest -import warnings - -_MODULE_NAME = "synapse.ml.services.openai.OpenAICompletion" -_PACKAGE_NAME = "synapse.ml.services.openai" -_WARNING_TEXT = "OpenAICompletion has been removed" - - -def _clear_openai_completion_imports(): - sys.modules.pop(_MODULE_NAME, None) - package = sys.modules.get(_PACKAGE_NAME) - if package is not None: - package.__dict__.pop("OpenAICompletion", None) - - -def _has_openai_completion_warning(caught): - return any( - issubclass(warning.category, FutureWarning) - and _WARNING_TEXT in str(warning.message) - for warning in caught - ) - - -class TestOpenAICompletionDeprecated(unittest.TestCase): - def test_package_import_warns(self): - _clear_openai_completion_imports() - - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - from synapse.ml.services.openai import OpenAICompletion - - package = sys.modules[_PACKAGE_NAME] - if hasattr(package, "__getattr__"): - self.assertIsInstance(OpenAICompletion, type) - else: - self.assertIsInstance(OpenAICompletion.OpenAICompletion, type) - self.assertTrue(_has_openai_completion_warning(caught)) - - def test_submodule_import_warns(self): - _clear_openai_completion_imports() - - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - from synapse.ml.services.openai.OpenAICompletion import OpenAICompletion - - self.assertIsInstance(OpenAICompletion, type) - self.assertTrue(_has_openai_completion_warning(caught)) - - def test_instantiation_warns_and_raises(self): - _clear_openai_completion_imports() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - from synapse.ml.services.openai.OpenAICompletion import OpenAICompletion - - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - with self.assertRaisesRegex(RuntimeError, _WARNING_TEXT): - OpenAICompletion() - - self.assertTrue(_has_openai_completion_warning(caught)) - - -if __name__ == "__main__": - result = unittest.main() diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala index e316202f80b..425d7314f6f 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala @@ -18,32 +18,6 @@ object PyCodegen { import CodeGenUtils._ - private val DeprecatedOpenAICompletionFile = "OpenAICompletion.py" - - private val OpenAICompletionImportHook: String = - """ - |def __getattr__(name): - | if name == "OpenAICompletion": - | import warnings - | - | with warnings.catch_warnings(): - | warnings.simplefilter("ignore", FutureWarning) - | from synapse.ml.services.openai.OpenAICompletion import ( - | OpenAICompletion, - | warn_openai_completion_deprecated, - | ) - | warn_openai_completion_deprecated(stacklevel=2) - | globals()["OpenAICompletion"] = OpenAICompletion - | return OpenAICompletion - | raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - |""".stripMargin - - private def isOpenAICompletionStub(packageFolder: String, fileName: String): Boolean = - packageFolder == "/services/openai" && fileName == DeprecatedOpenAICompletionFile - - private def initFileExtra(packageFolder: String): String = - if (packageFolder == "/services/openai") OpenAICompletionImportHook else "" - def generatePythonClasses(conf: CodegenConfig): Unit = { val instantiatedClasses = instantiateServices[PythonWrappable](conf.jarName) instantiatedClasses.foreach { w => @@ -63,13 +37,12 @@ object PyCodegen { dir.listFiles.filter(_.isFile).sorted .map(_.getName) .filter(name => name.endsWith(".py") && !name.startsWith("_") && !name.startsWith("test")) - .filterNot(name => isOpenAICompletionStub(packageFolder, name)) .map(name => s"from synapse.ml$packageString.${getBaseName(name)} import *\n").mkString("") } val initFile = new File(dir, "__init__.py") if (packageFolder != "/cognitive"){ if (packageFolder != "") { - writeFile(initFile, conf.packageHelp(importStrings) + initFileExtra(packageFolder)) + writeFile(initFile, conf.packageHelp(importStrings)) } else if (initFile.exists()) { initFile.delete() } From f06f1ade547f4dbfec418159af5795aa626d57ab Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Tue, 19 May 2026 03:06:59 -0700 Subject: [PATCH 09/13] Revert "Fix RAI test for OpenAIPrompt" This reverts commit 3ed60449fe0ef594e742b24c701040e2b3624e8d. --- .../openai/OpenAIChatCompletionSuite.scala | 28 ------------------- .../services/openai/OpenAIPromptSuite.scala | 20 ++++--------- 2 files changed, 5 insertions(+), 43 deletions(-) diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala index d540cd71a55..60b1bc54018 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala @@ -539,34 +539,6 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion] testCompletion(completion, goodDf) } - test("content filtering identifies empty assistant content") { - val responseJson = - """{ - | "id":"chatcmpl_test", - | "object":"chat.completion", - | "created":"1", - | "model":"gpt-4.1", - | "choices":[ - | { - | "message":{"role":"assistant","content":null,"name":null}, - | "index":0, - | "finish_reason":"content_filter" - | } - | ], - | "system_fingerprint":null, - | "usage":null - |}""".stripMargin - - val outputRow = spark.read - .schema(ChatModelResponse.schema) - .json(Seq(responseJson).toDS) - .collect() - .head - val completion = new OpenAIChatCompletion() - assert(completion.isContentFiltered(outputRow)) - assert(completion.getFilterReason(outputRow) == "content_filter") - } - ignore("Custom EndPoint") { lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "") lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "") diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala index 1942bbdb71c..35b425c973b 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala @@ -112,27 +112,17 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } // scalastyle:on null - private def hasNonEmptyError(row: Row, errorColName: String): Boolean = { - Option(row.getAs[Row](errorColName)) - .flatMap(error => Option(error.getAs[String]("response"))) - .exists(_.nonEmpty) - } - test("RAI Usage") { - val rows = prompt + val result = prompt .setDeploymentName(deploymentName) .setPromptTemplate("Tell me about a graphically disgusting " + "and violent movie in detail, " + "be very gory and NSFW in your description.") .transform(df) - .select(col(prompt.getOutputCol), col(prompt.getErrorCol)) - .collect() - - assert(rows.length == 3) - rows.foreach { row => - val hasOutput = Option(row.getAs[String](prompt.getOutputCol)).exists(_.nonEmpty) - assert(hasOutput || hasNonEmptyError(row, prompt.getErrorCol)) - } + .where(col(prompt.getErrorCol).isNotNull) + .select(prompt.getErrorCol) + .collect().head.getAs[Row](0) + assert(Option(result).nonEmpty) } test("Basic Usage") { From 987484ceee2c2dd89b1b04ef104b0c716695ef6a Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Tue, 19 May 2026 03:07:12 -0700 Subject: [PATCH 10/13] Revert "Remove deprecation warnings" This reverts commit 9a40c5c21ff9ffbdb0ed04cd3eb10703a5569165. --- .../synapse/ml/causal/DoubleMLEstimator.scala | 2 +- .../ml/causal/OrthoForestDMLEstimator.scala | 2 +- .../ml/causal/ResidualTransformer.scala | 4 ++-- .../ml/core/utils/CloseableIterator.scala | 3 +++ .../azure/synapse/ml/param/GlobalParams.scala | 23 +++++++++++-------- 5 files changed, 20 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala index dffe6246823..738b7ffeed1 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala @@ -246,7 +246,7 @@ class DoubleMLEstimator(override val uid: String) 4. Cross-fit treatment and outcome models with the second split, residual model with the first split. 5. Average slopes from the two residual models. */ - val splits = dataset.toDF().randomSplit(getSampleSplitRatio) + val splits = dataset.randomSplit(getSampleSplitRatio) val (train, test) = (splits(0).cache, splits(1).cache) val residualsDF1 = calculateResiduals(train, test).select(outcomeResidualCol, treatmentResidualVecCol) val residualsDF2 = calculateResiduals(test, train).select(outcomeResidualCol, treatmentResidualVecCol) diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/OrthoForestDMLEstimator.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/OrthoForestDMLEstimator.scala index cbb82267888..46c7e4a9593 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/OrthoForestDMLEstimator.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/OrthoForestDMLEstimator.scala @@ -104,7 +104,7 @@ class OrthoForestDMLEstimator(override val uid: String) 4. Cross-fit treatment and outcome models with the second split, residual model with the first split. 5. Average slopes from the two residual models is eqiuivalent to fitting one tree */ - val splits = dataset.toDF().randomSplit(getSampleSplitRatio) + val splits = dataset.randomSplit(getSampleSplitRatio) val (train, test) = (splits(0).cache, splits(1).cache) val residualsDF1 = calculateResiduals(train, test) val residualsDF2 = calculateResiduals(test, train) diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/ResidualTransformer.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/ResidualTransformer.scala index 5248b944554..de6dfe9d3f9 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/ResidualTransformer.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/ResidualTransformer.scala @@ -72,9 +72,9 @@ class ResidualTransformer(override val uid: String) extends Transformer s"${this.getClass.getSimpleName}: " + s"observedCol must be of type DoubleType, LongType, IntegerType or BooleanType but got $observedColType") - val convertedDataset: DataFrame = if (observedColType == BooleanType) { + val convertedDataset = if (observedColType == BooleanType) { dataset.withColumn(getObservedCol, col(getObservedCol).cast(IntegerType)) - } else dataset.toDF() + } else dataset val predictedColDataType = convertedDataset.schema(getPredictedCol).dataType diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/core/utils/CloseableIterator.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/core/utils/CloseableIterator.scala index e541079eee9..68656f6ff0c 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/core/utils/CloseableIterator.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/core/utils/CloseableIterator.scala @@ -26,6 +26,9 @@ class CloseableIterator[+T](delegate: Iterator[T], cleanup: => Unit) extends Ite catch { case _: Throwable => } + + super.finalize() } } //scalastyle:on no.finalize + diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala index 8d888df85d9..9ff42cd8b28 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala @@ -10,31 +10,34 @@ import scala.collection.mutable trait GlobalKey[T] object GlobalParams { - private val ParamToKeyMap: mutable.Map[Any, GlobalKey[Any]] = mutable.Map.empty - private val GlobalParams: mutable.Map[GlobalKey[Any], Any] = mutable.Map.empty + private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty + private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty - private def untypedKey[T](key: GlobalKey[T]): GlobalKey[Any] = { - key.asInstanceOf[GlobalKey[Any]] - } def setGlobalParam[T](key: GlobalKey[T], value: T): Unit = { - GlobalParams(untypedKey(key)) = value + GlobalParams(key) = value } def getGlobalParam[T](key: GlobalKey[T]): Option[T] = { - GlobalParams.get(untypedKey(key)).map(_.asInstanceOf[T]) + GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[T]) } def resetGlobalParam[T](key: GlobalKey[T]): Unit = { - GlobalParams -= untypedKey(key) + GlobalParams -= key } def getParam[T](p: Param[T]): Option[T] = { - ParamToKeyMap.get(p).flatMap(GlobalParams.get).map(_.asInstanceOf[T]) + ParamToKeyMap.get(p).flatMap { key => + key match { + case k: GlobalKey[T] => + getGlobalParam(k) + case _ => None + } + } } def registerParam[T](p: Param[T], key: GlobalKey[T]): Unit = { - ParamToKeyMap(p) = untypedKey(key) + ParamToKeyMap(p) = key } } From 59226acd367f5227ce0fb36115f0f8a86b5b7e03 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Tue, 19 May 2026 03:23:37 -0700 Subject: [PATCH 11/13] Reapply "Remove deprecation warnings" This reverts commit 987484ceee2c2dd89b1b04ef104b0c716695ef6a. --- .../synapse/ml/causal/DoubleMLEstimator.scala | 2 +- .../ml/causal/OrthoForestDMLEstimator.scala | 2 +- .../ml/causal/ResidualTransformer.scala | 4 ++-- .../ml/core/utils/CloseableIterator.scala | 3 --- .../azure/synapse/ml/param/GlobalParams.scala | 23 ++++++++----------- 5 files changed, 14 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala index 738b7ffeed1..dffe6246823 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala @@ -246,7 +246,7 @@ class DoubleMLEstimator(override val uid: String) 4. Cross-fit treatment and outcome models with the second split, residual model with the first split. 5. Average slopes from the two residual models. */ - val splits = dataset.randomSplit(getSampleSplitRatio) + val splits = dataset.toDF().randomSplit(getSampleSplitRatio) val (train, test) = (splits(0).cache, splits(1).cache) val residualsDF1 = calculateResiduals(train, test).select(outcomeResidualCol, treatmentResidualVecCol) val residualsDF2 = calculateResiduals(test, train).select(outcomeResidualCol, treatmentResidualVecCol) diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/OrthoForestDMLEstimator.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/OrthoForestDMLEstimator.scala index 46c7e4a9593..cbb82267888 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/OrthoForestDMLEstimator.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/OrthoForestDMLEstimator.scala @@ -104,7 +104,7 @@ class OrthoForestDMLEstimator(override val uid: String) 4. Cross-fit treatment and outcome models with the second split, residual model with the first split. 5. Average slopes from the two residual models is eqiuivalent to fitting one tree */ - val splits = dataset.randomSplit(getSampleSplitRatio) + val splits = dataset.toDF().randomSplit(getSampleSplitRatio) val (train, test) = (splits(0).cache, splits(1).cache) val residualsDF1 = calculateResiduals(train, test) val residualsDF2 = calculateResiduals(test, train) diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/ResidualTransformer.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/ResidualTransformer.scala index de6dfe9d3f9..5248b944554 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/ResidualTransformer.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/ResidualTransformer.scala @@ -72,9 +72,9 @@ class ResidualTransformer(override val uid: String) extends Transformer s"${this.getClass.getSimpleName}: " + s"observedCol must be of type DoubleType, LongType, IntegerType or BooleanType but got $observedColType") - val convertedDataset = if (observedColType == BooleanType) { + val convertedDataset: DataFrame = if (observedColType == BooleanType) { dataset.withColumn(getObservedCol, col(getObservedCol).cast(IntegerType)) - } else dataset + } else dataset.toDF() val predictedColDataType = convertedDataset.schema(getPredictedCol).dataType diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/core/utils/CloseableIterator.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/core/utils/CloseableIterator.scala index 68656f6ff0c..e541079eee9 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/core/utils/CloseableIterator.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/core/utils/CloseableIterator.scala @@ -26,9 +26,6 @@ class CloseableIterator[+T](delegate: Iterator[T], cleanup: => Unit) extends Ite catch { case _: Throwable => } - - super.finalize() } } //scalastyle:on no.finalize - diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala index 9ff42cd8b28..8d888df85d9 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala @@ -10,34 +10,31 @@ import scala.collection.mutable trait GlobalKey[T] object GlobalParams { - private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty - private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty + private val ParamToKeyMap: mutable.Map[Any, GlobalKey[Any]] = mutable.Map.empty + private val GlobalParams: mutable.Map[GlobalKey[Any], Any] = mutable.Map.empty + private def untypedKey[T](key: GlobalKey[T]): GlobalKey[Any] = { + key.asInstanceOf[GlobalKey[Any]] + } def setGlobalParam[T](key: GlobalKey[T], value: T): Unit = { - GlobalParams(key) = value + GlobalParams(untypedKey(key)) = value } def getGlobalParam[T](key: GlobalKey[T]): Option[T] = { - GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[T]) + GlobalParams.get(untypedKey(key)).map(_.asInstanceOf[T]) } def resetGlobalParam[T](key: GlobalKey[T]): Unit = { - GlobalParams -= key + GlobalParams -= untypedKey(key) } def getParam[T](p: Param[T]): Option[T] = { - ParamToKeyMap.get(p).flatMap { key => - key match { - case k: GlobalKey[T] => - getGlobalParam(k) - case _ => None - } - } + ParamToKeyMap.get(p).flatMap(GlobalParams.get).map(_.asInstanceOf[T]) } def registerParam[T](p: Param[T], key: GlobalKey[T]): Unit = { - ParamToKeyMap(p) = key + ParamToKeyMap(p) = untypedKey(key) } } From 867c9800ede7b50157947abd1cb2e90a702f3750 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Tue, 19 May 2026 03:23:42 -0700 Subject: [PATCH 12/13] Reapply "Fix RAI test for OpenAIPrompt" This reverts commit f06f1ade547f4dbfec418159af5795aa626d57ab. --- .../openai/OpenAIChatCompletionSuite.scala | 28 +++++++++++++++++++ .../services/openai/OpenAIPromptSuite.scala | 20 +++++++++---- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala index 60b1bc54018..d540cd71a55 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala @@ -539,6 +539,34 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion] testCompletion(completion, goodDf) } + test("content filtering identifies empty assistant content") { + val responseJson = + """{ + | "id":"chatcmpl_test", + | "object":"chat.completion", + | "created":"1", + | "model":"gpt-4.1", + | "choices":[ + | { + | "message":{"role":"assistant","content":null,"name":null}, + | "index":0, + | "finish_reason":"content_filter" + | } + | ], + | "system_fingerprint":null, + | "usage":null + |}""".stripMargin + + val outputRow = spark.read + .schema(ChatModelResponse.schema) + .json(Seq(responseJson).toDS) + .collect() + .head + val completion = new OpenAIChatCompletion() + assert(completion.isContentFiltered(outputRow)) + assert(completion.getFilterReason(outputRow) == "content_filter") + } + ignore("Custom EndPoint") { lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "") lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "") diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala index 35b425c973b..1942bbdb71c 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala @@ -112,17 +112,27 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } // scalastyle:on null + private def hasNonEmptyError(row: Row, errorColName: String): Boolean = { + Option(row.getAs[Row](errorColName)) + .flatMap(error => Option(error.getAs[String]("response"))) + .exists(_.nonEmpty) + } + test("RAI Usage") { - val result = prompt + val rows = prompt .setDeploymentName(deploymentName) .setPromptTemplate("Tell me about a graphically disgusting " + "and violent movie in detail, " + "be very gory and NSFW in your description.") .transform(df) - .where(col(prompt.getErrorCol).isNotNull) - .select(prompt.getErrorCol) - .collect().head.getAs[Row](0) - assert(Option(result).nonEmpty) + .select(col(prompt.getOutputCol), col(prompt.getErrorCol)) + .collect() + + assert(rows.length == 3) + rows.foreach { row => + val hasOutput = Option(row.getAs[String](prompt.getOutputCol)).exists(_.nonEmpty) + assert(hasOutput || hasNonEmptyError(row, prompt.getErrorCol)) + } } test("Basic Usage") { From 97ac819113e94d5782b03f4a1a8071f28c015ae5 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Tue, 19 May 2026 03:23:45 -0700 Subject: [PATCH 13/13] Reapply "Add OpenAICompletion deprecation" This reverts commit 10715cd4c20ef80b0b318e27710a3027f5bd5236. --- .../ml/services/openai/OpenAICompletion.py | 29 ++++++++ .../openai/test_OpenAICompletionDeprecated.py | 68 +++++++++++++++++++ .../azure/synapse/ml/codegen/PyCodegen.scala | 29 +++++++- 3 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 cognitive/src/main/python/synapse/ml/services/openai/OpenAICompletion.py create mode 100644 cognitive/src/test/python/synapsemltest/services/openai/test_OpenAICompletionDeprecated.py diff --git a/cognitive/src/main/python/synapse/ml/services/openai/OpenAICompletion.py b/cognitive/src/main/python/synapse/ml/services/openai/OpenAICompletion.py new file mode 100644 index 00000000000..6b3df7ec9c2 --- /dev/null +++ b/cognitive/src/main/python/synapse/ml/services/openai/OpenAICompletion.py @@ -0,0 +1,29 @@ +# Copyright (C) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE in project root for information. + +import warnings + +__all__ = ["OpenAICompletion"] + +_OPENAI_COMPLETION_DEPRECATION_MESSAGE = ( + "OpenAICompletion has been removed because the legacy OpenAI Completions API " + "is deprecated and retired. Use OpenAIResponses, OpenAIChatCompletion, or " + "OpenAIPrompt with setApiType('chat_completions') or setApiType('responses') instead." +) + + +def warn_openai_completion_deprecated(stacklevel=2): + warnings.warn( + _OPENAI_COMPLETION_DEPRECATION_MESSAGE, + FutureWarning, + stacklevel=stacklevel, + ) + + +warn_openai_completion_deprecated(stacklevel=2) + + +class OpenAICompletion: + def __init__(self, *args, **kwargs): + warn_openai_completion_deprecated(stacklevel=2) + raise RuntimeError(_OPENAI_COMPLETION_DEPRECATION_MESSAGE) diff --git a/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAICompletionDeprecated.py b/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAICompletionDeprecated.py new file mode 100644 index 00000000000..075359b1d79 --- /dev/null +++ b/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAICompletionDeprecated.py @@ -0,0 +1,68 @@ +# Copyright (C) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE in project root for information. + +import sys +import unittest +import warnings + +_MODULE_NAME = "synapse.ml.services.openai.OpenAICompletion" +_PACKAGE_NAME = "synapse.ml.services.openai" +_WARNING_TEXT = "OpenAICompletion has been removed" + + +def _clear_openai_completion_imports(): + sys.modules.pop(_MODULE_NAME, None) + package = sys.modules.get(_PACKAGE_NAME) + if package is not None: + package.__dict__.pop("OpenAICompletion", None) + + +def _has_openai_completion_warning(caught): + return any( + issubclass(warning.category, FutureWarning) + and _WARNING_TEXT in str(warning.message) + for warning in caught + ) + + +class TestOpenAICompletionDeprecated(unittest.TestCase): + def test_package_import_warns(self): + _clear_openai_completion_imports() + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + from synapse.ml.services.openai import OpenAICompletion + + package = sys.modules[_PACKAGE_NAME] + if hasattr(package, "__getattr__"): + self.assertIsInstance(OpenAICompletion, type) + else: + self.assertIsInstance(OpenAICompletion.OpenAICompletion, type) + self.assertTrue(_has_openai_completion_warning(caught)) + + def test_submodule_import_warns(self): + _clear_openai_completion_imports() + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + from synapse.ml.services.openai.OpenAICompletion import OpenAICompletion + + self.assertIsInstance(OpenAICompletion, type) + self.assertTrue(_has_openai_completion_warning(caught)) + + def test_instantiation_warns_and_raises(self): + _clear_openai_completion_imports() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from synapse.ml.services.openai.OpenAICompletion import OpenAICompletion + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + with self.assertRaisesRegex(RuntimeError, _WARNING_TEXT): + OpenAICompletion() + + self.assertTrue(_has_openai_completion_warning(caught)) + + +if __name__ == "__main__": + result = unittest.main() diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala index 425d7314f6f..e316202f80b 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala @@ -18,6 +18,32 @@ object PyCodegen { import CodeGenUtils._ + private val DeprecatedOpenAICompletionFile = "OpenAICompletion.py" + + private val OpenAICompletionImportHook: String = + """ + |def __getattr__(name): + | if name == "OpenAICompletion": + | import warnings + | + | with warnings.catch_warnings(): + | warnings.simplefilter("ignore", FutureWarning) + | from synapse.ml.services.openai.OpenAICompletion import ( + | OpenAICompletion, + | warn_openai_completion_deprecated, + | ) + | warn_openai_completion_deprecated(stacklevel=2) + | globals()["OpenAICompletion"] = OpenAICompletion + | return OpenAICompletion + | raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + |""".stripMargin + + private def isOpenAICompletionStub(packageFolder: String, fileName: String): Boolean = + packageFolder == "/services/openai" && fileName == DeprecatedOpenAICompletionFile + + private def initFileExtra(packageFolder: String): String = + if (packageFolder == "/services/openai") OpenAICompletionImportHook else "" + def generatePythonClasses(conf: CodegenConfig): Unit = { val instantiatedClasses = instantiateServices[PythonWrappable](conf.jarName) instantiatedClasses.foreach { w => @@ -37,12 +63,13 @@ object PyCodegen { dir.listFiles.filter(_.isFile).sorted .map(_.getName) .filter(name => name.endsWith(".py") && !name.startsWith("_") && !name.startsWith("test")) + .filterNot(name => isOpenAICompletionStub(packageFolder, name)) .map(name => s"from synapse.ml$packageString.${getBaseName(name)} import *\n").mkString("") } val initFile = new File(dir, "__init__.py") if (packageFolder != "/cognitive"){ if (packageFolder != "") { - writeFile(initFile, conf.packageHelp(importStrings)) + writeFile(initFile, conf.packageHelp(importStrings) + initFileExtra(packageFolder)) } else if (initFile.exists()) { initFile.delete() }