'Combining a descriptor class with dataclass and field

I am using a dataclass and field to pass in a default value. When an argument is provided I want to validate it using a descriptor class.

Is there any way to utilize the benefits of field (repr, default, init, etc) while getting the validator benefits of a descriptor class?

from dataclasses import dataclass, field


class Descriptor:
    def __init__(self, default):
        self.default = default

    def __set_name__(self, owner, name):
        self.name = name

    def __get__(self, obj, objtype=None):
        if obj:
            return vars(obj).get(self.name)
        else:
            return None

    def __set__(self, obj, value):
        if not value:
            value = self.default
        else:
            value = field(default=int(value), repr=False)
        vars(obj)[self.name] = value


@dataclass
class Person:
    age: str = Descriptor(field(default=3, repr=False))

    # Many additional attributes 
    # using same descriptor class


p = Person()
r = Person(2.37)


Solution 1:[1]

Your idea is ok, but isn't pretty and doesn't work
My first idea was for Descriptor to inherit dataclasses.Field
but that doesn't work either

What you should do is somewhat roundabout
But, that is how things should be done even without dataclasses
It is to make 2 attributes age and _age

import dataclasses as dc

class GetSet:
    def __init__(self, predicate = lambda value: True):
        self.predicate = predicate

    def __set__(self, obj, value):
        if isinstance(value, GetSet):
            return # init will try to assign `obj.age = GetSet(...)`, but we will ignore that
            # this actually happens because it thinks `GetSet` should be default argument
        if not self.predicate(value):
            raise ValueError
        obj._age = value

    def __get__(self, obj, cls):
        if obj is not None: # called obj.age
            return obj._age
        return self # called Person.age

@dc.dataclass
class Person:
    # this one doesn't go to init, hence init=False, rest is as you wish
    _age: int = dc.field(default = 20, repr=False, init=False)
    # this one goes to init, if that is not desired, just remove annotation `: int`
    age: int = GetSet(lambda age: age >= 18) # this attribute will not be allowed to be under 18

and alternative is to use built-in property

@dc.dataclass
class Person:
    _age: int = dc.field(default = 20, repr=False, init=False)
    age: int
    @property
    def age(self):
        return self._age
    @age.setter
    def age(self, value):
        if isinstance(value, property):
            return # we still need to ignore this
        if value < 18:
            raise ValueError('no underaged')
        self._age = value

Solution 2:[2]

There is a way to get all the benefits. Note that dataclass is generating code for you, so you can do modifications by inheritance.

from dataclasses import dataclass
from dataclasses import field
from typing import Union


class Descriptor:

    def __set_name__(self, owner, name):
        self.name = name

    def __get__(self, obj, obj_type=None):
        if obj:
            return vars(obj)[self.name]
        else:
            return None

    def __set__(self, obj, value):
        if isinstance(value, str):
            value = int(value)
        vars(obj)[self.name] = value


@dataclass
class PersonData:
    name: str
    age: Union[int, str] = field(default=3, repr=False)

    # Many additional attributes you want to get the benefit of field for


class Person(PersonData):
    age = Descriptor()

    # Many additional attributes you want to use descriptors for


if __name__ == '__main__':
    p = Person('Mike')
    r = Person('Mary', 2)

    print(p, p.age)
    print(r, r.age)

    r.age = 5
    print(p, p.age)
    print(r, r.age)

You will get the following print:

Person(name='Mike') 3
Person(name='Mary') 2
Person(name='Mike') 3
Person(name='Mary') 5

You can see that the default value defined in PersonData is posted to Person automatically. There is no need to store default in Descriptor anymore.

Solution 3:[3]

You can use a descriptor with a dataclasses.field (at least after this bug was fixed).

There are just a few things that needed to be changed in your code:

Starting in the dataclass, you had the order wrong as to which object is calling which object. The field object is what should be attached to the dataclass with the default of the Descriptor:

@dataclass
class Person:
    age: str = field(default=Descriptor(default=3), repr=False)

Next in Descriptor.__set__, when the age argument is not provided to the constructor, the value argument will actually be the instance of the Descriptor class. So we need to change the guard to see if value is self:

class Descriptor:
    ...
    def __set__(self, obj, value):
        if value is self:
            value = self.default
        ....

Finally, I made one more change to echo the patterns I've seen in the python ecosystem: using the getattr and setattr functions for getting and setting attributes on classes.

Unfortunately, this introduced an infinite recursion bug, so I changed the place the value is stored on the Person object to _age.

All that being said, this works as you intended:

from dataclasses import dataclass, field


class Descriptor:
    def __init__(self, default):
        self.default = default

    def __set_name__(self, owner, name):
        self.private_name = '_' + name

    def __get__(self, obj, objtype=None):
        return getattr(obj, self.private_name)

    def __set__(self, obj, value):
        if value is self:
            value = self.default
        else:
            value = int(value)
        setattr(obj, self.private_name, value)


@dataclass
class Person:
    age: str = field(default=Descriptor(default=3), repr=False)

    # Many additional attributes
    # using same descriptor class


r = Person(2.37)
assert r.age == 2
p = Person()
assert p.age == 3
print(r)
print(p)
print(vars(p))

Solution 4:[4]

Using descriptor with dataclass is quite tricky and require some ugly hack to make it work (to support default_factory, frozen,...). I have some example code working in: https://github.com/google/etils/blob/main/etils/edc/field_utils.py

from etils import edc  # pip install etils[edc]

@dataclasses.dataclass
class A:
  path: epath.Path = edc.field(validate=epath.Path)
  x: int = edc.field(validate=int)
  y: int = edc.field(validate=lambda x: -x, default=5)


a = A(
   path='/some/path'  # Inputs auto-normalized `str` -> `epath.Path`
   x='123',
)
assert isinstance(a.path, epath.Path)
assert a.x == 123
assert a.y == -5

Here is the implementation:

"""Field utils."""

from __future__ import annotations

import dataclasses
import typing
from typing import Any, Callable, Generic, Optional, Type, TypeVar

_Dataclass = Any
_In = Any
_Out = Any
_InT = TypeVar('_InT')
_OutT = TypeVar('_OutT')


def field(
    *,
    validate: Optional[Callable[[_In], _OutT]] = None,
    **kwargs: Any,
) -> dataclasses.Field[_OutT]:
  """Like `dataclasses.field`, but allow `validator`.
  Args:
    validate: A callable `(x) -> x` called each time the variable is assigned.
    **kwargs: Kwargs forwarded to `dataclasses.field`
  Returns:
    The field.
  """
  if validate is None:
    return dataclasses.field(**kwargs)
  else:
    field_ = _Field(validate=validate, field_kwargs=kwargs)
    return typing.cast(dataclasses.Field, field_)


class _Field(Generic[_InT, _OutT]):
  """Field descriptor."""

  def __init__(
      self,
      validate: Callable[[_InT], _OutT],
      field_kwargs: dict[str, Any],
  ) -> None:
    """Constructor.
    Args:
      validate: A callable called each time the variable is assigned.
      field_kwargs: Kwargs forwarded to `dataclasses.field`
    """
    # Attribute name and objtype refer to the object in which the descriptor
    # is applied. E.g. if `A.x = edc.field()`:
    # * _attribute_name = 'x'
    # * _objtype = A
    self._attribute_name: Optional[str] = None
    self._objtype: Optional[Type[_Dataclass]] = None

    self._validate = validate
    self._field_kwargs = field_kwargs

    # Whether `__get__` has not been called yet. See `__get__` for details.
    self._first_getattr_call: bool = True

  def __set_name__(self, objtype: Type[_Dataclass], name: str) -> None:
    """Bind the descriptor to the class (PEP 487)."""
    self._objtype = objtype
    self._attribute_name = name

  def __get__(
      self,
      obj: Optional[_Dataclass],
      objtype: Optional[Type[_Dataclass]] = None,
  ) -> _OutT:
    """Called when `MyDataclass.x` or `my_dataclass.x`."""
    # Called as `MyDataclass.my_attribute`
    if obj is None:
      if self._first_getattr_call:
        # Count the number of times `dataclasses.dataclass(cls)` calls
        # `getattr(cls, f.name)`.
        # The first time, we return a `dataclasses.Field` to let dataclass
        # do the magic.
        # The second time, `dataclasses.dataclass` delete the descriptor if
        # `isinstance(getattr(cls, f.name, None), Field)`. So it is very
        # important to return anything except a `dataclasses.Field`.
        # This rely on implementation detail, but seems to hold for python
        # 3.6-3.10.
        self._first_getattr_call = False
        return dataclasses.field(**self._field_kwargs)
      else:
        # TODO(epot): Could better handle default value: Either by returning
        # the default value, or raising an AttributeError. Currently, we just
        # return the descriptor:
        # assert isinstance(MyDataclass.my_attribute, _Field)
        return self
    else:
      # Called as `my_dataclass.my_path`
      return _getattr(obj, self._attribute_name)

  def __set__(self, obj: _Dataclass, value: _InT) -> None:
    """Called as `my_dataclass.x = value`."""
    # Validate the value during assignement
    _setattr(obj, self._attribute_name, self._validate(value))


# Because there is one instance of the `_Field` per class, shared across all
# class instances, we need to store the per-object state somewhere.
# The simplest is to attach the state in an extra `dict[str, value]`:
# `_dataclass_field_values`.


def _getattr(
    obj: _Dataclass,
    attribute_name: str,
) -> _Out:
  """Returns the `obj.attribute_name`."""
  _init_dataclass_state(obj)
  # Accessing the attribute before it was set (e.g. before super().__init__)
  if attribute_name not in obj._dataclass_field_values:  # pylint: disable=protected-access
    raise AttributeError(
        f"type object '{type(obj).__qualname__}' has no attribute "
        f"'{attribute_name}'")
  else:
    return obj._dataclass_field_values[attribute_name]  # pylint: disable=protected-access


def _setattr(
    obj: _Dataclass,
    attribute_name: str,
    value: _In,
) -> None:
  """Set the `obj.attribute_name = value`."""
  # Note: In `dataclasses.dataclass(frozen=True)`, obj.__setattr__ will
  # correctly raise a `FrozenInstanceError` before `DataclassField.__set__` is
  # called.
  _init_dataclass_state(obj)
  obj._dataclass_field_values[attribute_name] = value  # pylint: disable=protected-access


def _init_dataclass_state(obj: _Dataclass) -> None:
  """Initialize the object state containing all DataclassField values."""
  if not hasattr(obj, '_dataclass_field_values'):
    # Use object.__setattr__ for frozen dataclasses
    object.__setattr__(obj, '_dataclass_field_values', {})

There might be simpler way, but this one was fully tested.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2 daizhirui
Solution 3
Solution 4 Conchylicultor