diff --git a/keepercommander/commands/pam_import/README.md b/keepercommander/commands/pam_import/README.md index e5baae283..233f8f37a 100644 --- a/keepercommander/commands/pam_import/README.md +++ b/keepercommander/commands/pam_import/README.md @@ -303,7 +303,7 @@ Each Machine (pamMachine, pamDatabase, pamDirectory) can specify **Administrativ > **Note 3:** Post rotation scripts (a.k.a. `scripts`) are executed in following order: `pamUser` scripts after any **successful** rotation for that user, `pamMachine` scripts after any **successful** rotation on the machine and `pamConfiguration` scripts after any rotation using that configuration. > **Note 4:** When `allow_supply_user` is false and JIT ephemeral is not used, vault may require a launch credential; import can provide it via `launch_credentials` in the resource's `connection` block. -JIT and KeeperAI settings below are shared across all resource types (pamMachine, pamDatabase, pamDirectory) except User and RBI (pamRemoteBrowser) records. +JIT and KeeperAI settings below are shared across all resource types (pamMachine, pamDatabase, pamDirectory) except User and RBI (pamRemoteBrowser) records. **Workflow** (approvals / checkout / temporal restrictions) is supported on all four resource types: pamMachine, pamDatabase, pamDirectory, **and** pamRemoteBrowser.
Just-In-Time Access (JIT) @@ -406,6 +406,79 @@ JIT and KeeperAI settings below are shared across all resource types (pamMachine ```
+Workflow (Approvals, Checkout, Temporal Access) + +Workflow controls how privileged access to a resource is gated: how many approvals are needed, whether sessions require check-out, MFA, reason/ticket, what time windows access is allowed in, and who can approve (with optional escalation). Workflow is applied via the Keeper Router **after** the resource record and DAG/JIT/AI steps are complete and is not stored on the record itself. + +**How to Configure:** Add `pam_settings.options.workflow` to any pamMachine, pamDatabase, pamDirectory, or pamRemoteBrowser. The workflow object maps directly to the Web Vault's "Workflow" tab on a resource record. + +```json +{ + "pam_settings": { + "options": { + "workflow": { + "approvals_needed": 2, + "checkout_needed": true, + "start_access_on_approval": false, + "require_reason": true, + "require_ticket": false, + "require_mfa": true, + "access_duration": "8h", + "allowed_times": { + "allowed_days": ["mon", "tue", "wed", "thu", "fri"], + "time_ranges": [ + { "start": "09:00", "end": "17:30" } + ], + "timezone": "America/New_York" + }, + "approvers": [ + { + "principal": { "type": "user", "email": "primary.approver@example.com" }, + "escalation": false + }, + { + "principal": { "type": "user", "email": "second.approver@example.com" }, + "escalation": false + }, + { + "principal": { + "type": "team", + "team_uid_base64url": "REPLACE_TEAM_UID_BASE64URL" + }, + "escalation": true, + "escalation_after": "45m" + } + ] + } + } + } +} +``` + +**Field reference:** +- `approvals_needed` *(int, default `0`)* — number of approvals required to grant access. +- `checkout_needed` *(bool, default `false`)* — require explicit check-out before launching a session. +- `start_access_on_approval` *(bool, default `false`)* — start the access window the moment approval is granted (rather than at session launch). +- `require_reason` / `require_ticket` *(bool, default `false`)* — prompt the user for a reason / ticket reference at request time. +- `require_mfa` *(bool, default `false`)* — require MFA at session launch. +- `access_duration` *(string, default `"1d"`)* — how long approved access remains valid. Accepts `Xm` / `Xh` / `Xd` (e.g. `"30m"`, `"8h"`, `"2d"`); a bare integer is interpreted as minutes. Must be positive. +- `allowed_times.allowed_days` *(list of strings)* — restrict access to these weekdays. Accepts 3-letter (`mon`..`sun`) or full names (`monday`..`sunday`), case-insensitive. +- `allowed_times.time_ranges` *(list of `{start, end}` objects)* — one or more allowed daily time windows in `HH:MM` (24-hour) format. **Multiple ranges per day are supported.** A single range whose `end` is earlier than its `start` (e.g. an overnight `22:00–06:00`) **should be split into two ranges** that both fall inside one day (e.g. `22:00–23:59` and `00:00–06:00`) +- `allowed_times.timezone` *(string)* — IANA timezone name (e.g. `"UTC"`, `"America/New_York"`). **Required when `time_ranges` is non-empty.** +- `approvers[]` — list of approver entries. + - `principal.type` — `"user"` or `"team"`. + - For users: `principal.email` (must exist in the enterprise). + - For teams: `principal.team_uid_base64url` (the team's vault UID, base64url-encoded; validated against the local team cache during import — unknown UIDs fail in dry-run). + - `escalation` *(bool)* — whether this approver is in the escalation chain. + - `escalation_after` *(duration string, optional)* — wait this long before escalating to this approver. **Requires `escalation: true`.** + +**Behavior notes:** +- **Trivial workflow is a no-op.** If none of `approvals_needed > 0`, `checkout_needed`, `require_mfa`, `start_access_on_approval`, `allowed_times.allowed_days`, or `allowed_times.time_ranges` is set, the workflow block is treated as absent and no Router call is made. +- **Pre-flight validation runs in `--dry-run`.** Bad durations, malformed `HH:MM`, missing timezone, escalation rule violations, and unknown team UIDs are reported during dry-run before any vault writes. +- **Dry-run skips the Router calls.** Workflow is applied (Router create/update + approver reconcile) only on a real run. +- **`extend` only applies workflow to newly created resources** (existing resources are not touched). +
+
pam_data.resources.pamMachine (RDP) ```json @@ -435,7 +508,8 @@ JIT and KeeperAI settings below are shared across all resource types (pamMachine "ai_threat_detection": "off", "ai_terminate_session_on_detection": "off", "jit_settings": {}, - "ai_settings": {} + "ai_settings": {}, + "workflow": {} }, "allow_supply_host": false, "port_forward": { diff --git a/keepercommander/commands/pam_import/base.py b/keepercommander/commands/pam_import/base.py index 22137b8cf..e5cf37835 100644 --- a/keepercommander/commands/pam_import/base.py +++ b/keepercommander/commands/pam_import/base.py @@ -22,9 +22,11 @@ from typing import Any, Dict, Optional, List, Union from ..record_edit import RecordAddCommand as RecordEditAddCommand +from ..workflow.helpers import RecordResolver, WorkflowFormatter from ... import api, attachment, utils, vault, vault_extensions, \ record_facades, record_management from ...display import bcolors +from ...error import CommandError from ...recordv3 import RecordV3 @@ -69,7 +71,8 @@ "pam_settings": { "options" : { "jit_settings": {}, - "ai_settings": {} + "ai_settings": {}, + "workflow": {} }, "connection" : {} }, @@ -611,6 +614,144 @@ def load(cls, data: Union[str, dict]): return obj +class PamWorkflowOptions: + """Parsed workflow settings from pam_settings.options.workflow. + Not stored on record fields nor in DAG; applied via Krouter after record/DAG creation. + """ + + _DEFAULT_DURATION_MS = 86_400_000 # "1d" + + def __init__(self): + self.approvals_needed: int = 0 + self.checkout_needed: bool = False + self.start_access_on_approval: bool = False + self.require_reason: bool = False + self.require_ticket: bool = False + self.require_mfa: bool = False + self.access_duration_ms: int = self._DEFAULT_DURATION_MS + self.allowed_days: List[str] = [] # canonical 3-letter tokens: "mon".."sun" + self.time_ranges: List[dict] = [] # each: {"start": "HH:MM", "end": "HH:MM"} + self.timezone: str = "" + self.approvers: List[dict] = [] # each: {principal_type, email, team_uid_b64, escalation, escalation_after_ms} + + @staticmethod + def _parse_duration(value) -> int: + """Return milliseconds. Raises CommandError on invalid/non-positive value. + Delegates to WorkflowFormatter.parse_duration; adds a None -> default-1d shim + (the CLI command always supplies a string, but the JSON import may omit the key). + """ + if value is None: + return PamWorkflowOptions._DEFAULT_DURATION_MS + return WorkflowFormatter.parse_duration(str(value)) + + @classmethod + def load(cls, data) -> Optional['PamWorkflowOptions']: + """Parse workflow JSON dict. Returns None when absent / null / trivial (V2 guard).""" + if not data or not isinstance(data, dict): + return None + + obj = cls() + obj.approvals_needed = max(0, int(data.get('approvals_needed', 0) or 0)) + obj.checkout_needed = bool(data.get('checkout_needed', False)) + obj.start_access_on_approval = bool(data.get('start_access_on_approval', False)) + obj.require_reason = bool(data.get('require_reason', False)) + obj.require_ticket = bool(data.get('require_ticket', False)) + obj.require_mfa = bool(data.get('require_mfa', False)) + + # V9: access_duration — default "1d" + obj.access_duration_ms = cls._parse_duration(data.get('access_duration')) + + # allowed_times + at = data.get('allowed_times') or {} + if isinstance(at, dict): + days_raw = at.get('allowed_days') or [] + if isinstance(days_raw, list): + for day in days_raw: + d = str(day).lower().strip() + if d not in WorkflowFormatter.DAY_PARSE_MAP: + raise CommandError('', f'workflow: invalid allowed_times.allowed_days token "{day}"') + obj.allowed_days.append(d[:3]) # store as "mon".."sun" + + ranges_raw = at.get('time_ranges') or [] + if isinstance(ranges_raw, list): + for r in ranges_raw: + if isinstance(r, dict): + start = str(r.get('start', '') or '').strip() + end = str(r.get('end', '') or '').strip() + if start and end: + obj.time_ranges.append({'start': start, 'end': end}) + + obj.timezone = str(at.get('timezone', '') or '').strip() + + # V8: time_ranges non-empty => timezone required + if obj.time_ranges and not obj.timezone: + raise CommandError('', 'workflow: allowed_times.time_ranges requires timezone') + + # approvers + for idx, a in enumerate(data.get('approvers') or []): + if not isinstance(a, dict): + continue + principal = a.get('principal') or {} + if not isinstance(principal, dict): + continue + ptype = str(principal.get('type', '') or '').lower() + escalation = bool(a.get('escalation', False)) + esc_after_raw = a.get('escalation_after') + esc_after_ms = cls._parse_duration(esc_after_raw) if esc_after_raw else 0 + # V7: escalation_after requires escalation: true + if esc_after_ms and not escalation: + raise CommandError('', f'workflow: approvers[{idx}] escalation_after requires escalation: true') + if ptype == 'user': + email = str(principal.get('email', '') or '').strip() + if not email: + raise CommandError('', f'workflow: approvers[{idx}] user principal requires non-empty email') + obj.approvers.append({ + 'principal_type': 'user', 'email': email, 'team_uid_b64': None, + 'escalation': escalation, 'escalation_after_ms': esc_after_ms, + }) + elif ptype == 'team': + uid_b64 = str(principal.get('team_uid_base64url', '') or '').strip() + if not uid_b64: + raise CommandError('', f'workflow: approvers[{idx}] team principal requires non-empty team_uid_base64url') + obj.approvers.append({ + 'principal_type': 'team', 'email': None, 'team_uid_b64': uid_b64, + 'escalation': escalation, 'escalation_after_ms': esc_after_ms, + }) + else: + raise CommandError('', f'workflow: approvers[{idx}] principal.type must be "user" or "team", got "{ptype}"') + + # V2: non-trivial guard — at least one meaningful flag must be set + is_trivial = ( + obj.approvals_needed == 0 + and not obj.start_access_on_approval + and not obj.checkout_needed + and not obj.require_mfa + and not obj.allowed_days + and not obj.time_ranges + ) + if is_trivial: + return None # nothing to persist; caller treats as delete/no-op + + # V4 warning: approvals_needed > 0 with no approvers + if obj.approvals_needed > 0 and not obj.approvers: + logging.warning('workflow: approvals_needed > 0 but no approvers specified') + + return obj + + def validate_principals(self, params, resource_title: str = '') -> None: + """Validate team UIDs via RecordResolver.validate_team (which checks both + team_cache and enterprise.teams). Raises CommandError on first unknown UID. + """ + for idx, a in enumerate(self.approvers): + if a['principal_type'] != 'team': + continue + try: + RecordResolver.validate_team(params, a['team_uid_b64']) + except CommandError as e: + prefix = f'Resource "{resource_title}": ' if resource_title else '' + raise CommandError('', f'{prefix}workflow approvers[{idx}]: {e.message or str(e)}') + + class DagJitSettingsObject(): def __init__(self): self.create_ephemeral: bool = False @@ -2900,10 +3041,12 @@ class PamRemoteBrowserSettings: def __init__( self, options: Optional[DagSettingsObject] = None, - connection: Optional[ConnectionSettingsHTTP] = None + connection: Optional[ConnectionSettingsHTTP] = None, + workflow: Optional[PamWorkflowOptions] = None, ): self.options = options self.connection = connection + self.workflow = workflow # not on record nor in DAG; applied via Krouter @classmethod def load(cls, data: Optional[Union[str, dict]]): @@ -2912,9 +3055,14 @@ def load(cls, data: Optional[Union[str, dict]]): except: logging.error(f"PAM RBI Settings field failed to load from: {str(data)[:80]}...") if not isinstance(data, dict): return obj - options = DagSettingsObject.load(data.get("options", {})) + options_dict = data.get("options", {}) or {} + options = DagSettingsObject.load(options_dict) if not is_empty_instance(options): obj.options = options + if isinstance(options_dict, dict): + workflow_value = options_dict.get("workflow") + if workflow_value is not None: + obj.workflow = PamWorkflowOptions.load(workflow_value) cdata = data.get("connection", {}) # TO DO: if isinstance(cdata, str): lookup_by_name(pam_data.connections) @@ -2944,6 +3092,7 @@ def __init__( options: Optional[DagSettingsObject] = None, jit_settings: Optional[DagJitSettingsObject] = None, ai_settings: Optional[DagAiSettingsObject] = None, + workflow: Optional[PamWorkflowOptions] = None, ): self.allowSupplyHost = allowSupplyHost self.connection = connection @@ -2951,6 +3100,7 @@ def __init__( self.options = options self.jit_settings = jit_settings self.ai_settings = ai_settings + self.workflow = workflow # not on record nor in DAG; applied via Krouter # PamConnectionSettings excludes ConnectionSettingsHTTP pam_connection_classes = [ @@ -2981,8 +3131,8 @@ def is_empty(self): empty = is_empty_instance(self.options) empty = empty and is_empty_instance(self.portForward) empty = empty and is_empty_instance(self.connection, ["protocol"]) - # NB! JIT and AI settings are in import json but not in record json (just DAG json) - empty = empty and self.jit_settings is None and self.ai_settings is None + # NB! JIT, AI, workflow are in import json but not in record json (not DAG either for workflow) + empty = empty and self.jit_settings is None and self.ai_settings is None and self.workflow is None return empty @classmethod @@ -3008,6 +3158,9 @@ def load(cls, data: Union[str, dict]): ai_settings = DagAiSettingsObject.load(ai_value) if ai_settings: obj.ai_settings = ai_settings + workflow_value = options_dict.get("workflow") + if workflow_value is not None: + obj.workflow = PamWorkflowOptions.load(workflow_value) portForward = PamPortForwardSettings.load(data.get("port_forward", {})) if not is_empty_instance(portForward): diff --git a/keepercommander/commands/pam_import/edit.py b/keepercommander/commands/pam_import/edit.py index 0b5d35686..d80fd8354 100644 --- a/keepercommander/commands/pam_import/edit.py +++ b/keepercommander/commands/pam_import/edit.py @@ -19,6 +19,7 @@ from typing import Any, Dict, Optional, List, Union from .keeper_ai_settings import set_resource_jit_settings, set_resource_keeper_ai_settings, refresh_meta_to_latest, refresh_link_to_config_to_latest +from .workflow_apply import apply_workflow, validate_workflow_principals from .base import ( PAM_RESOURCES_RECORD_TYPES, PROJECT_IMPORT_JSON_TEMPLATE, @@ -1642,6 +1643,9 @@ def process_data(self, params, project): resolve_domain_admin(pce, users) # only resolve here - create after machine and user creation + # pre-flight: validate workflow team UIDs before any vault writes (runs in dry-run too) + validate_workflow_principals(params, resources) + # dry run if project["options"].get("dry_run", False) is True: print("Will import file data here...") @@ -1696,6 +1700,9 @@ def process_data(self, params, project): args["connections"] = True args["v_type"] = RefType.PAM_BROWSER tdag.set_resource_allowed(**args) + rbi_wf = getattr(getattr(mach, 'rbi_settings', None), 'workflow', None) + if rbi_wf: + apply_workflow(params, mach.uid, mach.title or '', rbi_wf) else: # machine/db/directory args = parse_command_options(mach, True) if admin_uid: args["admin"] = admin_uid @@ -1739,6 +1746,10 @@ def process_data(self, params, project): if ai: refresh_link_to_config_to_latest(params, mach.uid, pam_cfg_uid) + ps_wf = getattr(getattr(mach, 'pam_settings', None), 'workflow', None) + if ps_wf: + apply_workflow(params, mach.uid, mach.title or '', ps_wf) + # Machine - create its users (if any) users = getattr(mach, "users", []) users = users if isinstance(users, list) else [] diff --git a/keepercommander/commands/pam_import/extend.py b/keepercommander/commands/pam_import/extend.py index 82fb5522b..c21a2a6f2 100644 --- a/keepercommander/commands/pam_import/extend.py +++ b/keepercommander/commands/pam_import/extend.py @@ -53,6 +53,7 @@ refresh_meta_to_latest, refresh_link_to_config_to_latest, ) +from .workflow_apply import apply_workflow, validate_workflow_principals from ...keeper_dag import EdgeType from ...keeper_dag.types import RefType from ..base import Command @@ -549,6 +550,10 @@ def execute(self, params, **kwargs): fp = (getattr(u, "folder_path", None) or "").strip() u.resolved_folder_uid = path_to_folder_uid.get(fp) or usr_folder_uid + # pre-flight: validate workflow team UIDs for new resources (runs in dry-run too) + new_rscs = [r for r in project.get('mapped_resources', []) if getattr(r, '_extend_tag', None) == 'new'] + validate_workflow_principals(params, new_rscs) + if dry_run: print("[DRY RUN COMPLETE] No changes were made. All actions were validated but not executed.") return @@ -1402,6 +1407,9 @@ def process_data(self, params, project): args["connections"] = True args["v_type"] = RefType.PAM_BROWSER tdag.set_resource_allowed(**args) + rbi_wf = getattr(getattr(mach, 'rbi_settings', None), 'workflow', None) + if rbi_wf: + apply_workflow(params, mach.uid, mach.title or '', rbi_wf) else: args = parse_command_options(mach, True) if admin_uid: @@ -1444,6 +1452,10 @@ def process_data(self, params, project): if ai: refresh_link_to_config_to_latest(params, mach.uid, pam_cfg_uid) + ps_wf = getattr(getattr(mach, 'pam_settings', None), 'workflow', None) + if ps_wf: + apply_workflow(params, mach.uid, mach.title or '', ps_wf) + mach_users = getattr(mach, "users", []) or [] for user in mach_users: if getattr(user, "_extend_tag", None) != "new": diff --git a/keepercommander/commands/pam_import/workflow_apply.py b/keepercommander/commands/pam_import/workflow_apply.py new file mode 100644 index 000000000..65ffc8c5f --- /dev/null +++ b/keepercommander/commands/pam_import/workflow_apply.py @@ -0,0 +1,262 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' bool: + if isinstance(e, KeeperApiError) and e.result_code == 429: + return True + msg = str(getattr(e, 'message', None) or e).lower() + return 'throttle' in msg or 'too many' in msg + + +def _post_with_throttle_retry(params, path: str, **kwargs): + """Wrap _post_request_to_router with progressive backoff on 429 / throttle errors. + Non-throttle errors propagate immediately. Final retry's exception is re-raised. + """ + wait = _THROTTLE_BASE_WAIT + for attempt in range(1, _THROTTLE_MAX_RETRIES + 1): + try: + return _post_request_to_router(params, path, **kwargs) + except Exception as e: + if not _is_throttle_error(e) or attempt >= _THROTTLE_MAX_RETRIES: + raise + logging.warning( + 'Krouter rate-limited on %s (attempt %d/%d); waiting %.1fs', + path, attempt, _THROTTLE_MAX_RETRIES, wait, + ) + time.sleep(wait) + wait *= _THROTTLE_MULTIPLIER + + +# Re-exported for tests and any downstream importers; the canonical map lives +# in WorkflowFormatter.DAY_PARSE_MAP and accepts both 3-letter and full names. +_DAY_PROTO_MAP = { + k: v for k, v in WorkflowFormatter.DAY_PARSE_MAP.items() if len(k) == 3 +} + + +def _build_temporal_filter(opts: PamWorkflowOptions): + """Build TemporalAccessFilter from opts. Returns None when no temporal slice is set. + + startTime / endTime on TimeOfDayRange are HHMM integers (hours*100 + minutes); + see WorkflowFormatter._parse_time_to_hhmm. Canonical sources: + - keeperapp-protobuf/workflow.proto:140 (`int32 startTime = 1; // HHMM format`) + - ka-libs/workflow/.../handlers/WfConfigCRUD.kt::validateHHMM (server validator) + """ + if not opts.allowed_days and not opts.time_ranges and not opts.timezone: + return None + temporal = workflow_pb2.TemporalAccessFilter() + for day_token in opts.allowed_days: + day_enum = WorkflowFormatter.DAY_PARSE_MAP.get(day_token) + if day_enum is not None: + temporal.allowedDays.append(day_enum) + for r in opts.time_ranges: + tr = workflow_pb2.TimeOfDayRange() + tr.startTime = WorkflowFormatter._parse_time_to_hhmm(r['start']) + tr.endTime = WorkflowFormatter._parse_time_to_hhmm(r['end']) + temporal.timeRanges.append(tr) + if opts.timezone: + temporal.timeZone = opts.timezone + return temporal + + +def _build_parameters( + record_uid_bytes: bytes, + record_title: str, + opts: PamWorkflowOptions, +) -> workflow_pb2.WorkflowParameters: + params_proto = workflow_pb2.WorkflowParameters() + params_proto.resource.CopyFrom(ProtobufRefBuilder.record_ref(record_uid_bytes, record_title)) + params_proto.approvalsNeeded = opts.approvals_needed + params_proto.checkoutNeeded = opts.checkout_needed + params_proto.startAccessOnApproval = opts.start_access_on_approval + params_proto.requireReason = opts.require_reason + params_proto.requireTicket = opts.require_ticket + params_proto.requireMFA = opts.require_mfa + params_proto.accessLength = opts.access_duration_ms + + temporal = _build_temporal_filter(opts) + if temporal: + params_proto.allowedTimes.CopyFrom(temporal) + + return params_proto + + +def _build_approver_proto(a: dict) -> workflow_pb2.WorkflowApprover: + approver = workflow_pb2.WorkflowApprover() + if a['principal_type'] == 'user': + approver.user = a['email'] + else: + approver.teamUid = utils.base64_url_decode(a['team_uid_b64']) + approver.escalation = a['escalation'] + if a['escalation_after_ms']: + approver.escalationAfterMs = a['escalation_after_ms'] + return approver + + +def _approver_key(params: KeeperParams, approver: workflow_pb2.WorkflowApprover) -> str: + """Return a stable identity key for an existing server approver (for reconcile diff). + Server may return either user (email) or userId (int). When userId is set, resolve + to email through the enterprise user list so it matches the import-side key. + """ + if approver.HasField('user'): + return f'user:{approver.user}' + if approver.HasField('userId'): + email = RecordResolver.resolve_user(params, approver.userId) + # resolve_user returns 'User ID ' when not found — fall back to userId so + # we don't accidentally key two different unknown users to the same string. + if email and not email.startswith('User ID '): + return f'user:{email}' + return f'userid:{approver.userId}' + if approver.HasField('teamUid'): + return f'team:{utils.base64_url_encode(approver.teamUid)}' + return '' + + +def _new_approver_key(a: dict) -> str: + if a['principal_type'] == 'user': + return f'user:{a["email"]}' + return f'team:{a["team_uid_b64"]}' + + +def _reconcile_approvers( + params: KeeperParams, + record_uid_bytes: bytes, + record_title: str, + existing: List[workflow_pb2.WorkflowApprover], + new_approvers: List[dict], +) -> None: + ref = ProtobufRefBuilder.record_ref(record_uid_bytes, record_title) + + existing_keys = {_approver_key(params, a): a for a in existing} + new_keys = {_new_approver_key(a): a for a in new_approvers} + + to_delete = [a for k, a in existing_keys.items() if k not in new_keys] + to_add = [a for k, a in new_keys.items() if k not in existing_keys] + + if to_delete: + config = workflow_pb2.WorkflowConfig() + config.parameters.resource.CopyFrom(ref) + for a in to_delete: + config.approvers.append(a) + _post_with_throttle_retry(params, 'delete_workflow_approvers', rq_proto=config) + + if to_add: + config = workflow_pb2.WorkflowConfig() + config.parameters.resource.CopyFrom(ref) + for a in to_add: + config.approvers.append(_build_approver_proto(a)) + _post_with_throttle_retry(params, 'add_workflow_approvers', rq_proto=config) + + +def apply_workflow( + params: KeeperParams, + record_uid: str, + record_title: str, + opts: PamWorkflowOptions, +) -> None: + """Create or update workflow config via Krouter. Raises CommandError on failure.""" + record_uid_bytes = utils.base64_url_decode(record_uid) + ref = ProtobufRefBuilder.record_ref(record_uid_bytes, record_title) + + try: + existing = _post_with_throttle_retry( + params, 'read_workflow_config', + rq_proto=ref, rs_type=workflow_pb2.WorkflowConfig, + ) + except Exception as e: + raise CommandError('', f'workflow read failed for "{record_title}": {sanitize_router_error(e)}') + + parameters = _build_parameters(record_uid_bytes, record_title, opts) + + try: + if existing: + _post_with_throttle_retry(params, 'update_workflow_config', rq_proto=parameters) + if opts.approvals_needed > 0: + _reconcile_approvers( + params, record_uid_bytes, record_title, + list(existing.approvers), opts.approvers, + ) + elif existing.approvers: + # approvals_needed dropped to 0: remove all existing approvers (V5) + config = workflow_pb2.WorkflowConfig() + config.parameters.resource.CopyFrom(ref) + for a in existing.approvers: + config.approvers.append(a) + _post_with_throttle_retry(params, 'delete_workflow_approvers', rq_proto=config) + else: + _post_with_throttle_retry(params, 'create_workflow_config', rq_proto=parameters) + if opts.approvals_needed > 0 and opts.approvers: + config = workflow_pb2.WorkflowConfig() + config.parameters.resource.CopyFrom(ref) + for a in opts.approvers: + config.approvers.append(_build_approver_proto(a)) + _post_with_throttle_retry(params, 'add_workflow_approvers', rq_proto=config) + except CommandError: + raise + except Exception as e: + raise CommandError('', f'workflow apply failed for "{record_title}": {sanitize_router_error(e)}') + + +def validate_workflow_principals(params: KeeperParams, resources) -> None: + """Pre-flight: validate team UIDs in workflow approvers for all resources. + Uses RecordResolver.validate_team which checks both team_cache and enterprise.teams, + matching the lookup path used by `pam workflow add-approver`. Raises CommandError + on the first unknown UID, with the resource title in the message for context. + """ + for mach in resources or []: + opts = None + ps = getattr(mach, 'pam_settings', None) + if ps: + opts = getattr(ps, 'workflow', None) + if opts is None: + rbi = getattr(mach, 'rbi_settings', None) + if rbi: + opts = getattr(rbi, 'workflow', None) + if opts is None: + continue + title = getattr(mach, 'title', '') or '' + for idx, a in enumerate(opts.approvers): + if a['principal_type'] != 'team': + continue + try: + RecordResolver.validate_team(params, a['team_uid_b64']) + except CommandError as e: + prefix = f'Resource "{title}": ' if title else '' + raise CommandError('', f'{prefix}workflow approvers[{idx}]: {e.message or str(e)}') diff --git a/keepercommander/commands/workflow/config_commands.py b/keepercommander/commands/workflow/config_commands.py index 30137bf9c..28ca9c71f 100644 --- a/keepercommander/commands/workflow/config_commands.py +++ b/keepercommander/commands/workflow/config_commands.py @@ -329,8 +329,11 @@ def _print_table(params, response, record_uid): print(f" Days: {', '.join(day_names)}") if at.timeRanges: for tr in at.timeRanges: - start_h, start_m = divmod(tr.startTime, 60) - end_h, end_m = divmod(tr.endTime, 60) + # startTime / endTime are HHMM (hours*100 + minutes); see + # WorkflowFormatter._parse_time_to_hhmm and the canonical + # ka-libs/workflow/.../WfConfigCRUD.kt::validateHHMM. + start_h, start_m = divmod(tr.startTime, 100) + end_h, end_m = divmod(tr.endTime, 100) print(f" Time: {start_h:02d}:{start_m:02d} - {end_h:02d}:{end_m:02d}") if at.timeZone: print(f" Timezone: {at.timeZone}") diff --git a/keepercommander/commands/workflow/helpers.py b/keepercommander/commands/workflow/helpers.py index e46eb263b..21a3ac4dd 100644 --- a/keepercommander/commands/workflow/helpers.py +++ b/keepercommander/commands/workflow/helpers.py @@ -523,9 +523,17 @@ def build_temporal_filter(allowed_days_str, time_range_str, timezone_str): @staticmethod def _parse_time_to_hhmm(time_str): - """Parse 'HH:MM' into the HHMM integer encoding the server expects on - TimeOfDayRange.startTime / .endTime — e.g. '03:00' -> 300, '17:30' -> 1730. - Server validates: HHMM integer with HH in 0-23 and MM in 0-59. + """Parse 'HH:MM' to the HHMM integer the server stores on + TimeOfDayRange.startTime / .endTime: hours*100 + minutes. + Examples: '00:00' -> 0, '03:00' -> 300, '09:00' -> 900, '17:30' -> 1730. + Valid range: 0..2359 with hours in 0-23 and minutes in 0-59. + + Canonical sources (all agree on HHMM): + - keeperapp-protobuf/workflow.proto:140 + `int32 startTime = 1; // HHMM format` + - ka-libs/workflow/src/main/kotlin/com/keepersecurity/workflow/handlers/WfConfigCRUD.kt::validateHHMM + `val hours = value / 100; val minutes = value % 100` + throws "Invalid : . Expected HHMM integer with HH in 0-23 and MM in 0-59" on bad input. """ try: parts = time_str.split(':') @@ -547,6 +555,7 @@ def format_temporal_filter(at): if at.timeRanges: ranges = [] for tr in at.timeRanges: + # startTime / endTime are HHMM integers (see _parse_time_to_hhmm). sh, sm = divmod(tr.startTime, 100) eh, em = divmod(tr.endTime, 100) ranges.append(f"{sh:02d}:{sm:02d}-{eh:02d}:{em:02d}") diff --git a/keepercommander/commands/workflow/registry.py b/keepercommander/commands/workflow/registry.py index ae87e7e8c..2ea6f31ed 100644 --- a/keepercommander/commands/workflow/registry.py +++ b/keepercommander/commands/workflow/registry.py @@ -9,9 +9,6 @@ # Contact: ops@keepersecurity.com # -import logging -from urllib.parse import urlparse - from ..base import GroupCommand, dump_report_data from ...display import bcolors from .helpers import _ENFORCEMENT_KEY @@ -42,15 +39,8 @@ class PAMWorkflowCommand(GroupCommand): - NOTICE_MSG = 'Notice: PAM Workflow commands are not in production yet. They will be available soon.' - _ALLOWED_PREFIXES = ('dev.', 'qa.') _ADMIN_VERBS = frozenset({'create', 'update', 'delete', 'add-approver', 'remove-approver'}) - @staticmethod - def _is_allowed_server(params): - hostname = urlparse(params.rest_context.server_base).hostname or '' - return any(hostname.startswith(p) for p in PAMWorkflowCommand._ALLOWED_PREFIXES) - @staticmethod def _can_manage_workflows(params): enforcements = getattr(params, 'enforcements', None) @@ -62,10 +52,6 @@ def _can_manage_workflows(params): ) def execute_args(self, params, args, **kwargs): - if not self._is_allowed_server(params): - logging.warning(f"{bcolors.WARNING}{self.NOTICE_MSG}{bcolors.ENDC}") - return - self._current_params = params pos = args.find(' ') if args else -1 diff --git a/keepercommander/constants.py b/keepercommander/constants.py index f60ae1c90..89ced2729 100644 --- a/keepercommander/constants.py +++ b/keepercommander/constants.py @@ -112,6 +112,7 @@ class PrivilegeScope(enum.IntEnum): ("MASTER_PASSWORD_MINIMUM_UPPER", 12, "LONG", "LOGIN_SETTINGS"), ("MASTER_PASSWORD_MINIMUM_LOWER", 13, "LONG", "LOGIN_SETTINGS"), ("MASTER_PASSWORD_MINIMUM_DIGITS", 14, "LONG", "LOGIN_SETTINGS"), + ("MASTER_PASSWORD_MINIMUM_LENGTH_NO_PROMPT", 15, "LONG", "LOGIN_SETTINGS"), ("MASTER_PASSWORD_RESTRICT_DAYS_BEFORE_REUSE", 16, "LONG", "LOGIN_SETTINGS"), ("REQUIRE_TWO_FACTOR", 20, "BOOLEAN", "TWO_FACTOR_AUTHENTICATION"), ("MASTER_PASSWORD_MAXIMUM_DAYS_BEFORE_CHANGE", 22, "LONG", "LOGIN_SETTINGS"), @@ -231,6 +232,7 @@ class PrivilegeScope(enum.IntEnum): ("ALLOW_VIEW_KCM_RECORDINGS", 234, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_TOTP_FIELD", 235, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("ALLOW_VIEW_RBI_RECORDINGS", 236, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), + ("USE_DEFAULT_BROWSER_FOR_SSO", 237, "TERNARY_DEN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_MANAGE_TLA", 238, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_SELF_DESTRUCT_RECORDS", 239, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_PERSONAL_USING_BUSINESS_DOMAINS", 240, "STRING", "ACCOUNT_ENFORCEMENTS"), @@ -240,6 +242,8 @@ class PrivilegeScope(enum.IntEnum): ("WARN_PERSONAL_USING_BUSINESS_SITES", 244, "STRING", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_ACCOUNT_SWITCHING", 245, "BOOLEAN", "AUTHENTICATION_ENFORCEMENTS"), ("RESTRICT_PASSKEY_LOGIN", 246, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), + # NOTE: 247 server name is ALLOW_CAN_EDIT_EXTERNAL_SHARES (positive). Commander's + # RESTRICT_ name is kept for backward compat but the polarity is inverted vs the server. ("RESTRICT_CAN_EDIT_EXTERNAL_SHARES", 247, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_SNAPSHOT_TOOL", 248, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_FORCEFIELD", 249, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), @@ -248,6 +252,16 @@ class PrivilegeScope(enum.IntEnum): ("RESTRICT_SF_FOLDER_DELETION", 253, "BOOLEAN", "SHARING_ENFORCEMENTS"), ("RESTRICT_PLATFORM_PASSKEY_LOGIN", 254, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_CROSS_PLATFORM_PASSKEY_LOGIN", 255, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_SESSION_WEB", 256, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_SESSION_MOBILE", 257, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_SESSION_DESKTOP", 258, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_SESSION_CONSOLE", 259, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_DEFAULT_WEB", 260, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_DEFAULT_MOBILE", 261, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_DEFAULT_DESKTOP", 262, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_DEFAULT_CONSOLE", 263, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("LOGOUT_TIMER_CONSOLE", 264, "LONG", "ACCOUNT_SETTINGS"), + ("ALLOW_CONFIGURE_WORKFLOW_SETTINGS", 267, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ] _COMPOUND_ENFORCEMENTS = [ diff --git a/tests/test_pam_workflow.py b/tests/test_pam_workflow.py new file mode 100644 index 000000000..155025c55 --- /dev/null +++ b/tests/test_pam_workflow.py @@ -0,0 +1,362 @@ +"""Unit tests for PAM import workflow parsing, validation, and protobuf assembly.""" + +import unittest +from unittest.mock import MagicMock, patch + +from keepercommander.error import CommandError, KeeperApiError +from keepercommander.commands.pam_import.base import PamWorkflowOptions +from keepercommander.commands.pam_import import workflow_apply +from keepercommander.commands.pam_import.workflow_apply import ( + _build_temporal_filter, + _build_parameters, + _DAY_PROTO_MAP, + _is_throttle_error, + _post_with_throttle_retry, +) +from keepercommander.commands.workflow.helpers import WorkflowFormatter +from keepercommander.proto import workflow_pb2 + +# Server expects HHMM integer (workflow.proto:140 "HHMM format" + server validator). +_parse_time_to_hhmm = WorkflowFormatter._parse_time_to_hhmm + + +# --------------------------------------------------------------------------- +# Duration parsing +# --------------------------------------------------------------------------- + +class TestParseDuration(unittest.TestCase): + + def test_hours(self): + self.assertEqual(PamWorkflowOptions._parse_duration('8h'), 8 * 3_600_000) + + def test_minutes(self): + self.assertEqual(PamWorkflowOptions._parse_duration('30m'), 30 * 60_000) + + def test_days(self): + self.assertEqual(PamWorkflowOptions._parse_duration('1d'), 86_400_000) + + def test_bare_integer_treated_as_minutes(self): + self.assertEqual(PamWorkflowOptions._parse_duration('45'), 45 * 60_000) + + def test_none_returns_default(self): + self.assertEqual(PamWorkflowOptions._parse_duration(None), 86_400_000) + + def test_zero_raises(self): + with self.assertRaises(CommandError): + PamWorkflowOptions._parse_duration('0h') + + def test_negative_raises(self): + with self.assertRaises(CommandError): + PamWorkflowOptions._parse_duration('-1d') + + def test_invalid_string_raises(self): + with self.assertRaises(CommandError): + PamWorkflowOptions._parse_duration('invalid') + + def test_uppercase_suffix(self): + self.assertEqual(PamWorkflowOptions._parse_duration('2H'), 2 * 3_600_000) + + +# --------------------------------------------------------------------------- +# Day mapping +# --------------------------------------------------------------------------- + +class TestDayMapping(unittest.TestCase): + + def test_all_3letter_tokens_in_map(self): + expected = {'mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun'} + self.assertEqual(set(_DAY_PROTO_MAP.keys()), expected) + + def test_monday_maps_to_proto(self): + self.assertEqual(_DAY_PROTO_MAP['mon'], workflow_pb2.MONDAY) + + def test_friday_maps_to_proto(self): + self.assertEqual(_DAY_PROTO_MAP['fri'], workflow_pb2.FRIDAY) + + +# --------------------------------------------------------------------------- +# Time-of-day parsing +# --------------------------------------------------------------------------- + +class TestParseTimeToHHMM(unittest.TestCase): + """Server expects HHMM integer encoding per workflow.proto and the server-side + validator (returns "Expected HHMM integer with HH in 0-23 and MM in 0-59").""" + + def test_midnight(self): + self.assertEqual(_parse_time_to_hhmm('00:00'), 0) + + def test_nine_am(self): + self.assertEqual(_parse_time_to_hhmm('09:00'), 900) + + def test_half_past_five_pm(self): + self.assertEqual(_parse_time_to_hhmm('17:30'), 1730) + + def test_invalid_format_raises(self): + with self.assertRaises(CommandError): + _parse_time_to_hhmm('25:00') + + def test_non_numeric_raises(self): + with self.assertRaises(CommandError): + _parse_time_to_hhmm('ab:cd') + + +# --------------------------------------------------------------------------- +# V2: trivial workflow detection +# --------------------------------------------------------------------------- + +class TestTrivialWorkflow(unittest.TestCase): + + def test_empty_dict_returns_none(self): + self.assertIsNone(PamWorkflowOptions.load({})) + + def test_none_returns_none(self): + self.assertIsNone(PamWorkflowOptions.load(None)) + + def test_all_flags_off_no_temporal_returns_none(self): + self.assertIsNone(PamWorkflowOptions.load({ + 'approvals_needed': 0, + 'checkout_needed': False, + 'require_mfa': False, + })) + + def test_checkout_needed_true_is_non_trivial(self): + opts = PamWorkflowOptions.load({'checkout_needed': True, 'access_duration': '2h'}) + self.assertIsNotNone(opts) + self.assertTrue(opts.checkout_needed) + + def test_require_mfa_true_is_non_trivial(self): + opts = PamWorkflowOptions.load({'require_mfa': True}) + self.assertIsNotNone(opts) + + def test_allowed_days_is_non_trivial(self): + opts = PamWorkflowOptions.load({'allowed_times': {'allowed_days': ['mon'], 'timezone': 'UTC'}}) + self.assertIsNotNone(opts) + + def test_approvals_needed_gt0_is_non_trivial(self): + opts = PamWorkflowOptions.load({'approvals_needed': 2}) + self.assertIsNotNone(opts) + + +# --------------------------------------------------------------------------- +# V7: escalation_after requires escalation: true +# --------------------------------------------------------------------------- + +class TestEscalationValidation(unittest.TestCase): + + def test_escalation_after_without_escalation_raises(self): + data = { + 'approvals_needed': 1, + 'approvers': [{ + 'principal': {'type': 'user', 'email': 'a@b.com'}, + 'escalation': False, + 'escalation_after': '30m', + }], + } + with self.assertRaises(CommandError): + PamWorkflowOptions.load(data) + + def test_escalation_after_with_escalation_true_ok(self): + data = { + 'approvals_needed': 1, + 'approvers': [{ + 'principal': {'type': 'user', 'email': 'a@b.com'}, + 'escalation': True, + 'escalation_after': '30m', + }], + } + opts = PamWorkflowOptions.load(data) + self.assertIsNotNone(opts) + self.assertEqual(opts.approvers[0]['escalation_after_ms'], 30 * 60_000) + + +# --------------------------------------------------------------------------- +# V8: time_ranges requires timezone +# --------------------------------------------------------------------------- + +class TestTimezoneRequirement(unittest.TestCase): + + def test_time_ranges_without_timezone_raises(self): + data = { + 'require_mfa': True, + 'allowed_times': { + 'time_ranges': [{'start': '09:00', 'end': '17:00'}], + }, + } + with self.assertRaises(CommandError): + PamWorkflowOptions.load(data) + + def test_time_ranges_with_timezone_ok(self): + data = { + 'require_mfa': True, + 'allowed_times': { + 'time_ranges': [{'start': '09:00', 'end': '17:00'}], + 'timezone': 'America/New_York', + }, + } + opts = PamWorkflowOptions.load(data) + self.assertIsNotNone(opts) + self.assertEqual(opts.timezone, 'America/New_York') + self.assertEqual(len(opts.time_ranges), 1) + + +# --------------------------------------------------------------------------- +# V9: access_duration default +# --------------------------------------------------------------------------- + +class TestAccessDurationDefault(unittest.TestCase): + + def test_missing_access_duration_defaults_to_1d(self): + opts = PamWorkflowOptions.load({'approvals_needed': 1}) + self.assertEqual(opts.access_duration_ms, 86_400_000) + + def test_explicit_duration_parsed(self): + opts = PamWorkflowOptions.load({'approvals_needed': 1, 'access_duration': '4h'}) + self.assertEqual(opts.access_duration_ms, 4 * 3_600_000) + + +# --------------------------------------------------------------------------- +# Protobuf assembly: _build_parameters +# --------------------------------------------------------------------------- + +class TestBuildParameters(unittest.TestCase): + + def _make_uid_bytes(self): + import base64 + return base64.urlsafe_b64decode('AAAAAAAAAAAAAAAAAAAAAA==') + + def test_basic_fields_populated(self): + opts = PamWorkflowOptions.load({ + 'approvals_needed': 2, + 'checkout_needed': True, + 'require_mfa': True, + 'access_duration': '8h', + }) + uid_bytes = self._make_uid_bytes() + params_proto = _build_parameters(uid_bytes, 'Test Machine', opts) + self.assertEqual(params_proto.approvalsNeeded, 2) + self.assertTrue(params_proto.checkoutNeeded) + self.assertTrue(params_proto.requireMFA) + self.assertEqual(params_proto.accessLength, 8 * 3_600_000) + self.assertEqual(params_proto.resource.value, uid_bytes) + self.assertEqual(params_proto.resource.name, 'Test Machine') + + def test_temporal_filter_attached(self): + opts = PamWorkflowOptions.load({ + 'require_mfa': True, + 'allowed_times': { + 'allowed_days': ['mon', 'fri'], + 'time_ranges': [{'start': '09:00', 'end': '17:00'}], + 'timezone': 'UTC', + }, + }) + uid_bytes = self._make_uid_bytes() + params_proto = _build_parameters(uid_bytes, 'Box', opts) + at = params_proto.allowedTimes + self.assertIn(workflow_pb2.MONDAY, at.allowedDays) + self.assertIn(workflow_pb2.FRIDAY, at.allowedDays) + self.assertEqual(len(at.timeRanges), 1) + # HHMM integer encoding: 09:00 -> 900, 17:00 -> 1700 + self.assertEqual(at.timeRanges[0].startTime, 900) + self.assertEqual(at.timeRanges[0].endTime, 1700) + self.assertEqual(at.timeZone, 'UTC') + + def test_no_allowed_times_no_temporal(self): + opts = PamWorkflowOptions.load({'approvals_needed': 1}) + uid_bytes = self._make_uid_bytes() + params_proto = _build_parameters(uid_bytes, 'Box', opts) + self.assertFalse(params_proto.HasField('allowedTimes')) + + +# --------------------------------------------------------------------------- +# validate_principals +# --------------------------------------------------------------------------- + +class TestValidatePrincipals(unittest.TestCase): + + def _make_params(self, team_uids): + p = MagicMock() + p.team_cache = {uid: {} for uid in team_uids} + return p + + def test_known_team_uid_passes(self): + opts = PamWorkflowOptions.load({ + 'approvals_needed': 1, + 'approvers': [{'principal': {'type': 'team', 'team_uid_base64url': 'validUID123'}}], + }) + params = self._make_params(['validUID123']) + opts.validate_principals(params, 'MyResource') + + def test_unknown_team_uid_raises(self): + opts = PamWorkflowOptions.load({ + 'approvals_needed': 1, + 'approvers': [{'principal': {'type': 'team', 'team_uid_base64url': 'unknownUID'}}], + }) + params = self._make_params(['otherUID']) + with self.assertRaises(CommandError): + opts.validate_principals(params, 'MyResource') + + def test_user_principal_not_checked_against_team_cache(self): + opts = PamWorkflowOptions.load({ + 'approvals_needed': 1, + 'approvers': [{'principal': {'type': 'user', 'email': 'user@example.com'}}], + }) + params = self._make_params([]) + opts.validate_principals(params) + + +# --------------------------------------------------------------------------- +# Throttle / 429 retry wrapper +# --------------------------------------------------------------------------- + +class TestThrottleErrorDetection(unittest.TestCase): + + def test_keeper_api_error_429_is_throttle(self): + self.assertTrue(_is_throttle_error(KeeperApiError(429, 'Too many requests'))) + + def test_keeper_api_error_500_is_not_throttle(self): + self.assertFalse(_is_throttle_error(KeeperApiError(500, 'Internal error'))) + + def test_string_throttle_in_msg_is_throttle(self): + self.assertTrue(_is_throttle_error(Exception('record was throttled'))) + + def test_too_many_in_msg_is_throttle(self): + self.assertTrue(_is_throttle_error(Exception('Too many requests'))) + + def test_unrelated_error_is_not_throttle(self): + self.assertFalse(_is_throttle_error(Exception('connection refused'))) + + +class TestThrottleRetry(unittest.TestCase): + + def test_no_retry_on_non_throttle(self): + with patch.object(workflow_apply, '_post_request_to_router', + side_effect=KeeperApiError(500, 'boom')) as mock_post: + with self.assertRaises(KeeperApiError): + _post_with_throttle_retry(MagicMock(), 'read_workflow_config') + self.assertEqual(mock_post.call_count, 1) + + def test_retries_then_succeeds(self): + # First two calls 429, third succeeds. Patch sleep to keep test fast. + side_effects = [KeeperApiError(429, 'Too many requests'), + KeeperApiError(429, 'Too many requests'), + 'OK'] + with patch.object(workflow_apply, '_post_request_to_router', + side_effect=side_effects) as mock_post, \ + patch.object(workflow_apply.time, 'sleep') as mock_sleep: + result = _post_with_throttle_retry(MagicMock(), 'read_workflow_config') + self.assertEqual(result, 'OK') + self.assertEqual(mock_post.call_count, 3) + # Two backoff sleeps: 10s, 15s (10 * 1.5) + self.assertEqual([round(c.args[0], 2) for c in mock_sleep.call_args_list], [10.0, 15.0]) + + def test_exhausts_retries_and_reraises(self): + with patch.object(workflow_apply, '_post_request_to_router', + side_effect=KeeperApiError(429, 'Too many requests')) as mock_post, \ + patch.object(workflow_apply.time, 'sleep'): + with self.assertRaises(KeeperApiError): + _post_with_throttle_retry(MagicMock(), 'read_workflow_config') + self.assertEqual(mock_post.call_count, workflow_apply._THROTTLE_MAX_RETRIES) + + +if __name__ == '__main__': + unittest.main()