diff --git a/.gitignore b/.gitignore index 8154e92..5cf1cf6 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,8 @@ test.py venv/ .venv/ env/ -.env \ No newline at end of file +.env/ +venv*/ +.venv*/ +env*/ +.env*/ \ No newline at end of file diff --git a/flask_parameter_validation/parameter_types/parameter.py b/flask_parameter_validation/parameter_types/parameter.py index c540a08..187a659 100644 --- a/flask_parameter_validation/parameter_types/parameter.py +++ b/flask_parameter_validation/parameter_types/parameter.py @@ -196,8 +196,8 @@ def convert(self, value, allowed_types, current_error=None): error = ValueError(f"datetime format does not match: {self.datetime_format}") if blank_none and type(None) in allowed_types and str in allowed_types and type(value) is str and len(value) == 0: return None - if any(isclass(allowed_type) and (issubclass(allowed_type, str) or issubclass(allowed_type, int) and issubclass(allowed_type, Enum)) for allowed_type in allowed_types): - for allowed_type in allowed_types: + for allowed_type in allowed_types: + if isclass(allowed_type) and (issubclass(allowed_type, str) or issubclass(allowed_type, int) and issubclass(allowed_type, Enum)): if issubclass(allowed_type, Enum): try: if issubclass(allowed_type, int): diff --git a/flask_parameter_validation/test/test_form_params.py b/flask_parameter_validation/test/test_form_params.py index 28aa698..55d1709 100644 --- a/flask_parameter_validation/test/test_form_params.py +++ b/flask_parameter_validation/test/test_form_params.py @@ -1299,6 +1299,22 @@ def test_list_func(client): assert "error" in r.json +def test_list_in_union(client): + url = "/form/list/in_union" + # Test that input passing func yields input + v = ["hi", "ho"] + r = client.post(url, data={"v": v}) + assert "v" in r.json + assert type(r.json["v"]) is list + assert len(r.json["v"]) == 2 + list_assertion_helper(2, str, v, r.json["v"]) + # Test that input passing func yields input + v = "hi" + r = client.post(url, data={"v": v}) + assert "v" in r.json + assert type(r.json["v"]) is str + + def test_min_list_length(client): url = "/form/list/min_list_length" # Test that below length yields error diff --git a/flask_parameter_validation/test/test_json_params.py b/flask_parameter_validation/test/test_json_params.py index 8b56576..6c0e25c 100644 --- a/flask_parameter_validation/test/test_json_params.py +++ b/flask_parameter_validation/test/test_json_params.py @@ -1404,6 +1404,25 @@ def test_list_func(client): assert "error" in r.json +def test_list_in_union(client): + url = "/json/list/in_union" + # Test that input passing func yields input + v = ["hi", "ho"] + r = client.post(url, json={"v": v}) + assert "v" in r.json + assert type(r.json["v"]) is list + assert len(r.json["v"]) == 2 + list_assertion_helper(2, str, v, r.json["v"]) + # Test that input passing func yields input + v = "hi" + r = client.post(url, json={"v": v}) + assert "v" in r.json + assert type(r.json["v"]) is str + # Test that input failing func yields error + r = client.post(url, json={"v": [0.4, 0.5]}) + assert "error" in r.json + + def test_min_list_length(client): url = "/json/list/min_list_length" # Test that below length yields error diff --git a/flask_parameter_validation/test/test_multi_source_params.py b/flask_parameter_validation/test/test_multi_source_params.py index 99e516c..71d2e46 100644 --- a/flask_parameter_validation/test/test_multi_source_params.py +++ b/flask_parameter_validation/test/test_multi_source_params.py @@ -433,6 +433,29 @@ def test_multi_source_list_dict(client, source_a, source_b): r = client.get(url) assert "error" in r.json +@pytest.mark.parametrize(*common_parameters) +def test_multi_source_list_in_union(client, source_a, source_b): + if source_a == source_b or "route" in [source_a, source_b]: # Duplicate sources shouldn't be something someone does, so we won't test for it, Route does not support parameters of type 'List' + return + l = ["hi", "ho"] + url = f"/ms_{source_a}_{source_b}/list/in_union" + for source in [source_a, source_b]: + # Test that present input yields input value + r = None + if source == "query": + r = client.get(url, query_string={"v": l}) + elif source == "form": + r = client.get(url, data={"v": l}) + elif source == "json": + r = client.get(url, json={"v": l}) + assert r is not None + assert "v" in r.json + assert json.dumps(r.json["v"]) == json.dumps(l) + + # Test that missing input yields error + r = client.get(url) + assert "error" in r.json + @pytest.mark.parametrize(*common_parameters) def test_multi_source_str(client, source_a, source_b): if source_a == source_b: # This shouldn't be something someone does, so we won't test for it diff --git a/flask_parameter_validation/test/test_query_params.py b/flask_parameter_validation/test/test_query_params.py index 2feb83c..ff05927 100644 --- a/flask_parameter_validation/test/test_query_params.py +++ b/flask_parameter_validation/test/test_query_params.py @@ -2409,6 +2409,22 @@ def test_list_func(client): assert "error" in r.json +def test_list_in_union(client): + url = "/query/list/in_union" + # Test that input passing func yields input + v = ["hi", "ho"] + r = client.get(url, query_string={"v": v}) + assert "v" in r.json + assert type(r.json["v"]) is list + assert len(r.json["v"]) == 2 + list_assertion_helper(2, str, v, r.json["v"]) + # Test that input passing func yields input + v = "hi" + r = client.get(url, query_string={"v": v}) + assert "v" in r.json + assert type(r.json["v"]) is str + + def test_min_list_length(client): url = "/query/list/min_list_length" # Test that below length yields error diff --git a/flask_parameter_validation/test/testing_blueprints/list_blueprint.py b/flask_parameter_validation/test/testing_blueprints/list_blueprint.py index 4367068..fd7ce51 100644 --- a/flask_parameter_validation/test/testing_blueprints/list_blueprint.py +++ b/flask_parameter_validation/test/testing_blueprints/list_blueprint.py @@ -384,4 +384,15 @@ def dict_args_str_union(v: list[dict[str, Union[str, int]]] = ParamType(list_dis assert type(val) is str or type(val) is int return jsonify({"v": v}) + @decorator("/in_union") + @ValidateParameters() + def in_union(v: Union[str, list[str]] = ParamType(list_disable_query_csv=True)): + if type(v) is list: + for ele in v: + assert type(ele) is str + else: + assert type(v) is str + return jsonify({"v": v}) + return list_bp + diff --git a/flask_parameter_validation/test/testing_blueprints/multi_source_blueprint.py b/flask_parameter_validation/test/testing_blueprints/multi_source_blueprint.py index aa62ec2..2432d69 100644 --- a/flask_parameter_validation/test/testing_blueprints/multi_source_blueprint.py +++ b/flask_parameter_validation/test/testing_blueprints/multi_source_blueprint.py @@ -156,6 +156,17 @@ def multi_source_list_dict_str_union(v: list[dict[str, Union[str, int]]] = Multi def multi_source_optional_list(v: Optional[List[int]] = MultiSource(sources[0], sources[1])): return jsonify({"v": v}) + @param_bp.route("/list/in_union", methods=["GET", "POST"]) + # Route doesn't support List parameters + @ValidateParameters() + def in_union(v: Union[str, list[str]] = MultiSource(sources[0], sources[1])): + if type(v) is list: + for ele in v: + assert type(ele) is str + else: + assert type(v) is str + return jsonify({"v": v}) + @param_bp.route("/required_str", methods=["GET", "POST"]) @param_bp.route("/required_str/", methods=["GET", "POST"]) @ValidateParameters()