Skip to content

Pyspark DecisionTreeRegressionModel bundle does not include all attributes #871

@anigmo97

Description

@anigmo97

Issue Description

Pyspark DecisionTreeRegressionModel loses values ​​in attributes after packaging and loading them.

Minimal Reproducible Example

mleap version: 0.23.1
pyspark version: 3.3.0
Python version: 3.10.6

import pyspark
import mleap
import mleap.pyspark
from mleap.pyspark.spark_support import SimpleSparkSerializer

from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import DecisionTreeRegressor, DecisionTreeRegressionModel

# Step 1: Create a Spark session
spark = SparkSession.builder\
         .config('spark.jars.packages', 'ml.combust.mleap:mleap-spark_2.12:0.23.1') \
        .getOrCreate()

# Step 2: Prepare Data
data = [(1.0, 2.0, 3.0), (2.0, 3.0, 4.0), (3.0, 4.0, 5.0)]
columns = ["feature1", "feature2", "label"]
df = spark.createDataFrame(data, columns)

# Step 3: Feature Vector Assembly
assembler = VectorAssembler(inputCols=["feature1", "feature2"], outputCol="features")
df = assembler.transform(df)

# Step 4: Split Data
(trainingData, testData) = df.randomSplit([0.8, 0.2], seed=1234)

# Step 5: Create and Train Decision Tree Model
dt = DecisionTreeRegressor(featuresCol="features", labelCol="label")
model = dt.fit(trainingData)

# Step 6: Make Predictions
predictions = model.transform(testData)

If we take a look to the created model, we can see that nodes have different attributes.

print(model._to_java().rootNode().toString())
print(model._java_obj.rootNode().toString())

InternalNode(prediction = 4.0, impurity = 0.6666666666666666, split = org.apache.spark.ml.tree.ContinuousSplit@3ff80000)
InternalNode(prediction = 4.0, impurity = 0.6666666666666666, split = org.apache.spark.ml.tree.ContinuousSplit@3ff80000)

If I save and load the model the results are:

model_path = f"{os.getcwd()}/tree_regressor.zip"
model.serializeToBundle(f"jar:file:{model_path}", predictions)
print(f"Model Saved as MLeap bundle at: {model_path}")

loaded_model = DecisionTreeRegressionModel.deserializeFromBundle(f"jar:file:{model_path}")

print(loaded_model._to_java().rootNode().toString())
print(loaded_model._java_obj.rootNode().toString())
print(loaded_model._to_java().rootNode().impurityStats())

InternalNode(prediction = 0.0, impurity = 0.0, split = org.apache.spark.ml.tree.ContinuousSplit@3ff80000)
InternalNode(prediction = 0.0, impurity = 0.0, split = org.apache.spark.ml.tree.ContinuousSplit@3ff80000)
None

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions