Conversation
…construction, and diagnostic rules
…and data integrity checks
There was a problem hiding this comment.
Pull request overview
Adds structured correctness validation and structured plan diagnostics to OptiSpark’s interactive optimization loop, with accompanying unit tests, to improve reliability of AI-generated DataFrame rewrites.
Changes:
- Introduces a new correctness validator (
validate_optimization) and failure-prompt builder for self-healing retries. - Adds a structured Spark execution plan parser + rule-based analyzer to replace brittle string heuristics.
- Integrates plan diagnostics + correctness validation into the agent sandbox execution path and enhances safety detection for
explode/salt_array.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_validator.py | New unit/integration-style tests for validator checks and failure prompt building. |
| tests/test_plan_analyzer.py | New tests for plan parsing, rule detection, and debug rendering. |
| src/optispark/validator.py | New correctness validator module (schema/row-count/sample diff/aggregate parity) + prompt builder. |
| src/optispark/safety.py | Replaces brittle string detection with AST-based detection for high-memory ops (explode/salt). |
| src/optispark/plan_analyzer.py | New plan parser + diagnostics rules + debug helpers. |
| src/optispark/agent.py | Wires plan diagnostics into context gathering and adds correctness validation/self-correction in sandbox execution. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| orig_count = original_df.limit(sample_size).count() | ||
| opt_count = optimized_df.limit(sample_size).count() | ||
|
|
||
| if orig_count == opt_count: | ||
| return { |
| # Rows in original not in optimized | ||
| only_in_orig = orig_sample.subtract(opt_sample) | ||
| orig_diff_count = only_in_orig.count() | ||
|
|
||
| # Rows in optimized not in original |
| # Handles lines like "BroadcastHashJoin [user_id#10], ..." or "*(2) HashAggregate(...)" | ||
| _NODE_TYPE_PATTERN = re.compile( | ||
| r'(?:\*\(\d+\)\s*)?' # optional WholeStageCodegen marker like *(2) | ||
| r'([A-Z][A-Za-z]*(?:\s[A-Za-z]+)?)' # PascalCase node type (1-2 words) | ||
| ) |
| Returns: | ||
| A list of PlanDiagnostic instances, deduplicated by rule_id. | ||
| """ |
| for n, t in fields: | ||
| if isinstance(t, (T.IntegerType, T.LongType, T.FloatType, T.DoubleType)): | ||
| total = sum(r.get(n, 0) for r in rows if r.get(n) is not None) | ||
| agg_row.__getitem__ = lambda self, key, _n=n, _t=total: _t if key == _n else 0 | ||
| agg_row.__getitem__ = MagicMock(side_effect=lambda key: sum(r.get(key, 0) for r in rows if r.get(key) is not None)) |
There was a problem hiding this comment.
Pull request overview
Adds a structured plan analyzer and a correctness validator to OptiSpark, and integrates them into the agent flow so generated optimizations are both safety-checked and semantically validated before being accepted.
Changes:
- Introduces
optispark.plan_analyzerto parse Spark physical plans into a node tree and emit structured diagnostics. - Introduces
optispark.validatorto compare original vs optimized DataFrames (schema, sample row count, sample data integrity, aggregate parity) and generate an LLM failure prompt. - Integrates plan diagnostics + correctness validation into
OptiSpark’s DataFrame introspection and sandbox execution; adds unit tests for analyzer/validator.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
src/optispark/plan_analyzer.py |
New plan parsing + rule-based diagnostics and debug printers. |
src/optispark/validator.py |
New correctness validation pipeline and failure-prompt builder. |
src/optispark/agent.py |
Uses plan diagnostics in context and validates optimized output before accepting; adds validation UI. |
src/optispark/safety.py |
Replaces brittle string checks with AST-based detection for explode/salt patterns. |
tests/test_plan_analyzer.py |
Adds unit tests covering parsing, rule detection, and debug output. |
tests/test_validator.py |
Adds unit tests for validator checks, confidence scoring, and failure prompt building. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Rows in original not in optimized | ||
| only_in_orig = orig_sample.subtract(opt_sample) | ||
| orig_diff_count = only_in_orig.count() | ||
|
|
||
| # Rows in optimized not in original | ||
| only_in_opt = opt_sample.subtract(orig_sample) | ||
| opt_diff_count = only_in_opt.count() |
There was a problem hiding this comment.
Resolved in bdf76eb. _check_data_integrity now uses exceptAll() (multiset/EXCEPT ALL semantics) instead of subtract() (EXCEPT DISTINCT), so duplicate-row multiplicity mismatches are correctly detected. Test mocks updated to match.
| # Relative difference check | ||
| orig_f = float(orig_val) | ||
| opt_f = float(opt_val) |
| def analyze_plan(roots: list, broadcast_threshold_bytes: int = 10 * 1024 * 1024) -> list: | ||
| """Run all detection rules over a parsed plan tree. | ||
|
|
||
| Args: | ||
| roots: List of root PlanNode instances from parse_plan(). | ||
| broadcast_threshold_bytes: The autoBroadcastJoinThreshold from the SparkSession. | ||
| Defaults to Spark's default of 10 MB. | ||
|
|
||
| Returns: | ||
| A list of PlanDiagnostic instances, deduplicated by rule_id. | ||
| """ | ||
| # Flatten the forest into a single list for rule functions. | ||
| all_nodes = [] | ||
| for root in roots: | ||
| all_nodes.extend(root.walk()) | ||
|
|
||
| diagnostics = [] | ||
| diagnostics.extend(_rule_missing_broadcast(all_nodes, broadcast_threshold_bytes)) | ||
| diagnostics.extend(_rule_excessive_shuffle(all_nodes)) | ||
| diagnostics.extend(_rule_redundant_repartition(all_nodes)) | ||
| # Only add informational SMJ rule if MISSING_BROADCAST didn't already flag it | ||
| broadcast_flagged = {d.node for d in diagnostics if d.rule_id == "MISSING_BROADCAST"} | ||
| for d in _rule_sort_merge_join_info(all_nodes): | ||
| if d.node not in broadcast_flagged: | ||
| diagnostics.append(d) | ||
|
|
||
| return diagnostics |
| return local_vars["optimized_df"] | ||
| else: | ||
| if "optimized_df" not in local_vars: | ||
| print(f" {C.RED}✖ Error: The code executed but did not assign 'optimized_df'.{C.RESET}") |
|
@copilot fix all issues |
…nt clarity, _NODE_TYPE_PATTERN, docstrings, agent retry Agent-Logs-Url: https://github.com/Radom12/OptiSpark/sessions/5809aa55-d44b-4d24-8e6f-591290199a24 Co-authored-by: Radom12 <69920588+Radom12@users.noreply.github.com>
All issues addressed in bdf76eb:
|
|
@copilot tests/test_agent.py::test_optimize_success PASSED [ 0%] =================================== FAILURES =================================== self = <test_validator.TestAggregateParity object at 0x7f6a52a43590>
E assert False is True tests/test_validator.py:269: AssertionError self = <test_validator.TestAggregateParity object at 0x7f6a52a439d0>
E AssertionError: assert 'amount' in 'Aggregate parity check failed with error: ' tests/test_validator.py:286: AssertionError self = <test_validator.TestAggregateParity object at 0x7f6a52a49b50>
E assert False is True tests/test_validator.py:312: AssertionError self = <test_validator.TestAggregateParity object at 0x7f6a52a4a890>
E AssertionError: assert 'price' in 'Aggregate parity check failed with error: ' tests/test_validator.py:331: AssertionError self = <test_validator.TestValidateOptimization object at 0x7f6a52a4af90>
E AssertionError: assert False is True tests/test_validator.py:349: AssertionError self = <test_validator.TestValidateOptimization object at 0x7f6a52a48390>
E AssertionError: assert False is True tests/test_validator.py:386: AssertionError -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html Name Stmts Miss Cover Missingserver/init.py 0 0 100%
|
… in aggregate parity check Agent-Logs-Url: https://github.com/Radom12/OptiSpark/sessions/d114b664-f815-4def-a65b-826592e9d68d Co-authored-by: Radom12 <69920588+Radom12@users.noreply.github.com>
Fixed in d9a011a. The root cause was that PySpark 4.x requires an active The fix replaces the Column-expression approach with a plain Python dict passed to |
No description provided.