from django.contrib.gis.db.models.fields import BaseSpatialField from django.contrib.gis.measure import Distance from django.db import NotSupportedError from django.db.models import Expression, Lookup, Transform from django.db.models.sql.query import Query from django.utils.regex_helper import _lazy_re_compile class RasterBandTransform(Transform): def as_sql(self, compiler, connection): return compiler.compile(self.lhs) class GISLookup(Lookup): sql_template = None transform_func = None distance = False band_rhs = None band_lhs = None def __init__(self, lhs, rhs): rhs, *self.rhs_params = rhs if isinstance(rhs, (list, tuple)) else [rhs] super().__init__(lhs, rhs) self.template_params = {} self.process_rhs_params() def process_rhs_params(self): if self.rhs_params: # Check if a band index was passed in the query argument. if len(self.rhs_params) == (2 if self.lookup_name == "relate" else 1): self.process_band_indices() elif len(self.rhs_params) > 1: raise ValueError("Tuple too long for lookup %s." % self.lookup_name) elif isinstance(self.lhs, RasterBandTransform): self.process_band_indices(only_lhs=True) def process_band_indices(self, only_lhs=False): """ Extract the lhs band index from the band transform class and the rhs band index from the input tuple. """ # PostGIS band indices are 1-based, so the band index needs to be # increased to be consistent with the GDALRaster band indices. if only_lhs: self.band_rhs = 1 self.band_lhs = self.lhs.band_index + 1 return if isinstance(self.lhs, RasterBandTransform): self.band_lhs = self.lhs.band_index + 1 else: self.band_lhs = 1 self.band_rhs, *self.rhs_params = self.rhs_params def get_db_prep_lookup(self, value, connection): # get_db_prep_lookup is called by process_rhs from super class return ("%s", [connection.ops.Adapter(value)]) def process_rhs(self, compiler, connection): if isinstance(self.rhs, Query): # If rhs is some Query, don't touch it. return super().process_rhs(compiler, connection) if isinstance(self.rhs, Expression): self.rhs = self.rhs.resolve_expression(compiler.query) rhs, rhs_params = super().process_rhs(compiler, connection) placeholder = connection.ops.get_geom_placeholder( self.lhs.output_field, self.rhs, compiler ) return placeholder % rhs, rhs_params def get_rhs_op(self, connection, rhs): # Unlike BuiltinLookup, the GIS get_rhs_op() implementation should return # an object (SpatialOperator) with an as_sql() method to allow for more # complex computations (where the lhs part can be mixed in). return connection.ops.gis_operators[self.lookup_name] def as_sql(self, compiler, connection): lhs_sql, lhs_params = self.process_lhs(compiler, connection) rhs_sql, rhs_params = self.process_rhs(compiler, connection) sql_params = (*lhs_params, *rhs_params) template_params = { "lhs": lhs_sql, "rhs": rhs_sql, "value": "%s", **self.template_params, } rhs_op = self.get_rhs_op(connection, rhs_sql) return rhs_op.as_sql(connection, self, template_params, sql_params) # ------------------ # Geometry operators # ------------------ @BaseSpatialField.register_lookup class OverlapsLeftLookup(GISLookup): """ The overlaps_left operator returns true if A's bounding box overlaps or is to the left of B's bounding box. """ lookup_name = "overlaps_left" @BaseSpatialField.register_lookup class OverlapsRightLookup(GISLookup): """ The 'overlaps_right' operator returns true if A's bounding box overlaps or is to the right of B's bounding box. """ lookup_name = "overlaps_right" @BaseSpatialField.register_lookup class OverlapsBelowLookup(GISLookup): """ The 'overlaps_below' operator returns true if A's bounding box overlaps or is below B's bounding box. """ lookup_name = "overlaps_below" @BaseSpatialField.register_lookup class OverlapsAboveLookup(GISLookup): """ The 'overlaps_above' operator returns true if A's bounding box overlaps or is above B's bounding box. """ lookup_name = "overlaps_above" @BaseSpatialField.register_lookup class LeftLookup(GISLookup): """ The 'left' operator returns true if A's bounding box is strictly to the left of B's bounding box. """ lookup_name = "left" @BaseSpatialField.register_lookup class RightLookup(GISLookup): """ The 'right' operator returns true if A's bounding box is strictly to the right of B's bounding box. """ lookup_name = "right" @BaseSpatialField.register_lookup class StrictlyBelowLookup(GISLookup): """ The 'strictly_below' operator returns true if A's bounding box is strictly below B's bounding box. """ lookup_name = "strictly_below" @BaseSpatialField.register_lookup class StrictlyAboveLookup(GISLookup): """ The 'strictly_above' operator returns true if A's bounding box is strictly above B's bounding box. """ lookup_name = "strictly_above" @BaseSpatialField.register_lookup class SameAsLookup(GISLookup): """ The "~=" operator is the "same as" operator. It tests actual geometric equality of two features. So if A and B are the same feature, vertex-by-vertex, the operator returns true. """ lookup_name = "same_as" BaseSpatialField.register_lookup(SameAsLookup, "exact") @BaseSpatialField.register_lookup class BBContainsLookup(GISLookup): """ The 'bbcontains' operator returns true if A's bounding box completely contains by B's bounding box. """ lookup_name = "bbcontains" @BaseSpatialField.register_lookup class BBOverlapsLookup(GISLookup): """ The 'bboverlaps' operator returns true if A's bounding box overlaps B's bounding box. """ lookup_name = "bboverlaps" @BaseSpatialField.register_lookup class ContainedLookup(GISLookup): """ The 'contained' operator returns true if A's bounding box is completely contained by B's bounding box. """ lookup_name = "contained" # ------------------ # Geometry functions # ------------------ @BaseSpatialField.register_lookup class ContainsLookup(GISLookup): lookup_name = "contains" @BaseSpatialField.register_lookup class ContainsProperlyLookup(GISLookup): lookup_name = "contains_properly" @BaseSpatialField.register_lookup class CoveredByLookup(GISLookup): lookup_name = "coveredby" @BaseSpatialField.register_lookup class CoversLookup(GISLookup): lookup_name = "covers" @BaseSpatialField.register_lookup class CrossesLookup(GISLookup): lookup_name = "crosses" @BaseSpatialField.register_lookup class DisjointLookup(GISLookup): lookup_name = "disjoint" @BaseSpatialField.register_lookup class EqualsLookup(GISLookup): lookup_name = "equals" @BaseSpatialField.register_lookup class IntersectsLookup(GISLookup): lookup_name = "intersects" @BaseSpatialField.register_lookup class OverlapsLookup(GISLookup): lookup_name = "overlaps" @BaseSpatialField.register_lookup class RelateLookup(GISLookup): lookup_name = "relate" sql_template = "%(func)s(%(lhs)s, %(rhs)s, %%s)" pattern_regex = _lazy_re_compile(r"^[012TF\*]{9}$") def process_rhs(self, compiler, connection): # Check the pattern argument pattern = self.rhs_params[0] backend_op = connection.ops.gis_operators[self.lookup_name] if hasattr(backend_op, "check_relate_argument"): backend_op.check_relate_argument(pattern) elif not isinstance(pattern, str) or not self.pattern_regex.match(pattern): raise ValueError('Invalid intersection matrix pattern "%s".' % pattern) sql, params = super().process_rhs(compiler, connection) return sql, params + [pattern] @BaseSpatialField.register_lookup class TouchesLookup(GISLookup): lookup_name = "touches" @BaseSpatialField.register_lookup class WithinLookup(GISLookup): lookup_name = "within" class DistanceLookupBase(GISLookup): distance = True sql_template = "%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s" def process_rhs_params(self): if not 1 <= len(self.rhs_params) <= 3: raise ValueError( "2, 3, or 4-element tuple required for '%s' lookup." % self.lookup_name ) elif len(self.rhs_params) == 3 and self.rhs_params[2] != "spheroid": raise ValueError( "For 4-element tuples the last argument must be the 'spheroid' " "directive." ) # Check if the second parameter is a band index. if len(self.rhs_params) > 1 and self.rhs_params[1] != "spheroid": self.process_band_indices() def process_distance(self, compiler, connection): dist_param = self.rhs_params[0] return ( compiler.compile(dist_param.resolve_expression(compiler.query)) if hasattr(dist_param, "resolve_expression") else ( "%s", connection.ops.get_distance( self.lhs.output_field, self.rhs_params, self.lookup_name ), ) ) @BaseSpatialField.register_lookup class DWithinLookup(DistanceLookupBase): lookup_name = "dwithin" sql_template = "%(func)s(%(lhs)s, %(rhs)s, %(value)s)" def process_distance(self, compiler, connection): dist_param = self.rhs_params[0] if ( not connection.features.supports_dwithin_distance_expr and hasattr(dist_param, "resolve_expression") and not isinstance(dist_param, Distance) ): raise NotSupportedError( "This backend does not support expressions for specifying " "distance in the dwithin lookup." ) return super().process_distance(compiler, connection) def process_rhs(self, compiler, connection): dist_sql, dist_params = self.process_distance(compiler, connection) self.template_params["value"] = dist_sql rhs_sql, params = super().process_rhs(compiler, connection) return rhs_sql, params + dist_params class DistanceLookupFromFunction(DistanceLookupBase): def as_sql(self, compiler, connection): spheroid = ( len(self.rhs_params) == 2 and self.rhs_params[-1] == "spheroid" ) or None distance_expr = connection.ops.distance_expr_for_lookup( self.lhs, self.rhs, spheroid=spheroid ) sql, params = compiler.compile(distance_expr.resolve_expression(compiler.query)) dist_sql, dist_params = self.process_distance(compiler, connection) return ( "%(func)s %(op)s %(dist)s" % {"func": sql, "op": self.op, "dist": dist_sql}, params + dist_params, ) @BaseSpatialField.register_lookup class DistanceGTLookup(DistanceLookupFromFunction): lookup_name = "distance_gt" op = ">" @BaseSpatialField.register_lookup class DistanceGTELookup(DistanceLookupFromFunction): lookup_name = "distance_gte" op = ">=" @BaseSpatialField.register_lookup class DistanceLTLookup(DistanceLookupFromFunction): lookup_name = "distance_lt" op = "<" @BaseSpatialField.register_lookup class DistanceLTELookup(DistanceLookupFromFunction): lookup_name = "distance_lte" op = "<="