Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

import warnings

__all__ = ["OpenAICompletion"]

_OPENAI_COMPLETION_DEPRECATION_MESSAGE = (
"OpenAICompletion has been removed because the legacy OpenAI Completions API "
"is deprecated and retired. Use OpenAIResponses, OpenAIChatCompletion, or "
"OpenAIPrompt with setApiType('chat_completions') or setApiType('responses') instead."
)


def warn_openai_completion_deprecated(stacklevel=2):
warnings.warn(
_OPENAI_COMPLETION_DEPRECATION_MESSAGE,
FutureWarning,
stacklevel=stacklevel,
)


warn_openai_completion_deprecated(stacklevel=2)


class OpenAICompletion:
def __init__(self, *args, **kwargs):
warn_openai_completion_deprecated(stacklevel=2)
raise RuntimeError(_OPENAI_COMPLETION_DEPRECATION_MESSAGE)
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@ class AIFoundryChatCompletion(override val uid: String) extends OpenAIChatComple
setUrl(s"https://$v.services.ai.azure.com/" + urlPath.stripPrefix("/"))
}

override protected def prepareUrlRoot: Row => String = { row =>
s"${getUrl}models/chat/completions"
override protected def prepareUrlRoot: Row => String = { _ =>
endpointUrl("models/chat/completions")
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,9 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spray.json.DefaultJsonProtocol._

import java.util.Locale
import scala.language.existentials

trait HasPromptInputs extends HasServiceParams {
val prompt: ServiceParam[String] = new ServiceParam[String](
this, "prompt", "The text to complete", isRequired = false)

def getPrompt: String = getScalarParam(prompt)

def setPrompt(v: String): this.type = setScalarParam(prompt, v)

def getPromptCol: String = getVectorParam(prompt)

def setPromptCol(v: String): this.type = setVectorParam(prompt, v)

val batchPrompt: ServiceParam[Seq[String]] = new ServiceParam[Seq[String]](
this, "batchPrompt", "Sequence of prompts to complete", isRequired = false)

def getBatchPrompt: Seq[String] = getScalarParam(batchPrompt)

def setBatchPrompt(v: Seq[String]): this.type = setScalarParam(batchPrompt, v)

def getBatchPromptCol: String = getVectorParam(batchPrompt)

def setBatchPromptCol(v: String): this.type = setVectorParam(batchPrompt, v)

}

trait HasMessagesInput extends Params {
val messagesCol: Param[String] = new Param[String](
this, "messagesCol", "The column messages to generate chat completions for," +
Expand All @@ -54,6 +30,29 @@ trait HasMessagesInput extends Params {
case object OpenAIDeploymentNameKey extends GlobalKey[Either[String, String]]
case object OpenAIEmbeddingDeploymentNameKey extends GlobalKey[Either[String, String]]

private[openai] object OpenAIEndpointUtils {
private def stripTrailingSlashes(value: String): String = value.replaceAll("/+$", "")

private def withoutQueryOrFragment(value: String): String = {
val stopAt = Seq(value.indexOf("?"), value.indexOf("#")).filter(_ >= 0) match {
case Seq() => value.length
case indexes => indexes.min
}
value.take(stopAt)
}

def appendPath(baseUrl: String, path: String): String = {
val separator = if (baseUrl.endsWith("/")) "" else "/"
baseUrl + separator + path.stripPrefix("/")
}

def isV1BaseUrl(baseUrl: String): Boolean = {
stripTrailingSlashes(withoutQueryOrFragment(baseUrl))
.toLowerCase(Locale.ROOT)
.endsWith("/v1")
}
}

trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion {

val deploymentName = new ServiceParam[String](
Expand Down Expand Up @@ -137,7 +136,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
"The maximum number of completion tokens to generate. Has minimum of 0." +
" Works with both reasoning and non-reasoning models." +
" Sent as max_completion_tokens for chat completions," +
" max_output_tokens for responses API, and max_tokens for legacy completions.",
" and max_output_tokens for responses API.",
isRequired = false) {
override val payloadName: String = "max_completion_tokens"
}
Expand Down Expand Up @@ -456,6 +455,39 @@ abstract class OpenAIServicesBase(override val uid: String) extends CognitiveSer
with HasOpenAISharedParams with OpenAIFabricSetting {
setDefault(timeout -> 360.0)

override def setUrl(value: String): this.type = set(url, value)

protected[openai] def isOpenAIV1BaseUrl: Boolean =
get(url).orElse(getDefault(url)).exists(OpenAIEndpointUtils.isV1BaseUrl)

protected[openai] def endpointUrl(path: String): String = OpenAIEndpointUtils.appendPath(getUrl, path)

protected[openai] def withV1DeploymentModel(params: Map[String, Any], row: Row): Map[String, Any] = {
if (isOpenAIV1BaseUrl && !params.contains("model")) {
params.updated("model", getValue(row, deploymentName))
} else {
params
}
}

private def warnIfV1ApiVersionConfigured(): Unit = {
if (isOpenAIV1BaseUrl && (get(apiVersion).nonEmpty || GlobalParams.getParam(apiVersion).nonEmpty)) {
logWarning(
"apiVersion is ignored when the OpenAI URL is a v1 base URL. " +
"Remove apiVersion or use a non-v1 endpoint.")
}
}

override protected def getUrlParams: Array[ServiceParam[_]] = {
val params = super.getUrlParams
if (isOpenAIV1BaseUrl) {
warnIfV1ApiVersionConfigured()
params.filterNot(_.name == apiVersion.name)
} else {
params
}
}

private def usingDefaultOpenAIEndpoint(): Boolean = {
getUrl == FabricClient.MLWorkloadEndpointML + "/cognitive/openai/"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(
}

override protected def prepareUrlRoot: Row => String = { row =>
s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/chat/completions"
if (isOpenAIV1BaseUrl) {
endpointUrl("chat/completions")
} else {
endpointUrl(s"openai/deployments/${getValue(row, deploymentName)}/chat/completions")
}
}

override private[ml] def getOptionalParams(r: Row): Map[String, Any] = {
Expand All @@ -125,7 +129,7 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(
r =>
lazy val optionalParams: Map[String, Any] = getOptionalParams(r)
val messages = r.getAs[Seq[Row]](getMessagesCol)
Some(getStringEntity(messages, optionalParams))
Some(getStringEntity(messages, withV1DeploymentModel(optionalParams, r)))
}

override val subscriptionKeyHeaderName: String = "api-key"
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ object OpenAIDefaults {
}

def setURL(v: String): Unit = {
val url = if (v.endsWith("/")) v else v + "/"
GlobalParams.setGlobalParam(URLKey, url)
GlobalParams.setGlobalParam(URLKey, v)
}

def getURL: Option[String] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,19 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)
}

override protected def prepareUrlRoot: Row => String = { row =>
val dep = getEmbeddingDeployment(row)
if (isOpenAIV1BaseUrl) {
endpointUrl("embeddings")
} else {
endpointUrl(s"openai/deployments/$dep/embeddings")
}
}

private[this] def getEmbeddingDeployment(row: Row): String = {
val globalEmbeddingDeployment =
GlobalParams.getGlobalParam(OpenAIEmbeddingDeploymentNameKey).flatMap(_.left.toOption)

val dep = globalEmbeddingDeployment.orElse {
globalEmbeddingDeployment.orElse {
// If embedding-specific deployment is not set, check instance param
if (isSet(deploymentName)) {
getValueOpt(row, deploymentName)
Expand All @@ -77,8 +86,6 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)
}.getOrElse(throw new IllegalArgumentException(
"No embedding deployment name provided. Set the 'deploymentName' param or call " +
"OpenAIDefaults.setEmbeddingDeploymentName('<your-embedding-deployment>') to set a global default."))

s"${getUrl}openai/deployments/$dep/embeddings"
}

private[this] def getStringEntity[A](text: A, optionalParams: Map[String, Any]): StringEntity = {
Expand All @@ -88,7 +95,10 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)

override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
r =>
lazy val optionalParams: Map[String, Any] = getOptionalParams(r)
lazy val optionalParams: Map[String, Any] = {
val params = getOptionalParams(r)
if (isOpenAIV1BaseUrl) params.updated("model", getEmbeddingDeployment(r)) else params
}
getValueOpt(r, text)
.map(text => getStringEntity(text, optionalParams))
.orElse(throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ class OpenAIPrompt(override val uid: String) extends Transformer
store -> Left(false)
)

override def setUrl(value: String): this.type = set(url, value)

override def setCustomServiceName(v: String): this.type = {
setUrl(s"https://$v.openai.azure.com/" + urlPath.stripPrefix("/"))
}
Expand Down Expand Up @@ -284,8 +286,6 @@ class OpenAIPrompt(override val uid: String) extends Transformer
df: DataFrame,
messagesCol: Column
): (DataFrame, String, OpenAIServicesBase with HasTextOutput) = {
// All services are now HasMessagesInput (OpenAIChatCompletion, OpenAIResponses, AIFoundryChatCompletion)
// Legacy OpenAICompletion did not support MessagesInput which is no longer used in this class.
val messagesService = service.asInstanceOf[HasMessagesInput]

if (isSet(responseFormat)) {
Expand Down Expand Up @@ -639,7 +639,12 @@ class OpenAIPrompt(override val uid: String) extends Transformer
host.exists(_.toLowerCase.endsWith("services.ai.azure.com"))
}

private[openai] def hasAIFoundryModel: Boolean = this.isDefined(model) && isAIFoundryEndpoint
private def isOpenAIV1Endpoint: Boolean = {
get(url).orElse(getDefault(url)).exists(OpenAIEndpointUtils.isV1BaseUrl)
}

private[openai] def hasAIFoundryModel: Boolean =
this.isDefined(model) && isAIFoundryEndpoint && !isOpenAIV1Endpoint

//deployment name can be set by user, it doesn't have to match with model name
private def getOpenAIChatService: OpenAIServicesBase with HasTextOutput = {
Expand All @@ -658,11 +663,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer
.filter(p => !localParamNames.contains(p.param.name) && completion.hasParam(p.param.name))
.foreach(p => completion.set(completion.getParam(p.param.name), p.value))

completion match {
case resp: OpenAIResponses
if this.isDefined(model) && get(deploymentName).orElse(getDefault(deploymentName)).isEmpty =>
resp.setDeploymentName(getModel)
case _ =>
if (this.isDefined(model) &&
get(deploymentName).orElse(getDefault(deploymentName)).isEmpty &&
(isOpenAIV1Endpoint || completion.isInstanceOf[OpenAIResponses])) {
completion.setDeploymentName(getModel)
}

completion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,11 @@ class OpenAIResponses(override val uid: String) extends OpenAIServicesBase(uid)
}

override protected def prepareUrlRoot: Row => String = { row =>
s"${getUrl}openai/responses"
if (isOpenAIV1BaseUrl) {
endpointUrl("responses")
} else {
endpointUrl("openai/responses")
}
}

override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = {
Expand All @@ -164,6 +168,9 @@ class OpenAIResponses(override val uid: String) extends OpenAIServicesBase(uid)
private def mergeModel(params: Map[String, Any], r: Row): Map[String, Any] = {
getValueOpt(r, deploymentName) match {
case Some(m) if m != null && m.nonEmpty => params.updated("model", m)
case _ if isOpenAIV1BaseUrl && !params.contains("model") =>
throw new IllegalArgumentException(
"No deployment/model name provided for OpenAI v1 endpoint. Set the 'deploymentName' param.")
case _ => params
}
}
Expand Down
Loading
Loading