diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/search/AzureSearch.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/search/AzureSearch.scala index 5471329997..03cb9d0bdf 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/search/AzureSearch.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/search/AzureSearch.scala @@ -20,7 +20,8 @@ import org.apache.spark.ml.util._ import org.apache.spark.ml.{ComplexParamsReadable, NamespaceInjections, PipelineModel} import org.apache.spark.ml.linalg.SQLDataTypes.VectorType import org.apache.spark.ml.functions.vector_to_array -import org.apache.spark.sql.functions.{col, expr, struct, to_json, to_utc_timestamp, date_format, when} +import org.apache.spark.sql.functions.{col, expr, from_json, struct, to_json, to_utc_timestamp, + date_format, when} import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -249,6 +250,57 @@ object AzureSearchWriter extends IndexParser with IndexJsonGetter with SLogging } } + /** + * Converts string columns containing GeoJSON to the proper struct shape required for + * Azure Search `Edm.GeographyPoint` fields. + * + * Azure AI Search expects spatial values to be sent as a GeoJSON object + * (e.g. `{"type":"Point","coordinates":[lon, lat]}`), not as a JSON-encoded string. + * Users frequently have their GeoJSON readily available as a string column, and + * passing it as a `StringType` previously caused a `400 Bad Request` + * (see [[https://github.com/microsoft/SynapseML/issues/2420]]) because the writer + * JSON-escaped the entire string. + * + * For each '''top-level''' field declared as `Edm.GeographyPoint` in the index, if the + * corresponding DataFrame column is a `StringType`, parse it into the canonical + * `StructType(type: StringType, coordinates: ArrayType(DoubleType))` so that downstream + * `to_json` emits a proper GeoJSON object. Columns that are already structured are + * left as-is. GeographyPoint fields nested inside complex types are not auto-converted + * (mirrors the existing top-level-only handling in `convertDateTimeToISO8601`). + * + * Parsing uses Spark's `FAILFAST` mode so malformed GeoJSON surfaces an explicit + * exception instead of being silently coerced to `null` and shipped to Azure Search. + * + * @param df DataFrame with potential GeographyPoint columns + * @param indexJson JSON string containing the index schema + * @return DataFrame with string GeographyPoint columns converted to GeoJSON structs + */ + private[ml] def convertGeographyPointToStruct(df: DataFrame, indexJson: String): DataFrame = { + val geoStructType = StructType(Seq( + StructField("type", StringType), + StructField("coordinates", ArrayType(DoubleType)) + )) + val parseOptions = Map("mode" -> "FAILFAST") + val geoFields = parseIndexJson(indexJson).fields + .filter(_.`type` == "Edm.GeographyPoint") + .map(_.name) + geoFields.foldLeft(df) { (currentDF, fieldName) => + if (currentDF.columns.contains(fieldName)) { + currentDF.schema(fieldName).dataType match { + case StringType => + currentDF.withColumn(fieldName, + when(col(fieldName).isNotNull, from_json(col(fieldName), geoStructType, parseOptions)) + ) + case _ => + // Already a struct (or otherwise compatible); checkSchemaParity will validate. + currentDF + } + } else { + currentDF + } + } + } + private def dfToIndexJson(schema: StructType, indexName: String, keyCol: String, @@ -328,17 +380,18 @@ object AzureSearchWriter extends IndexParser with IndexJsonGetter with SLogging SearchIndex.createIfNoneExists(subscriptionKey, serviceName, indexJson, apiVersion) val dateConvertedDF = convertDateTimeToISO8601(preppedDF, indexJson) + val geoConvertedDF = convertGeographyPointToStruct(dateConvertedDF, indexJson) logInfo("checking schema parity") - checkSchemaParity(dateConvertedDF.schema, indexJson, actionCol) + checkSchemaParity(geoConvertedDF.schema, indexJson, actionCol) val df1 = if (filterNulls) { val collectionColumns = parseIndexJson(indexJson).fields .filter(_.`type`.startsWith("Collection")) .map(_.name) - collectionColumns.foldLeft(dateConvertedDF) { (ndf, c) => filterOutNulls(ndf, c) } + collectionColumns.foldLeft(geoConvertedDF) { (ndf, c) => filterOutNulls(ndf, c) } } else { - dateConvertedDF + geoConvertedDF } // Convert date/timestamp columns to ISO8601 strings for Azure Search diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/search/split2/SearchWriterSuitePart2.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/search/split2/SearchWriterSuitePart2.scala index 54ace956e1..55d161aefd 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/search/split2/SearchWriterSuitePart2.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/search/split2/SearchWriterSuitePart2.scala @@ -171,4 +171,123 @@ class SearchWriterSuite extends SearchWriterSuiteUtilities { } + test("Handle GeoJSON GeographyPoint fields supplied as strings") { + + val in = generateIndexName() + val df = spark.createDataFrame(Seq( + ("upload", "0", """{"type":"Point","coordinates":[-122.3493, 47.6205]}"""), + ("upload", "1", """{"type":"Point","coordinates":[-122.3351, 47.6080]}""") + )).toDF("searchAction", "id", "location") + + val indexJson = + s""" + |{ + | "name": "$in", + | "fields": [ + | { "name": "id", "type": "Edm.String", "key": true, "searchable": true, "retrievable": true }, + | { "name": "location", "type": "Edm.GeographyPoint", "searchable": false, + | "filterable": true, "retrievable": true, "sortable": true } + | ] + |} + |""".stripMargin + + AzureSearchWriter.write(df, + Map( + "subscriptionKey" -> azureSearchKey, + "actionCol" -> "searchAction", + "serviceName" -> testServiceName, + "indexJson" -> indexJson + ) + ) + + // With fatalErrors=true (default) any 400 from Azure Search becomes a thrown + // RuntimeException, so reaching this `assertSize` proves the documents were + // accepted as valid spatial objects -- a count of 2 is only achievable if the + // GeoJSON strings were correctly parsed and serialized as GeoJSON objects. + retryWithBackoff(assertSize(in, 2)) + + } + + test("convertGeographyPointToStruct parses GeoJSON strings into structs") { + val df = spark.createDataFrame(Seq( + ("0", """{"type":"Point","coordinates":[-122.3493, 47.6205]}"""), + ("1", null) + )).toDF("id", "location") + + val indexJson = + """ + |{ + | "name": "unit-test-geo", + | "fields": [ + | { "name": "id", "type": "Edm.String", "key": true }, + | { "name": "location", "type": "Edm.GeographyPoint" } + | ] + |} + |""".stripMargin + + val converted = AzureSearchWriter.convertGeographyPointToStruct(df, indexJson) + val expected = StructType(Seq( + StructField("type", StringType), + StructField("coordinates", ArrayType(DoubleType)) + )) + assert(converted.schema("location").dataType == expected) + + val rows = converted.orderBy("id").collect() + val parsed = rows.head.getStruct(rows.head.fieldIndex("location")) + assert(parsed.getString(0) == "Point") + assert(parsed.getSeq[Double](1) == Seq(-122.3493, 47.6205)) + assert(rows(1).isNullAt(rows(1).fieldIndex("location"))) + } + + test("convertGeographyPointToStruct leaves struct columns untouched") { + val schema = StructType(Seq( + StructField("id", StringType), + StructField("location", StructType(Seq( + StructField("type", StringType, nullable = false), + StructField("coordinates", ArrayType(DoubleType, containsNull = false), nullable = false) + ))) + )) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("0", Row("Point", Seq(-122.3493, 47.6205))))), + schema + ) + + val indexJson = + """ + |{ + | "name": "unit-test-geo", + | "fields": [ + | { "name": "id", "type": "Edm.String", "key": true }, + | { "name": "location", "type": "Edm.GeographyPoint" } + | ] + |} + |""".stripMargin + + val converted = AzureSearchWriter.convertGeographyPointToStruct(df, indexJson) + assert(converted.schema("location").dataType == schema("location").dataType) + } + + test("convertGeographyPointToStruct fails fast on malformed GeoJSON instead of silently nulling") { + val df = spark.createDataFrame(Seq( + ("0", "{not valid json") + )).toDF("id", "location") + + val indexJson = + """ + |{ + | "name": "unit-test-geo", + | "fields": [ + | { "name": "id", "type": "Edm.String", "key": true }, + | { "name": "location", "type": "Edm.GeographyPoint" } + | ] + |} + |""".stripMargin + + val converted = AzureSearchWriter.convertGeographyPointToStruct(df, indexJson) + // FAILFAST surfaces parse errors when the row is materialized, not at plan time. + intercept[org.apache.spark.SparkException] { + converted.collect() + } + } + }