Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 188 additions & 101 deletions materializationengine/blueprints/deltalake/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os

import cachetools.func
from flask import (
Blueprint,
current_app,
Expand Down Expand Up @@ -31,6 +32,25 @@ def _is_auth_disabled():
)


@cachetools.func.ttl_cache(maxsize=256, ttl=600)
def _dataset_for_datastack(datastack_name):
"""Resolve a datastack name to its auth *dataset* name via middle_auth.

``g.auth_user["datasets_admin"]`` is keyed by the auth *dataset* name, which
can differ from the *datastack* name. ``auth_requires_dataset_admin``
resolves datastack -> dataset (auth-service ``datastack`` namespace) before
checking; this mirrors that resolution so the wizard's datastack list agrees
with what the API endpoints actually allow. The mapping is
user-independent, so we resolve with the service ``AUTH_TOKEN`` and cache it.
"""
from middle_auth_client.decorators import dataset_from_table_id_from_request

return dataset_from_table_id_from_request(
datastack_name,
service_token=current_app.config.get("AUTH_TOKEN"),
)


# ---------------------------------------------------------------------------
# Wizard page routes
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -64,23 +84,37 @@ def wizard_step(step_number):
datastacks = []

if not _is_auth_disabled():
datasets_admin = g.get("auth_user", {}).get("datasets_admin", [])
datastacks = [ds for ds in datastacks if ds in datasets_admin]

if not datastacks and not _is_auth_disabled():
return render_template(
"deltalake_wizard.html",
current_step=step_number,
total_steps=total_steps,
step_template=None,
datastacks=[],
current_user=g.get("auth_user", {}),
access_denied=True,
access_denied_message="dataset_admin permission is required. You do not have dataset_admin access for any datastacks.",
target_partition_size_mb=get_config_param("DELTALAKE_TARGET_PARTITION_SIZE_MB", 256),
bloom_filter_fpp=get_config_param("DELTALAKE_BLOOM_FILTER_FPP", 0.001),
output_bucket=get_config_param("DELTALAKE_OUTPUT_BUCKET", ""),
), 403
datasets_admin = set(g.get("auth_user", {}).get("datasets_admin", []) or [])
# Resolve each datastack to its auth dataset name and keep only those the
# user is a dataset_admin of — matching what auth_requires_dataset_admin
# enforces on the API endpoints.
admin_datastacks = []
for ds in datastacks:
try:
dataset = _dataset_for_datastack(ds)
except Exception as e:
current_app.logger.warning(
"Could not resolve dataset for datastack %s: %s", ds, e
)
continue
if dataset in datasets_admin:
admin_datastacks.append(ds)
datastacks = admin_datastacks

if not datastacks:
return render_template(
"deltalake_wizard.html",
current_step=step_number,
total_steps=total_steps,
step_template=None,
datastacks=[],
current_user=g.get("auth_user", {}),
access_denied=True,
access_denied_message="dataset_admin permission is required. You do not have dataset_admin access for any datastacks.",
target_partition_size_mb=get_config_param("DELTALAKE_TARGET_PARTITION_SIZE_MB", 256),
bloom_filter_fpp=get_config_param("DELTALAKE_BLOOM_FILTER_FPP", 0.001),
output_bucket=get_config_param("DELTALAKE_OUTPUT_BUCKET", ""),
), 403

step_template_path = f"deltalake/step{step_number}.html"

Expand Down Expand Up @@ -150,11 +184,14 @@ def discover_specs(datastack_name):
_DEFAULT_DROP_COLUMNS,
TableSource,
_build_frozen_db_connection_string,
_classify_relation,
_get_redis_client,
_resolve_select_columns,
_validate_identifier,
discover_default_output_specs,
discover_view_output_specs,
estimate_bytes_per_row,
estimate_view_rows,
resolve_n_partitions,
)

Expand Down Expand Up @@ -217,100 +254,150 @@ def discover_specs(datastack_name):
{"error": f"Cannot connect to frozen DB for version {version}: {e}"}
), 404

# Look up row count.
with db_manager.session_scope(analysis_database) as session:
metadata_row = (
session.query(MaterializedMetadata)
.filter(MaterializedMetadata.table_name == table_name)
.first()
)
if metadata_row is None:
# Wrap discovery in a JSON error handler so backend failures (a missing
# frozen DB, an unreachable instance, a bug in the view path, etc.) surface
# as a readable JSON error instead of an HTML 500 page that the wizard
# reports as the opaque "Unexpected token '<'... is not valid JSON".
try:
# Classify as table vs view. Views are materialized views cloned into the
# frozen DB; they are not tracked in MaterializedMetadata and have no
# segmentation join, so they need view-specific row-count and spec
# discovery.
relation_kind = _classify_relation(connection_string, table_name)
if relation_kind is None:
return jsonify(
{"error": f"Table {table_name!r} not found in version {version}"}
), 404
row_count = metadata_row.row_count
is_view = relation_kind == "view"

if is_view:
row_count = estimate_view_rows(engine, table_name)
source = TableSource(annotation_table=table_name)
resolved_specs = discover_view_output_specs(source, connection_string)
else:
# Look up row count.
with db_manager.session_scope(analysis_database) as session:
metadata_row = (
session.query(MaterializedMetadata)
.filter(MaterializedMetadata.table_name == table_name)
.first()
)
if metadata_row is None:
return jsonify(
{
"error": f"Table {table_name!r} not found in version {version}"
}
), 404
row_count = metadata_row.row_count

# Detect segmentation table.
seg_table_name = build_segmentation_table_name(table_name, pcg_table_name)
has_seg_table = engine.dialect.has_table(engine, seg_table_name)
segmentation_table_name = seg_table_name if has_seg_table else None

source = TableSource(
annotation_table=table_name,
segmentation_table=segmentation_table_name,
)

# Detect segmentation table.
seg_table_name = build_segmentation_table_name(table_name, pcg_table_name)
has_seg_table = engine.dialect.has_table(engine, seg_table_name)
segmentation_table_name = seg_table_name if has_seg_table else None
# Discover specs.
resolved_specs = discover_default_output_specs(source, engine)

source = TableSource(
annotation_table=table_name,
segmentation_table=segmentation_table_name,
)
bytes_per_row = estimate_bytes_per_row(connection_string, source)

# Discover specs.
resolved_specs = discover_default_output_specs(source, engine)
bytes_per_row = estimate_bytes_per_row(connection_string, source)
# For a small table, collapse to a single output — partitioning a tiny
# table just produces many undersized files. (For views the first spec
# is the flat base, so this keeps the flat export.)
small_table_threshold_mb = int(
get_config_param("DELTALAKE_SMALL_TABLE_THRESHOLD_MB", 200)
)
estimated_total_mb = row_count * bytes_per_row / (1024 * 1024)
if estimated_total_mb < small_table_threshold_mb and len(resolved_specs) > 1:
resolved_specs = resolved_specs[:1]

small_table_threshold_mb = int(
get_config_param("DELTALAKE_SMALL_TABLE_THRESHOLD_MB", 200)
)
estimated_total_mb = row_count * bytes_per_row / (1024 * 1024)
if estimated_total_mb < small_table_threshold_mb and len(resolved_specs) > 1:
resolved_specs = resolved_specs[:1]

# Track which specs had "auto" before resolution (for caching).
was_auto = [spec.n_partitions == "auto" for spec in resolved_specs]

# Resolve partition counts.
for spec in resolved_specs:
if spec.n_partitions == "auto":
effective_target = spec.target_file_size_mb or target_partition_size_mb
spec.n_partitions = resolve_n_partitions(
"auto",
row_count,
target_file_size_mb=effective_target,
bytes_per_row=bytes_per_row,
)
# Track which specs had "auto" before resolution (for caching).
was_auto = [spec.n_partitions == "auto" for spec in resolved_specs]

from dataclasses import asdict
# Resolve partition counts.
for spec in resolved_specs:
if spec.n_partitions == "auto":
effective_target = (
spec.target_file_size_mb or target_partition_size_mb
)
spec.n_partitions = resolve_n_partitions(
"auto",
row_count,
target_file_size_mb=effective_target,
bytes_per_row=bytes_per_row,
)

# Build available columns list (base columns + computed columns from specs).
available_columns = _resolve_select_columns(
connection_string, source, _DEFAULT_DROP_COLUMNS
)
for spec in resolved_specs:
if spec.source_geometry_column:
col = spec.source_geometry_column
for suffix in ["_x", "_y", "_z", "_morton"]:
computed = f"{col}{suffix}"
if computed not in available_columns:
available_columns.append(computed)

# Collect geometry columns (position columns that get morton-encoded).
geometry_columns = sorted(
{s.source_geometry_column for s in resolved_specs if s.source_geometry_column}
)
from dataclasses import asdict

# Build available columns list (base columns + computed columns from specs).
available_columns = _resolve_select_columns(
connection_string, source, _DEFAULT_DROP_COLUMNS
)
for spec in resolved_specs:
if spec.source_geometry_column:
col = spec.source_geometry_column
for suffix in ["_x", "_y", "_z", "_morton"]:
computed = f"{col}{suffix}"
if computed not in available_columns:
available_columns.append(computed)

# Collect geometry columns (position columns that get morton-encoded).
geometry_columns = sorted(
{
s.source_geometry_column
for s in resolved_specs
if s.source_geometry_column
}
)

# Cache raw specs (before n_partitions resolution) so the cache stays
# valid regardless of the caller's target_partition_size_mb.
raw_specs = [asdict(s) for s in resolved_specs]
# Reset resolved n_partitions back to "auto" for specs that were auto.
for raw, auto in zip(raw_specs, was_auto):
if auto:
raw["n_partitions"] = "auto"

cache_result = {
"row_count": row_count,
"bytes_per_row": bytes_per_row,
"available_columns": available_columns,
"geometry_columns": geometry_columns,
"specs": raw_specs,
}
redis_client.set(cache_key, json.dumps(cache_result), ex=600)

# Return the result with resolved partition counts.
result = {
"row_count": row_count,
"bytes_per_row": bytes_per_row,
"available_columns": available_columns,
"geometry_columns": geometry_columns,
"specs": [asdict(s) for s in resolved_specs],
}

return jsonify(result)
# Cache raw specs (before n_partitions resolution) so the cache stays
# valid regardless of the caller's target_partition_size_mb.
raw_specs = [asdict(s) for s in resolved_specs]
# Reset resolved n_partitions back to "auto" for specs that were auto.
for raw, auto in zip(raw_specs, was_auto):
if auto:
raw["n_partitions"] = "auto"

# For views, row_count is a fast Postgres planner estimate (an exact
# count would execute the full view); it can be far off for aggregating
# views and is advisory only — the exact count is determined at export.
cache_result = {
"row_count": row_count,
"row_count_estimated": is_view,
"bytes_per_row": bytes_per_row,
"available_columns": available_columns,
"geometry_columns": geometry_columns,
"specs": raw_specs,
}
redis_client.set(cache_key, json.dumps(cache_result), ex=600)

# Return the result with resolved partition counts.
result = {
"row_count": row_count,
"row_count_estimated": is_view,
"bytes_per_row": bytes_per_row,
"available_columns": available_columns,
"geometry_columns": geometry_columns,
"specs": [asdict(s) for s in resolved_specs],
}

return jsonify(result)
except Exception as e:
current_app.logger.error(
"discover_specs failed for %s v%s table %r: %s",
datastack_name,
version,
table_name,
e,
exc_info=True,
)
return jsonify(
{"error": f"Spec discovery failed for {table_name!r}: {e}"}
), 500


@deltalake_bp.route("/api/<string:datastack_name>/check-exists", methods=["POST"])
Expand Down
Loading
Loading