Source code for patent_client.util.pydantic_util

import datetime
import importlib
import re
import typing as tp

from async_property.base import AsyncPropertyDescriptor
from dateutil.parser import isoparse
from dateutil.parser import parse as parse_dt
from pydantic import BaseModel as PydanticBaseModel
from pydantic import BeforeValidator, ConfigDict
from typing_extensions import Annotated

from patent_client.util.manager import Manager


def parse_datetime(date_obj: tp.Union[str, datetime.datetime]) -> datetime.datetime:
    if isinstance(date_obj, datetime.datetime):
        return date_obj
    try:
        return isoparse(date_obj)
    except ValueError:
        return parse_dt(date_obj)


def parse_date(
    date_obj: tp.Union[str, datetime.datetime, datetime.date],
) -> datetime.date:
    if isinstance(date_obj, datetime.date):
        return date_obj
    elif isinstance(date_obj, datetime.datetime):
        return date_obj.date()
    try:
        return isoparse(date_obj).date()
    except ValueError:
        return parse_dt(date_obj).date()


DateTime = Annotated[datetime.datetime, BeforeValidator(parse_datetime)]
Date = Annotated[datetime.date, BeforeValidator(parse_date)]


class ClassProperty:
    def __init__(self, getter):
        self.getter = getter

    def __get__(self, instance, owner):
        return self.getter(owner)


class UnmanagedModelException(Exception):
    pass


M = tp.TypeVar("M", bound=Manager)


leading_period_re = re.compile(r"^\.+")


def get_class(class_name: str, base_class: type) -> type:
    if class_name.startswith("."):
        base_module = base_class.__module__
        # Count the number of dots to determine how many levels to go up
        level_up = leading_period_re.match(class_name).span()[1]
        module_path = base_module.split(".")[:-level_up]
        absolute_model_name = ".".join(module_path) + "." + class_name[level_up:]
    else:
        absolute_model_name = class_name

    module_name, class_name = absolute_model_name.rsplit(".", 1)
    try:
        module = importlib.import_module(module_name)
        return getattr(module, class_name)
    except (ImportError, AttributeError) as e:
        raise ImportError(f"Could not import {absolute_model_name}: {e}")


[docs]class BaseModel(PydanticBaseModel, tp.Generic[M]): __manager__: tp.Optional[str] = None model_config = ConfigDict( ignored_types=(ClassProperty, AsyncPropertyDescriptor), ) @ClassProperty def objects(cls) -> M: if cls.__manager__ is not None: manager_path = cls.__manager__ else: manager_module = cls.__module__.split(".model")[0] + ".manager" manager_path = manager_module + "." + cls.__name__ + "Manager" return get_class(manager_path, base_class=cls)() def to_dict(self): return self.model_dump() def items(self): return iter(self) def _get_model(self, model_name: str, base_class=None) -> tp.Type["BaseModel"]: base_class = self.__class__ if base_class is None else base_class return get_class(model_name, base_class)