diff --git a/CHANGELOG.md b/CHANGELOG.md index 34b37ce..335ebbd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,14 +1,19 @@ # Change Log All notable changes to this project will be documented in this file. +## 2.1.2 - 2026-01 + ### Maker + - Feature type normalization to Double and Long + +## 2.1.1 - 2026-01 + - Replaces 2.1.0 release + ## 2.1.0 - 2025-10 ### General - Updated python version to 3.12 and pyspark to 4.0 - Migrated from poetry to UV ### Runner - Added merge_schema manual override option - ### Maker - - Added another feature decorator _@template_ to support feature to text conversion ## 2.0.11 - 2025-08-12 ### Loader diff --git a/rialto/common/utils.py b/rialto/common/utils.py index 296cba8..875ed7d 100644 --- a/rialto/common/utils.py +++ b/rialto/common/utils.py @@ -16,12 +16,20 @@ import inspect import os -from typing import Any, List +from typing import Any import pyspark.sql.functions as F import yaml from pyspark.sql import DataFrame -from pyspark.sql.types import FloatType +from pyspark.sql.types import ( + ByteType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + ShortType, +) from rialto.common.env_yaml import EnvLoader @@ -62,12 +70,32 @@ def get_caller_module() -> Any: 0th entry is this function 1st entry is the function which needs to know who called it 2nd entry is the calling function - Therefore, we'll return a module which contains the function at the 2nd place on the stack. :return: Python Module containing the calling function. """ - stack = inspect.stack() last_stack = stack[2] return inspect.getmodule(last_stack[0]) + + +def normalize_types(df: DataFrame) -> DataFrame: + """ + Normalize data types in the DataFrame + + Converts all decimal columns to FloatType and + all integer columns to LongType. + """ + float_types = (FloatType, DecimalType) + int_types = (ByteType, ShortType, IntegerType) + + return df.select( + [ + F.col(f.name).cast(DoubleType()) + if isinstance(f.dataType, float_types) + else F.col(f.name).cast(LongType()) + if isinstance(f.dataType, int_types) + else F.col(f.name) + for f in df.schema.fields + ] + ) diff --git a/rialto/maker/feature_maker.py b/rialto/maker/feature_maker.py index 7aa6c85..5fad175 100644 --- a/rialto/maker/feature_maker.py +++ b/rialto/maker/feature_maker.py @@ -23,6 +23,7 @@ from loguru import logger from pyspark.sql import DataFrame +from rialto.common.utils import normalize_types from rialto.maker.containers import FeatureFunction, FeatureHolder @@ -140,7 +141,8 @@ def _make_sequential(self, keep_preexisting: bool) -> DataFrame: if not keep_preexisting: logger.info("Dropping non-selected columns") self.data_frame = self.data_frame.select(*self.key, *feature_names) - return self._filter_null_keys(self.data_frame) + df = self._filter_null_keys(self.data_frame) + return normalize_types(df) def _make_aggregated(self) -> DataFrame: """ @@ -154,7 +156,8 @@ def _make_aggregated(self) -> DataFrame: aggregates.append(feature_function.callable().alias(feature_function.get_feature_name())) self.data_frame = self.data_frame.groupBy(self.key).agg(*aggregates) - return self._filter_null_keys(self.data_frame) + df = self._filter_null_keys(self.data_frame) + return normalize_types(df) def make( self, @@ -237,7 +240,8 @@ def make_single_feature( self.make_date = make_date feature_functions = self._register_module(features_module) feature = self._find_feature(name, feature_functions) - return df.withColumn(feature.get_feature_name(), feature.callable()).select(feature.get_feature_name()) + df = df.withColumn(feature.get_feature_name(), feature.callable()).select(feature.get_feature_name()) + return normalize_types(df) def make_single_agg_feature( self, @@ -261,7 +265,8 @@ def make_single_agg_feature( self.make_date = make_date feature_functions = self._register_module(features_module) feature = self._find_feature(name, feature_functions) - return df.groupBy(key).agg(feature.callable().alias(feature.get_feature_name())) + df = df.groupBy(key).agg(feature.callable().alias(feature.get_feature_name())) + return normalize_types(df) FeatureMaker = _FeatureMaker() diff --git a/tests/common/test_utils.py b/tests/common/test_utils.py index 44aa0a3..f11344d 100644 --- a/tests/common/test_utils.py +++ b/tests/common/test_utils.py @@ -15,8 +15,9 @@ import pyspark.sql.functions as F import pytest from numpy import dtype +from pyspark.sql.types import DoubleType, LongType, StringType -from rialto.common.utils import cast_decimals_to_floats +from rialto.common.utils import cast_decimals_to_floats, normalize_types @pytest.fixture @@ -29,6 +30,16 @@ def sample_df(spark): return df.select("a", "b", "c", F.col("d").cast("decimal"), F.col("e").cast("decimal(18,5)")) +@pytest.fixture +def sample_df2(spark): + df = spark.createDataFrame( + [(1, 2.33, "str", 4.55, 5.66, 4), (1, 2.33, "str", 4.55, 5.66, 5), (1, 2.33, "str", 4.55, 5.66, 6)], + schema="a long, b float, c string, d float, e float, f int", + ) + + return df.select("a", "b", "c", F.col("d").cast("decimal"), F.col("e").cast("double"), "f") + + def test_cast_decimals_to_floats(sample_df): df_fixed = cast_decimals_to_floats(sample_df) @@ -42,3 +53,14 @@ def test_cast_decimals_to_floats_topandas_works(sample_df): assert df_pd.dtypes.iloc[3] == dtype("float32") assert df_pd.dtypes.iloc[4] == dtype("float32") + + +def test_normalize_types(sample_df2): + df_fixed = normalize_types(sample_df2) + + assert isinstance(df_fixed.schema["a"].dataType, LongType) + assert isinstance(df_fixed.schema["b"].dataType, DoubleType) + assert isinstance(df_fixed.schema["c"].dataType, StringType) + assert isinstance(df_fixed.schema["d"].dataType, DoubleType) + assert isinstance(df_fixed.schema["e"].dataType, DoubleType) + assert isinstance(df_fixed.schema["f"].dataType, LongType)