from django import forms from django.core import exceptions from django.db.backends.postgresql.psycopg_any import ( DateRange, DateTimeTZRange, NumericRange, ) from django.forms.widgets import HiddenInput, MultiWidget from django.utils.translation import gettext_lazy as _ __all__ = [ "BaseRangeField", "IntegerRangeField", "DecimalRangeField", "DateTimeRangeField", "DateRangeField", "HiddenRangeWidget", "RangeWidget", ] class RangeWidget(MultiWidget): def __init__(self, base_widget, attrs=None): widgets = (base_widget, base_widget) super().__init__(widgets, attrs) def decompress(self, value): if value: return (value.lower, value.upper) return (None, None) class HiddenRangeWidget(RangeWidget): """A widget that splits input into two inputs.""" def __init__(self, attrs=None): super().__init__(HiddenInput, attrs) class BaseRangeField(forms.MultiValueField): default_error_messages = { "invalid": _("Enter two valid values."), "bound_ordering": _( "The start of the range must not exceed the end of the range." ), } hidden_widget = HiddenRangeWidget def __init__(self, **kwargs): if "widget" not in kwargs: kwargs["widget"] = RangeWidget(self.base_field.widget) if "fields" not in kwargs: kwargs["fields"] = [ self.base_field(required=False), self.base_field(required=False), ] kwargs.setdefault("required", False) kwargs.setdefault("require_all_fields", False) self.range_kwargs = {} if default_bounds := kwargs.pop("default_bounds", None): self.range_kwargs = {"bounds": default_bounds} super().__init__(**kwargs) def prepare_value(self, value): lower_base, upper_base = self.fields if isinstance(value, self.range_type): return [ lower_base.prepare_value(value.lower), upper_base.prepare_value(value.upper), ] if value is None: return [ lower_base.prepare_value(None), upper_base.prepare_value(None), ] return value def compress(self, values): if not values: return None lower, upper = values if lower is not None and upper is not None and lower > upper: raise exceptions.ValidationError( self.error_messages["bound_ordering"], code="bound_ordering", ) try: range_value = self.range_type(lower, upper, **self.range_kwargs) except TypeError: raise exceptions.ValidationError( self.error_messages["invalid"], code="invalid", ) else: return range_value class IntegerRangeField(BaseRangeField): default_error_messages = {"invalid": _("Enter two whole numbers.")} base_field = forms.IntegerField range_type = NumericRange class DecimalRangeField(BaseRangeField): default_error_messages = {"invalid": _("Enter two numbers.")} base_field = forms.DecimalField range_type = NumericRange class DateTimeRangeField(BaseRangeField): default_error_messages = {"invalid": _("Enter two valid date/times.")} base_field = forms.DateTimeField range_type = DateTimeTZRange class DateRangeField(BaseRangeField): default_error_messages = {"invalid": _("Enter two valid dates.")} base_field = forms.DateField range_type = DateRange