'Python Enum and Pydantic : accept enum member's composition

I have an enum :

from enum import Enum

class MyEnum(Enum):
    val1 = "val1"
    val2 = "val2"
    val3 = "val3"

I would like to validate a pydantic field based on that enum.

from pydantic import BaseModel

class MyModel(BaseModel):
    my_enum_field: MyEnum

BUT I would like this validation to also accept string that are composed by the Enum members.

So for example : "val1_val2_val3" or "val1_val3" are valid input.

I cannot make this field as a string field with a validator since I use a test library (hypothesis and pydantic-factories) that needs this type in order to render one of the values from the enum (for mocking random inputs)

So this :

from pydantic import BaseModel, validator

class MyModel(BaseModel):
    my_enum_field: str

    @validator('my_enum_field', pre=True)
    def validate_my_enum_field(cls, value):
        split_val = str(value).split('_')
        if not all(v in MyEnum._value2member_map_ for v in split_val):
            raise ValueError()
        return value

Could work, but break my test suites because the field is anymore of enum types.

How to keep this field as an Enum type (to make my mock structures still valid) and make pydantic accept composite values in the same time ?

So far, I tried to dynamically extend the enum, with no success.



Solution 1:[1]

I looked at this a bit further, and I believe something like this could be helpful. You can create a new class to define the property that is a list of enum values.

This class can supply a customized validate method and supply a __modify_schema__ to keep the information present about being a string in the json schema.

We can define a base class for generic lists of concatenated enums like this:

from typing import Generic, TypeVar, Type
from enum import Enum

T = TypeVar("T", bound=Enum)


class ConcatenatedEnum(Generic[T], list[T]):
    enum_type: Type[T]

    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, value: str):
        return list(map(cls.enum_type, value.split("_")))

    @classmethod
    def __modify_schema__(cls, field_schema: dict):
        all_values = ', '.join(f"'{ex.value}'" for ex in cls.enum_type)
        field_schema.update(
            title=f"Concatenation of {cls.enum_type.__name__} values",
            description=f"Underscore delimited list of values {all_values}",
            type="string",
        )
        if "items" in field_schema:
            del field_schema["items"]

In the __modify_schema__ method I also provide a way to generate a description of which values are valid.

To use this in your application:

class MyEnum(Enum):
    val1 = "val1"
    val2 = "val2"
    val3 = "val3"


class MyEnumList(ConcatenatedEnum[MyEnum]):
    enum_type = MyEnum


class MyModel(BaseModel):
    my_enum_field: MyEnumList

Examples Models:

print(MyModel.parse_obj({"my_enum_field": "val1"}))
print(MyModel.parse_obj({"my_enum_field": "val1_val2"}))
my_enum_field=[<MyEnum.val1: 'val1'>]
my_enum_field=[<MyEnum.val1: 'val1'>, <MyEnum.val2: 'val2'>]

Example Schema:

print(json.dumps(MyModel.schema(), indent=2))
{
  "title": "MyModel",
  "type": "object",
  "properties": {
    "my_enum_field": {
      "title": "Concatenation of MyEnum values",
      "description": "Underscore delimited list of values 'val1', 'val2', 'val3'",
      "type": "string"
    }
  },
  "required": [
    "my_enum_field"
  ]
}

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1