1656 lines
56 KiB
Python
1656 lines
56 KiB
Python
import inspect
|
|
import re
|
|
import uuid
|
|
from datetime import date
|
|
from enum import Enum
|
|
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union
|
|
|
|
from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order
|
|
from pypika.utils import (
|
|
CaseException,
|
|
FunctionException,
|
|
builder,
|
|
format_alias_sql,
|
|
format_quotes,
|
|
ignore_copy,
|
|
resolve_is_aggregate,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from pypika.queries import QueryBuilder, Selectable, Table
|
|
|
|
|
|
__author__ = "Timothy Heys"
|
|
__email__ = "theys@kayak.com"
|
|
|
|
|
|
NodeT = TypeVar("NodeT", bound="Node")
|
|
|
|
|
|
class Node:
|
|
is_aggregate = None
|
|
|
|
def nodes_(self) -> Iterator[NodeT]:
|
|
yield self
|
|
|
|
def find_(self, type: Type[NodeT]) -> List[NodeT]:
|
|
return [node for node in self.nodes_() if isinstance(node, type)]
|
|
|
|
|
|
class Term(Node):
|
|
is_aggregate = False
|
|
|
|
def __init__(self, alias: Optional[str] = None) -> None:
|
|
self.alias = alias
|
|
|
|
@builder
|
|
def as_(self, alias: str) -> "Term":
|
|
self.alias = alias
|
|
|
|
@property
|
|
def tables_(self) -> Set["Table"]:
|
|
from pypika import Table
|
|
|
|
return set(self.find_(Table))
|
|
|
|
def fields_(self) -> Set["Field"]:
|
|
return set(self.find_(Field))
|
|
|
|
@staticmethod
|
|
def wrap_constant(
|
|
val, wrapper_cls: Optional[Type["Term"]] = None
|
|
) -> Union[ValueError, NodeT, "LiteralValue", "Array", "Tuple", "ValueWrapper"]:
|
|
"""
|
|
Used for wrapping raw inputs such as numbers in Criterions and Operator.
|
|
|
|
For example, the expression F('abc')+1 stores the integer part in a ValueWrapper object.
|
|
|
|
:param val:
|
|
Any value.
|
|
:param wrapper_cls:
|
|
A pypika class which wraps a constant value so it can be handled as a component of the query.
|
|
:return:
|
|
Raw string, number, or decimal values will be returned in a ValueWrapper. Fields and other parts of the
|
|
querybuilder will be returned as inputted.
|
|
|
|
"""
|
|
|
|
if isinstance(val, Node):
|
|
return val
|
|
if val is None:
|
|
return NullValue()
|
|
if isinstance(val, list):
|
|
return Array(*val)
|
|
if isinstance(val, tuple):
|
|
return Tuple(*val)
|
|
|
|
# Need to default here to avoid the recursion. ValueWrapper extends this class.
|
|
wrapper_cls = wrapper_cls or ValueWrapper
|
|
return wrapper_cls(val)
|
|
|
|
@staticmethod
|
|
def wrap_json(
|
|
val: Union["Term", "QueryBuilder", "Interval", None, str, int, bool], wrapper_cls=None
|
|
) -> Union["Term", "QueryBuilder", "Interval", "NullValue", "ValueWrapper", "JSON"]:
|
|
from .queries import QueryBuilder
|
|
|
|
if isinstance(val, (Term, QueryBuilder, Interval)):
|
|
return val
|
|
if val is None:
|
|
return NullValue()
|
|
if isinstance(val, (str, int, bool)):
|
|
wrapper_cls = wrapper_cls or ValueWrapper
|
|
return wrapper_cls(val)
|
|
|
|
return JSON(val)
|
|
|
|
def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Term":
|
|
"""
|
|
Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries.
|
|
The base implementation returns self because not all terms have a table property.
|
|
|
|
:param current_table:
|
|
The table to be replaced.
|
|
:param new_table:
|
|
The table to replace with.
|
|
:return:
|
|
Self.
|
|
"""
|
|
return self
|
|
|
|
def eq(self, other: Any) -> "BasicCriterion":
|
|
return self == other
|
|
|
|
def isnull(self) -> "NullCriterion":
|
|
return NullCriterion(self)
|
|
|
|
def notnull(self) -> "Not":
|
|
return self.isnull().negate()
|
|
|
|
def isnotnull(self) -> 'NotNullCriterion':
|
|
return NotNullCriterion(self)
|
|
|
|
def bitwiseand(self, value: int) -> "BitwiseAndCriterion":
|
|
return BitwiseAndCriterion(self, self.wrap_constant(value))
|
|
|
|
def gt(self, other: Any) -> "BasicCriterion":
|
|
return self > other
|
|
|
|
def gte(self, other: Any) -> "BasicCriterion":
|
|
return self >= other
|
|
|
|
def lt(self, other: Any) -> "BasicCriterion":
|
|
return self < other
|
|
|
|
def lte(self, other: Any) -> "BasicCriterion":
|
|
return self <= other
|
|
|
|
def ne(self, other: Any) -> "BasicCriterion":
|
|
return self != other
|
|
|
|
def glob(self, expr: str) -> "BasicCriterion":
|
|
return BasicCriterion(Matching.glob, self, self.wrap_constant(expr))
|
|
|
|
def like(self, expr: str) -> "BasicCriterion":
|
|
return BasicCriterion(Matching.like, self, self.wrap_constant(expr))
|
|
|
|
def not_like(self, expr: str) -> "BasicCriterion":
|
|
return BasicCriterion(Matching.not_like, self, self.wrap_constant(expr))
|
|
|
|
def ilike(self, expr: str) -> "BasicCriterion":
|
|
return BasicCriterion(Matching.ilike, self, self.wrap_constant(expr))
|
|
|
|
def not_ilike(self, expr: str) -> "BasicCriterion":
|
|
return BasicCriterion(Matching.not_ilike, self, self.wrap_constant(expr))
|
|
|
|
def rlike(self, expr: str) -> "BasicCriterion":
|
|
return BasicCriterion(Matching.rlike, self, self.wrap_constant(expr))
|
|
|
|
def regex(self, pattern: str) -> "BasicCriterion":
|
|
return BasicCriterion(Matching.regex, self, self.wrap_constant(pattern))
|
|
|
|
def regexp(self, pattern: str) -> "BasicCriterion":
|
|
return BasicCriterion(Matching.regexp, self, self.wrap_constant(pattern))
|
|
|
|
def between(self, lower: Any, upper: Any) -> "BetweenCriterion":
|
|
return BetweenCriterion(self, self.wrap_constant(lower), self.wrap_constant(upper))
|
|
|
|
def from_to(self, start: Any, end: Any) -> "PeriodCriterion":
|
|
return PeriodCriterion(self, self.wrap_constant(start), self.wrap_constant(end))
|
|
|
|
def as_of(self, expr: str) -> "BasicCriterion":
|
|
return BasicCriterion(Matching.as_of, self, self.wrap_constant(expr))
|
|
|
|
def all_(self) -> "All":
|
|
return All(self)
|
|
|
|
def isin(self, arg: Union[list, tuple, set, "Term"]) -> "ContainsCriterion":
|
|
if isinstance(arg, (list, tuple, set)):
|
|
return ContainsCriterion(self, Tuple(*[self.wrap_constant(value) for value in arg]))
|
|
return ContainsCriterion(self, arg)
|
|
|
|
def notin(self, arg: Union[list, tuple, set, "Term"]) -> "ContainsCriterion":
|
|
return self.isin(arg).negate()
|
|
|
|
def bin_regex(self, pattern: str) -> "BasicCriterion":
|
|
return BasicCriterion(Matching.bin_regex, self, self.wrap_constant(pattern))
|
|
|
|
def negate(self) -> "Not":
|
|
return Not(self)
|
|
|
|
def lshift(self, other: Any) -> "ArithmeticExpression":
|
|
return self << other
|
|
|
|
def rshift(self, other: Any) -> "ArithmeticExpression":
|
|
return self >> other
|
|
|
|
def __invert__(self) -> "Not":
|
|
return Not(self)
|
|
|
|
def __pos__(self) -> "Term":
|
|
return self
|
|
|
|
def __neg__(self) -> "Negative":
|
|
return Negative(self)
|
|
|
|
def __add__(self, other: Any) -> "ArithmeticExpression":
|
|
return ArithmeticExpression(Arithmetic.add, self, self.wrap_constant(other))
|
|
|
|
def __sub__(self, other: Any) -> "ArithmeticExpression":
|
|
return ArithmeticExpression(Arithmetic.sub, self, self.wrap_constant(other))
|
|
|
|
def __mul__(self, other: Any) -> "ArithmeticExpression":
|
|
return ArithmeticExpression(Arithmetic.mul, self, self.wrap_constant(other))
|
|
|
|
def __truediv__(self, other: Any) -> "ArithmeticExpression":
|
|
return ArithmeticExpression(Arithmetic.div, self, self.wrap_constant(other))
|
|
|
|
def __pow__(self, other: Any) -> "Pow":
|
|
return Pow(self, other)
|
|
|
|
def __mod__(self, other: Any) -> "Mod":
|
|
return Mod(self, other)
|
|
|
|
def __radd__(self, other: Any) -> "ArithmeticExpression":
|
|
return ArithmeticExpression(Arithmetic.add, self.wrap_constant(other), self)
|
|
|
|
def __rsub__(self, other: Any) -> "ArithmeticExpression":
|
|
return ArithmeticExpression(Arithmetic.sub, self.wrap_constant(other), self)
|
|
|
|
def __rmul__(self, other: Any) -> "ArithmeticExpression":
|
|
return ArithmeticExpression(Arithmetic.mul, self.wrap_constant(other), self)
|
|
|
|
def __rtruediv__(self, other: Any) -> "ArithmeticExpression":
|
|
return ArithmeticExpression(Arithmetic.div, self.wrap_constant(other), self)
|
|
|
|
def __lshift__(self, other: Any) -> "ArithmeticExpression":
|
|
return ArithmeticExpression(Arithmetic.lshift, self, self.wrap_constant(other))
|
|
|
|
def __rshift__(self, other: Any) -> "ArithmeticExpression":
|
|
return ArithmeticExpression(Arithmetic.rshift, self, self.wrap_constant(other))
|
|
|
|
def __rlshift__(self, other: Any) -> "ArithmeticExpression":
|
|
return ArithmeticExpression(Arithmetic.lshift, self.wrap_constant(other), self)
|
|
|
|
def __rrshift__(self, other: Any) -> "ArithmeticExpression":
|
|
return ArithmeticExpression(Arithmetic.rshift, self.wrap_constant(other), self)
|
|
|
|
def __eq__(self, other: Any) -> "BasicCriterion":
|
|
return BasicCriterion(Equality.eq, self, self.wrap_constant(other))
|
|
|
|
def __ne__(self, other: Any) -> "BasicCriterion":
|
|
return BasicCriterion(Equality.ne, self, self.wrap_constant(other))
|
|
|
|
def __gt__(self, other: Any) -> "BasicCriterion":
|
|
return BasicCriterion(Equality.gt, self, self.wrap_constant(other))
|
|
|
|
def __ge__(self, other: Any) -> "BasicCriterion":
|
|
return BasicCriterion(Equality.gte, self, self.wrap_constant(other))
|
|
|
|
def __lt__(self, other: Any) -> "BasicCriterion":
|
|
return BasicCriterion(Equality.lt, self, self.wrap_constant(other))
|
|
|
|
def __le__(self, other: Any) -> "BasicCriterion":
|
|
return BasicCriterion(Equality.lte, self, self.wrap_constant(other))
|
|
|
|
def __getitem__(self, item: slice) -> "BetweenCriterion":
|
|
if not isinstance(item, slice):
|
|
raise TypeError("Field' object is not subscriptable")
|
|
return self.between(item.start, item.stop)
|
|
|
|
def __str__(self) -> str:
|
|
return self.get_sql(quote_char='"', secondary_quote_char="'")
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(self.get_sql(with_alias=True, with_namespace=True))
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class Parameter(Term):
|
|
is_aggregate = None
|
|
|
|
def __init__(self, placeholder: Union[str, int]) -> None:
|
|
super().__init__()
|
|
self.placeholder = placeholder
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
return str(self.placeholder)
|
|
|
|
|
|
class QmarkParameter(Parameter):
|
|
"""Question mark style, e.g. ...WHERE name=?"""
|
|
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
return "?"
|
|
|
|
|
|
class NumericParameter(Parameter):
|
|
"""Numeric, positional style, e.g. ...WHERE name=:1"""
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
return ":{placeholder}".format(placeholder=self.placeholder)
|
|
|
|
|
|
class NamedParameter(Parameter):
|
|
"""Named style, e.g. ...WHERE name=:name"""
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
return ":{placeholder}".format(placeholder=self.placeholder)
|
|
|
|
|
|
class FormatParameter(Parameter):
|
|
"""ANSI C printf format codes, e.g. ...WHERE name=%s"""
|
|
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
return "%s"
|
|
|
|
|
|
class PyformatParameter(Parameter):
|
|
"""Python extended format codes, e.g. ...WHERE name=%(name)s"""
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
return "%({placeholder})s".format(placeholder=self.placeholder)
|
|
|
|
|
|
class Negative(Term):
|
|
def __init__(self, term: Term) -> None:
|
|
super().__init__()
|
|
self.term = term
|
|
|
|
@property
|
|
def is_aggregate(self) -> Optional[bool]:
|
|
return self.term.is_aggregate
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
return "-{term}".format(term=self.term.get_sql(**kwargs))
|
|
|
|
|
|
class ValueWrapper(Term):
|
|
is_aggregate = None
|
|
|
|
def __init__(self, value: Any, alias: Optional[str] = None) -> None:
|
|
super().__init__(alias)
|
|
self.value = value
|
|
|
|
def get_value_sql(self, **kwargs: Any) -> str:
|
|
return self.get_formatted_value(self.value, **kwargs)
|
|
|
|
@classmethod
|
|
def get_formatted_value(cls, value: Any, **kwargs):
|
|
quote_char = kwargs.get("secondary_quote_char") or ""
|
|
|
|
# FIXME escape values
|
|
if isinstance(value, Term):
|
|
return value.get_sql(**kwargs)
|
|
if isinstance(value, Enum):
|
|
return cls.get_formatted_value(value.value, **kwargs)
|
|
if isinstance(value, date):
|
|
return cls.get_formatted_value(value.isoformat(), **kwargs)
|
|
if isinstance(value, str):
|
|
value = value.replace(quote_char, quote_char * 2)
|
|
return format_quotes(value, quote_char)
|
|
if isinstance(value, bool):
|
|
return str.lower(str(value))
|
|
if isinstance(value, uuid.UUID):
|
|
return cls.get_formatted_value(str(value), **kwargs)
|
|
if value is None:
|
|
return "null"
|
|
return str(value)
|
|
|
|
def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", **kwargs: Any) -> str:
|
|
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
|
|
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)
|
|
|
|
|
|
class JSON(Term):
|
|
table = None
|
|
|
|
def __init__(self, value: Any = None, alias: Optional[str] = None) -> None:
|
|
super().__init__(alias)
|
|
self.value = value
|
|
|
|
def _recursive_get_sql(self, value: Any, **kwargs: Any) -> str:
|
|
if isinstance(value, dict):
|
|
return self._get_dict_sql(value, **kwargs)
|
|
if isinstance(value, list):
|
|
return self._get_list_sql(value, **kwargs)
|
|
if isinstance(value, str):
|
|
return self._get_str_sql(value, **kwargs)
|
|
return str(value)
|
|
|
|
def _get_dict_sql(self, value: dict, **kwargs: Any) -> str:
|
|
pairs = [
|
|
"{key}:{value}".format(
|
|
key=self._recursive_get_sql(k, **kwargs),
|
|
value=self._recursive_get_sql(v, **kwargs),
|
|
)
|
|
for k, v in value.items()
|
|
]
|
|
return "".join(["{", ",".join(pairs), "}"])
|
|
|
|
def _get_list_sql(self, value: list, **kwargs: Any) -> str:
|
|
pairs = [self._recursive_get_sql(v, **kwargs) for v in value]
|
|
return "".join(["[", ",".join(pairs), "]"])
|
|
|
|
@staticmethod
|
|
def _get_str_sql(value: str, quote_char: str = '"', **kwargs: Any) -> str:
|
|
return format_quotes(value, quote_char)
|
|
|
|
def get_sql(self, secondary_quote_char: str = "'", **kwargs: Any) -> str:
|
|
sql = format_quotes(self._recursive_get_sql(self.value), secondary_quote_char)
|
|
return format_alias_sql(sql, self.alias, **kwargs)
|
|
|
|
def get_json_value(self, key_or_index: Union[str, int]) -> "BasicCriterion":
|
|
return BasicCriterion(JSONOperators.GET_JSON_VALUE, self, self.wrap_constant(key_or_index))
|
|
|
|
def get_text_value(self, key_or_index: Union[str, int]) -> "BasicCriterion":
|
|
return BasicCriterion(JSONOperators.GET_TEXT_VALUE, self, self.wrap_constant(key_or_index))
|
|
|
|
def get_path_json_value(self, path_json: str) -> "BasicCriterion":
|
|
return BasicCriterion(JSONOperators.GET_PATH_JSON_VALUE, self, self.wrap_json(path_json))
|
|
|
|
def get_path_text_value(self, path_json: str) -> "BasicCriterion":
|
|
return BasicCriterion(JSONOperators.GET_PATH_TEXT_VALUE, self, self.wrap_json(path_json))
|
|
|
|
def has_key(self, other: Any) -> "BasicCriterion":
|
|
return BasicCriterion(JSONOperators.HAS_KEY, self, self.wrap_json(other))
|
|
|
|
def contains(self, other: Any) -> "BasicCriterion":
|
|
return BasicCriterion(JSONOperators.CONTAINS, self, self.wrap_json(other))
|
|
|
|
def contained_by(self, other: Any) -> "BasicCriterion":
|
|
return BasicCriterion(JSONOperators.CONTAINED_BY, self, self.wrap_json(other))
|
|
|
|
def has_keys(self, other: Iterable) -> "BasicCriterion":
|
|
return BasicCriterion(JSONOperators.HAS_KEYS, self, Array(*other))
|
|
|
|
def has_any_keys(self, other: Iterable) -> "BasicCriterion":
|
|
return BasicCriterion(JSONOperators.HAS_ANY_KEYS, self, Array(*other))
|
|
|
|
|
|
class Values(Term):
|
|
def __init__(self, field: Union[str, "Field"]) -> None:
|
|
super().__init__(None)
|
|
self.field = Field(field) if not isinstance(field, Field) else field
|
|
|
|
def get_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str:
|
|
return "VALUES({value})".format(value=self.field.get_sql(quote_char=quote_char, **kwargs))
|
|
|
|
|
|
class LiteralValue(Term):
|
|
def __init__(self, value, alias: Optional[str] = None) -> None:
|
|
super().__init__(alias)
|
|
self._value = value
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
return format_alias_sql(self._value, self.alias, **kwargs)
|
|
|
|
|
|
class NullValue(LiteralValue):
|
|
def __init__(self, alias: Optional[str] = None) -> None:
|
|
super().__init__("NULL", alias)
|
|
|
|
|
|
class SystemTimeValue(LiteralValue):
|
|
def __init__(self, alias: Optional[str] = None) -> None:
|
|
super().__init__("SYSTEM_TIME", alias)
|
|
|
|
|
|
class Criterion(Term):
|
|
def __and__(self, other: Any) -> "ComplexCriterion":
|
|
return ComplexCriterion(Boolean.and_, self, other)
|
|
|
|
def __or__(self, other: Any) -> "ComplexCriterion":
|
|
return ComplexCriterion(Boolean.or_, self, other)
|
|
|
|
def __xor__(self, other: Any) -> "ComplexCriterion":
|
|
return ComplexCriterion(Boolean.xor_, self, other)
|
|
|
|
@staticmethod
|
|
def any(terms: Iterable[Term] = ()) -> "EmptyCriterion":
|
|
crit = EmptyCriterion()
|
|
|
|
for term in terms:
|
|
crit |= term
|
|
|
|
return crit
|
|
|
|
@staticmethod
|
|
def all(terms: Iterable[Any] = ()) -> "EmptyCriterion":
|
|
crit = EmptyCriterion()
|
|
|
|
for term in terms:
|
|
crit &= term
|
|
|
|
return crit
|
|
|
|
def get_sql(self) -> str:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class EmptyCriterion(Criterion):
|
|
is_aggregate = None
|
|
tables_ = set()
|
|
|
|
def fields_(self) -> Set["Field"]:
|
|
return set()
|
|
|
|
def __and__(self, other: Any) -> Any:
|
|
return other
|
|
|
|
def __or__(self, other: Any) -> Any:
|
|
return other
|
|
|
|
def __xor__(self, other: Any) -> Any:
|
|
return other
|
|
|
|
|
|
class Field(Criterion, JSON):
|
|
def __init__(
|
|
self, name: str, alias: Optional[str] = None, table: Optional[Union[str, "Selectable"]] = None
|
|
) -> None:
|
|
super().__init__(alias=alias)
|
|
self.name = name
|
|
self.table = table
|
|
|
|
def nodes_(self) -> Iterator[NodeT]:
|
|
yield self
|
|
if self.table is not None:
|
|
yield from self.table.nodes_()
|
|
|
|
@builder
|
|
def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Field":
|
|
"""
|
|
Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries.
|
|
|
|
:param current_table:
|
|
The table to be replaced.
|
|
:param new_table:
|
|
The table to replace with.
|
|
:return:
|
|
A copy of the field with the tables replaced.
|
|
"""
|
|
self.table = new_table if self.table == current_table else self.table
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
with_alias = kwargs.pop("with_alias", False)
|
|
with_namespace = kwargs.pop("with_namespace", False)
|
|
quote_char = kwargs.pop("quote_char", None)
|
|
|
|
field_sql = format_quotes(self.name, quote_char)
|
|
|
|
# Need to add namespace if the table has an alias
|
|
if self.table and (with_namespace or self.table.alias):
|
|
table_name = self.table.get_table_name()
|
|
field_sql = "{namespace}.{name}".format(
|
|
namespace=format_quotes(table_name, quote_char),
|
|
name=field_sql,
|
|
)
|
|
|
|
field_alias = getattr(self, "alias", None)
|
|
if with_alias:
|
|
return format_alias_sql(field_sql, field_alias, quote_char=quote_char, **kwargs)
|
|
return field_sql
|
|
|
|
|
|
class Index(Term):
|
|
def __init__(self, name: str, alias: Optional[str] = None) -> None:
|
|
super().__init__(alias)
|
|
self.name = name
|
|
|
|
def get_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str:
|
|
return format_quotes(self.name, quote_char)
|
|
|
|
|
|
class Star(Field):
|
|
def __init__(self, table: Optional[Union[str, "Selectable"]] = None) -> None:
|
|
super().__init__("*", table=table)
|
|
|
|
def nodes_(self) -> Iterator[NodeT]:
|
|
yield self
|
|
if self.table is not None:
|
|
yield from self.table.nodes_()
|
|
|
|
def get_sql(
|
|
self, with_alias: bool = False, with_namespace: bool = False, quote_char: Optional[str] = None, **kwargs: Any
|
|
) -> str:
|
|
if self.table and (with_namespace or self.table.alias):
|
|
namespace = self.table.alias or getattr(self.table, "_table_name")
|
|
return "{}.*".format(format_quotes(namespace, quote_char))
|
|
|
|
return "*"
|
|
|
|
|
|
class Tuple(Criterion):
|
|
def __init__(self, *values: Any) -> None:
|
|
super().__init__()
|
|
self.values = [self.wrap_constant(value) for value in values]
|
|
|
|
def nodes_(self) -> Iterator[NodeT]:
|
|
yield self
|
|
for value in self.values:
|
|
yield from value.nodes_()
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
sql = "({})".format(",".join(term.get_sql(**kwargs) for term in self.values))
|
|
return format_alias_sql(sql, self.alias, **kwargs)
|
|
|
|
@property
|
|
def is_aggregate(self) -> bool:
|
|
return resolve_is_aggregate([val.is_aggregate for val in self.values])
|
|
|
|
@builder
|
|
def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Tuple":
|
|
"""
|
|
Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries.
|
|
|
|
:param current_table:
|
|
The table to be replaced.
|
|
:param new_table:
|
|
The table to replace with.
|
|
:return:
|
|
A copy of the field with the tables replaced.
|
|
"""
|
|
self.values = [value.replace_table(current_table, new_table) for value in self.values]
|
|
|
|
|
|
class Array(Tuple):
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
dialect = kwargs.get("dialect", None)
|
|
values = ",".join(term.get_sql(**kwargs) for term in self.values)
|
|
|
|
sql = "[{}]".format(values)
|
|
if dialect in (Dialects.POSTGRESQL, Dialects.REDSHIFT):
|
|
sql = "ARRAY[{}]".format(values) if len(values) > 0 else "'{}'"
|
|
|
|
return format_alias_sql(sql, self.alias, **kwargs)
|
|
|
|
|
|
class Bracket(Tuple):
|
|
def __init__(self, term: Any) -> None:
|
|
super().__init__(term)
|
|
|
|
|
|
class NestedCriterion(Criterion):
|
|
def __init__(
|
|
self,
|
|
comparator: Comparator,
|
|
nested_comparator: "ComplexCriterion",
|
|
left: Any,
|
|
right: Any,
|
|
nested: Any,
|
|
alias: Optional[str] = None,
|
|
) -> None:
|
|
super().__init__(alias)
|
|
self.left = left
|
|
self.comparator = comparator
|
|
self.nested_comparator = nested_comparator
|
|
self.right = right
|
|
self.nested = nested
|
|
|
|
def nodes_(self) -> Iterator[NodeT]:
|
|
yield self
|
|
yield from self.right.nodes_()
|
|
yield from self.left.nodes_()
|
|
yield from self.nested.nodes_()
|
|
|
|
@property
|
|
def is_aggregate(self) -> Optional[bool]:
|
|
return resolve_is_aggregate([term.is_aggregate for term in [self.left, self.right, self.nested]])
|
|
|
|
@builder
|
|
def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "NestedCriterion":
|
|
"""
|
|
Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries.
|
|
|
|
:param current_table:
|
|
The table to be replaced.
|
|
:param new_table:
|
|
The table to replace with.
|
|
:return:
|
|
A copy of the criterion with the tables replaced.
|
|
"""
|
|
self.left = self.left.replace_table(current_table, new_table)
|
|
self.right = self.right.replace_table(current_table, new_table)
|
|
self.nested = self.right.replace_table(current_table, new_table)
|
|
|
|
def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str:
|
|
sql = "{left}{comparator}{right}{nested_comparator}{nested}".format(
|
|
left=self.left.get_sql(**kwargs),
|
|
comparator=self.comparator.value,
|
|
right=self.right.get_sql(**kwargs),
|
|
nested_comparator=self.nested_comparator.value,
|
|
nested=self.nested.get_sql(**kwargs),
|
|
)
|
|
|
|
if with_alias:
|
|
return format_alias_sql(sql=sql, alias=self.alias, **kwargs)
|
|
|
|
return sql
|
|
|
|
|
|
class BasicCriterion(Criterion):
|
|
def __init__(self, comparator: Comparator, left: Term, right: Term, alias: Optional[str] = None) -> None:
|
|
"""
|
|
A wrapper for a basic criterion such as equality or inequality. This wraps three parts, a left and right term
|
|
and a comparator which defines the type of comparison.
|
|
|
|
|
|
:param comparator:
|
|
Type: Comparator
|
|
This defines the type of comparison, such as {quote}={quote} or {quote}>{quote}.
|
|
:param left:
|
|
The term on the left side of the expression.
|
|
:param right:
|
|
The term on the right side of the expression.
|
|
"""
|
|
super().__init__(alias)
|
|
self.comparator = comparator
|
|
self.left = left
|
|
self.right = right
|
|
|
|
def nodes_(self) -> Iterator[NodeT]:
|
|
yield self
|
|
yield from self.right.nodes_()
|
|
yield from self.left.nodes_()
|
|
|
|
@property
|
|
def is_aggregate(self) -> Optional[bool]:
|
|
return resolve_is_aggregate([term.is_aggregate for term in [self.left, self.right]])
|
|
|
|
@builder
|
|
def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "BasicCriterion":
|
|
"""
|
|
Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries.
|
|
|
|
:param current_table:
|
|
The table to be replaced.
|
|
:param new_table:
|
|
The table to replace with.
|
|
:return:
|
|
A copy of the criterion with the tables replaced.
|
|
"""
|
|
self.left = self.left.replace_table(current_table, new_table)
|
|
self.right = self.right.replace_table(current_table, new_table)
|
|
|
|
def get_sql(self, quote_char: str = '"', with_alias: bool = False, **kwargs: Any) -> str:
|
|
sql = "{left}{comparator}{right}".format(
|
|
comparator=self.comparator.value,
|
|
left=self.left.get_sql(quote_char=quote_char, **kwargs),
|
|
right=self.right.get_sql(quote_char=quote_char, **kwargs),
|
|
)
|
|
if with_alias:
|
|
return format_alias_sql(sql, self.alias, **kwargs)
|
|
return sql
|
|
|
|
|
|
class ContainsCriterion(Criterion):
|
|
def __init__(self, term: Any, container: Term, alias: Optional[str] = None) -> None:
|
|
"""
|
|
A wrapper for a "IN" criterion. This wraps two parts, a term and a container. The term is the part of the
|
|
expression that is checked for membership in the container. The container can either be a list or a subquery.
|
|
|
|
|
|
:param term:
|
|
The term to assert membership for within the container.
|
|
:param container:
|
|
A list or subquery.
|
|
"""
|
|
super().__init__(alias)
|
|
self.term = term
|
|
self.container = container
|
|
self._is_negated = False
|
|
|
|
def nodes_(self) -> Iterator[NodeT]:
|
|
yield self
|
|
yield from self.term.nodes_()
|
|
yield from self.container.nodes_()
|
|
|
|
@property
|
|
def is_aggregate(self) -> Optional[bool]:
|
|
return self.term.is_aggregate
|
|
|
|
@builder
|
|
def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "ContainsCriterion":
|
|
"""
|
|
Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries.
|
|
|
|
:param current_table:
|
|
The table to be replaced.
|
|
:param new_table:
|
|
The table to replace with.
|
|
:return:
|
|
A copy of the criterion with the tables replaced.
|
|
"""
|
|
self.term = self.term.replace_table(current_table, new_table)
|
|
|
|
def get_sql(self, subquery: Any = None, **kwargs: Any) -> str:
|
|
sql = "{term} {not_}IN {container}".format(
|
|
term=self.term.get_sql(**kwargs),
|
|
container=self.container.get_sql(subquery=True, **kwargs),
|
|
not_="NOT " if self._is_negated else "",
|
|
)
|
|
return format_alias_sql(sql, self.alias, **kwargs)
|
|
|
|
@builder
|
|
def negate(self) -> "ContainsCriterion":
|
|
self._is_negated = True
|
|
|
|
|
|
class ExistsCriterion(Criterion):
|
|
def __init__(self, container, alias=None):
|
|
super(ExistsCriterion, self).__init__(alias)
|
|
self.container = container
|
|
self._is_negated = False
|
|
|
|
def get_sql(self, **kwargs):
|
|
# FIXME escape
|
|
return "{not_}EXISTS {container}".format(
|
|
container=self.container.get_sql(**kwargs), not_='NOT ' if self._is_negated else ''
|
|
)
|
|
|
|
def negate(self):
|
|
self._is_negated = True
|
|
return self
|
|
|
|
|
|
class RangeCriterion(Criterion):
|
|
def __init__(self, term: Term, start: Any, end: Any, alias: Optional[str] = None) -> str:
|
|
super().__init__(alias)
|
|
self.term = term
|
|
self.start = start
|
|
self.end = end
|
|
|
|
def nodes_(self) -> Iterator[NodeT]:
|
|
yield self
|
|
yield from self.term.nodes_()
|
|
yield from self.start.nodes_()
|
|
yield from self.end.nodes_()
|
|
|
|
@property
|
|
def is_aggregate(self) -> Optional[bool]:
|
|
return self.term.is_aggregate
|
|
|
|
|
|
class BetweenCriterion(RangeCriterion):
|
|
@builder
|
|
def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "BetweenCriterion":
|
|
"""
|
|
Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries.
|
|
|
|
:param current_table:
|
|
The table to be replaced.
|
|
:param new_table:
|
|
The table to replace with.
|
|
:return:
|
|
A copy of the criterion with the tables replaced.
|
|
"""
|
|
self.term = self.term.replace_table(current_table, new_table)
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
# FIXME escape
|
|
sql = "{term} BETWEEN {start} AND {end}".format(
|
|
term=self.term.get_sql(**kwargs),
|
|
start=self.start.get_sql(**kwargs),
|
|
end=self.end.get_sql(**kwargs),
|
|
)
|
|
return format_alias_sql(sql, self.alias, **kwargs)
|
|
|
|
|
|
class PeriodCriterion(RangeCriterion):
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
sql = "{term} FROM {start} TO {end}".format(
|
|
term=self.term.get_sql(**kwargs),
|
|
start=self.start.get_sql(**kwargs),
|
|
end=self.end.get_sql(**kwargs),
|
|
)
|
|
return format_alias_sql(sql, self.alias, **kwargs)
|
|
|
|
|
|
class BitwiseAndCriterion(Criterion):
|
|
def __init__(self, term: Term, value: Any, alias: Optional[str] = None) -> None:
|
|
super().__init__(alias)
|
|
self.term = term
|
|
self.value = value
|
|
|
|
def nodes_(self) -> Iterator[NodeT]:
|
|
yield self
|
|
yield from self.term.nodes_()
|
|
yield from self.value.nodes_()
|
|
|
|
@builder
|
|
def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "BitwiseAndCriterion":
|
|
"""
|
|
Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries.
|
|
|
|
:param current_table:
|
|
The table to be replaced.
|
|
:param new_table:
|
|
The table to replace with.
|
|
:return:
|
|
A copy of the criterion with the tables replaced.
|
|
"""
|
|
self.term = self.term.replace_table(current_table, new_table)
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
sql = "({term} & {value})".format(
|
|
term=self.term.get_sql(**kwargs),
|
|
value=self.value,
|
|
)
|
|
return format_alias_sql(sql, self.alias, **kwargs)
|
|
|
|
|
|
class NullCriterion(Criterion):
|
|
def __init__(self, term: Term, alias: Optional[str] = None) -> None:
|
|
super().__init__(alias)
|
|
self.term = term
|
|
|
|
def nodes_(self) -> Iterator[NodeT]:
|
|
yield self
|
|
yield from self.term.nodes_()
|
|
|
|
@builder
|
|
def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "NullCriterion":
|
|
"""
|
|
Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries.
|
|
|
|
:param current_table:
|
|
The table to be replaced.
|
|
:param new_table:
|
|
The table to replace with.
|
|
:return:
|
|
A copy of the criterion with the tables replaced.
|
|
"""
|
|
self.term = self.term.replace_table(current_table, new_table)
|
|
|
|
def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str:
|
|
sql = "{term} IS NULL".format(
|
|
term=self.term.get_sql(**kwargs),
|
|
)
|
|
return format_alias_sql(sql, self.alias, **kwargs)
|
|
|
|
|
|
class NotNullCriterion(NullCriterion):
|
|
def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str:
|
|
sql = "{term} IS NOT NULL".format(
|
|
term=self.term.get_sql(**kwargs),
|
|
)
|
|
return format_alias_sql(sql, self.alias, **kwargs)
|
|
|
|
|
|
class ComplexCriterion(BasicCriterion):
|
|
def get_sql(self, subcriterion: bool = False, **kwargs: Any) -> str:
|
|
sql = "{left} {comparator} {right}".format(
|
|
comparator=self.comparator.value,
|
|
left=self.left.get_sql(subcriterion=self.needs_brackets(self.left), **kwargs),
|
|
right=self.right.get_sql(subcriterion=self.needs_brackets(self.right), **kwargs),
|
|
)
|
|
|
|
if subcriterion:
|
|
return "({criterion})".format(criterion=sql)
|
|
|
|
return sql
|
|
|
|
def needs_brackets(self, term: Term) -> bool:
|
|
return isinstance(term, ComplexCriterion) and not term.comparator == self.comparator
|
|
|
|
|
|
class ArithmeticExpression(Term):
|
|
"""
|
|
Wrapper for an arithmetic function. Can be simple with two terms or complex with nested terms. Order of operations
|
|
are also preserved.
|
|
"""
|
|
|
|
add_order = [Arithmetic.add, Arithmetic.sub]
|
|
|
|
def __init__(self, operator: Arithmetic, left: Any, right: Any, alias: Optional[str] = None) -> None:
|
|
"""
|
|
Wrapper for an arithmetic expression.
|
|
|
|
:param operator:
|
|
Type: Arithmetic
|
|
An operator for the expression such as {quote}+{quote} or {quote}/{quote}
|
|
|
|
:param left:
|
|
The term on the left side of the expression.
|
|
:param right:
|
|
The term on the right side of the expression.
|
|
:param alias:
|
|
(Optional) an alias for the term which can be used inside a select statement.
|
|
:return:
|
|
"""
|
|
super().__init__(alias)
|
|
self.operator = operator
|
|
self.left = left
|
|
self.right = right
|
|
|
|
def nodes_(self) -> Iterator[NodeT]:
|
|
yield self
|
|
yield from self.left.nodes_()
|
|
yield from self.right.nodes_()
|
|
|
|
@property
|
|
def is_aggregate(self) -> Optional[bool]:
|
|
# True if both left and right terms are True or None. None if both terms are None. Otherwise, False
|
|
return resolve_is_aggregate([self.left.is_aggregate, self.right.is_aggregate])
|
|
|
|
@builder
|
|
def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "ArithmeticExpression":
|
|
"""
|
|
Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries.
|
|
|
|
:param current_table:
|
|
The table to be replaced.
|
|
:param new_table:
|
|
The table to replace with.
|
|
:return:
|
|
A copy of the term with the tables replaced.
|
|
"""
|
|
self.left = self.left.replace_table(current_table, new_table)
|
|
self.right = self.right.replace_table(current_table, new_table)
|
|
|
|
def left_needs_parens(self, curr_op, left_op) -> bool:
|
|
"""
|
|
Returns true if the expression on the left of the current operator needs to be enclosed in parentheses.
|
|
|
|
:param current_op:
|
|
The current operator.
|
|
:param left_op:
|
|
The highest level operator of the left expression.
|
|
"""
|
|
if left_op is None:
|
|
# If the left expression is a single item.
|
|
return False
|
|
if curr_op in self.add_order:
|
|
# If the current operator is '+' or '-'.
|
|
return False
|
|
# The current operator is '*' or '/'. If the left operator is '+' or '-', we need to add parentheses:
|
|
# e.g. (A + B) / ..., (A - B) / ...
|
|
# Otherwise, no parentheses are necessary:
|
|
# e.g. A * B / ..., A / B / ...
|
|
return left_op in self.add_order
|
|
|
|
def right_needs_parens(self, curr_op, right_op) -> bool:
|
|
"""
|
|
Returns true if the expression on the right of the current operator needs to be enclosed in parentheses.
|
|
|
|
:param current_op:
|
|
The current operator.
|
|
:param right_op:
|
|
The highest level operator of the right expression.
|
|
"""
|
|
if right_op is None:
|
|
# If the right expression is a single item.
|
|
return False
|
|
if curr_op == Arithmetic.add:
|
|
return False
|
|
if curr_op == Arithmetic.div:
|
|
return True
|
|
# The current operator is '*' or '-. If the right operator is '+' or '-', we need to add parentheses:
|
|
# e.g. ... - (A + B), ... - (A - B)
|
|
# Otherwise, no parentheses are necessary:
|
|
# e.g. ... - A / B, ... - A * B
|
|
return right_op in self.add_order
|
|
|
|
def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str:
|
|
left_op, right_op = [getattr(side, "operator", None) for side in [self.left, self.right]]
|
|
|
|
arithmetic_sql = "{left}{operator}{right}".format(
|
|
operator=self.operator.value,
|
|
left=("({})" if self.left_needs_parens(self.operator, left_op) else "{}").format(
|
|
self.left.get_sql(**kwargs)
|
|
),
|
|
right=("({})" if self.right_needs_parens(self.operator, right_op) else "{}").format(
|
|
self.right.get_sql(**kwargs)
|
|
),
|
|
)
|
|
|
|
if with_alias:
|
|
return format_alias_sql(arithmetic_sql, self.alias, **kwargs)
|
|
|
|
return arithmetic_sql
|
|
|
|
|
|
class Case(Criterion):
|
|
def __init__(self, alias: Optional[str] = None) -> None:
|
|
super().__init__(alias=alias)
|
|
self._cases = []
|
|
self._else = None
|
|
|
|
def nodes_(self) -> Iterator[NodeT]:
|
|
yield self
|
|
|
|
for criterion, term in self._cases:
|
|
yield from criterion.nodes_()
|
|
yield from term.nodes_()
|
|
|
|
if self._else is not None:
|
|
yield from self._else.nodes_()
|
|
|
|
@property
|
|
def is_aggregate(self) -> Optional[bool]:
|
|
# True if all criterions/cases are True or None. None all cases are None. Otherwise, False
|
|
return resolve_is_aggregate(
|
|
[criterion.is_aggregate or term.is_aggregate for criterion, term in self._cases]
|
|
+ [self._else.is_aggregate if self._else else None]
|
|
)
|
|
|
|
@builder
|
|
def when(self, criterion: Any, term: Any) -> "Case":
|
|
self._cases.append((criterion, self.wrap_constant(term)))
|
|
|
|
@builder
|
|
def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Case":
|
|
"""
|
|
Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries.
|
|
|
|
:param current_table:
|
|
The table to be replaced.
|
|
:param new_table:
|
|
The table to replace with.
|
|
:return:
|
|
A copy of the term with the tables replaced.
|
|
"""
|
|
self._cases = [
|
|
[
|
|
criterion.replace_table(current_table, new_table),
|
|
term.replace_table(current_table, new_table),
|
|
]
|
|
for criterion, term in self._cases
|
|
]
|
|
self._else = self._else.replace_table(current_table, new_table) if self._else else None
|
|
|
|
@builder
|
|
def else_(self, term: Any) -> "Case":
|
|
self._else = self.wrap_constant(term)
|
|
return self
|
|
|
|
def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str:
|
|
if not self._cases:
|
|
raise CaseException("At least one 'when' case is required for a CASE statement.")
|
|
|
|
cases = " ".join(
|
|
"WHEN {when} THEN {then}".format(when=criterion.get_sql(**kwargs), then=term.get_sql(**kwargs))
|
|
for criterion, term in self._cases
|
|
)
|
|
else_ = " ELSE {}".format(self._else.get_sql(**kwargs)) if self._else else ""
|
|
|
|
case_sql = "CASE {cases}{else_} END".format(cases=cases, else_=else_)
|
|
|
|
if with_alias:
|
|
return format_alias_sql(case_sql, self.alias, **kwargs)
|
|
|
|
return case_sql
|
|
|
|
|
|
class Not(Criterion):
|
|
def __init__(self, term: Any, alias: Optional[str] = None) -> None:
|
|
super().__init__(alias=alias)
|
|
self.term = term
|
|
|
|
def nodes_(self) -> Iterator[NodeT]:
|
|
yield self
|
|
yield from self.term.nodes_()
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
kwargs["subcriterion"] = True
|
|
sql = "NOT {term}".format(term=self.term.get_sql(**kwargs))
|
|
return format_alias_sql(sql, self.alias, **kwargs)
|
|
|
|
@ignore_copy
|
|
def __getattr__(self, name: str) -> Any:
|
|
"""
|
|
Delegate method calls to the class wrapped by Not().
|
|
Re-wrap methods on child classes of Term (e.g. isin, eg...) to retain 'NOT <term>' output.
|
|
"""
|
|
item_func = getattr(self.term, name)
|
|
|
|
if not inspect.ismethod(item_func):
|
|
return item_func
|
|
|
|
def inner(inner_self, *args, **kwargs):
|
|
result = item_func(inner_self, *args, **kwargs)
|
|
if isinstance(result, (Term,)):
|
|
return Not(result)
|
|
return result
|
|
|
|
return inner
|
|
|
|
@builder
|
|
def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Not":
|
|
"""
|
|
Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries.
|
|
|
|
:param current_table:
|
|
The table to be replaced.
|
|
:param new_table:
|
|
The table to replace with.
|
|
:return:
|
|
A copy of the criterion with the tables replaced.
|
|
"""
|
|
self.term = self.term.replace_table(current_table, new_table)
|
|
|
|
|
|
class All(Criterion):
|
|
def __init__(self, term: Any, alias: Optional[str] = None) -> None:
|
|
super().__init__(alias=alias)
|
|
self.term = term
|
|
|
|
def nodes_(self) -> Iterator[NodeT]:
|
|
yield self
|
|
yield from self.term.nodes_()
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
sql = "{term} ALL".format(term=self.term.get_sql(**kwargs))
|
|
return format_alias_sql(sql, self.alias, **kwargs)
|
|
|
|
|
|
class CustomFunction:
|
|
def __init__(self, name: str, params: Optional[Sequence] = None) -> None:
|
|
self.name = name
|
|
self.params = params
|
|
|
|
def __call__(self, *args: Any, **kwargs: Any) -> "Function":
|
|
if not self._has_params():
|
|
return Function(self.name, alias=kwargs.get("alias"))
|
|
|
|
if not self._is_valid_function_call(*args):
|
|
raise FunctionException(
|
|
"Function {name} require these arguments ({params}), ({args}) passed".format(
|
|
name=self.name,
|
|
params=", ".join(str(p) for p in self.params),
|
|
args=", ".join(str(p) for p in args),
|
|
)
|
|
)
|
|
|
|
return Function(self.name, *args, alias=kwargs.get("alias"))
|
|
|
|
def _has_params(self):
|
|
return self.params is not None
|
|
|
|
def _is_valid_function_call(self, *args):
|
|
return len(args) == len(self.params)
|
|
|
|
|
|
class Function(Criterion):
|
|
def __init__(self, name: str, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(kwargs.get("alias"))
|
|
self.name = name
|
|
self.args = [self.wrap_constant(param) for param in args]
|
|
self.schema = kwargs.get("schema")
|
|
|
|
def nodes_(self) -> Iterator[NodeT]:
|
|
yield self
|
|
for arg in self.args:
|
|
yield from arg.nodes_()
|
|
|
|
@property
|
|
def is_aggregate(self) -> Optional[bool]:
|
|
"""
|
|
This is a shortcut that assumes if a function has a single argument and that argument is aggregated, then this
|
|
function is also aggregated. A more sophisticated approach is needed, however it is unclear how that might work.
|
|
:returns:
|
|
True if the function accepts one argument and that argument is aggregate.
|
|
"""
|
|
return resolve_is_aggregate([arg.is_aggregate for arg in self.args])
|
|
|
|
@builder
|
|
def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Function":
|
|
"""
|
|
Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries.
|
|
|
|
:param current_table:
|
|
The table to be replaced.
|
|
:param new_table:
|
|
The table to replace with.
|
|
:return:
|
|
A copy of the criterion with the tables replaced.
|
|
"""
|
|
self.args = [param.replace_table(current_table, new_table) for param in self.args]
|
|
|
|
def get_special_params_sql(self, **kwargs: Any) -> Any:
|
|
pass
|
|
|
|
@staticmethod
|
|
def get_arg_sql(arg, **kwargs):
|
|
return arg.get_sql(with_alias=False, **kwargs) if hasattr(arg, "get_sql") else str(arg)
|
|
|
|
def get_function_sql(self, **kwargs: Any) -> str:
|
|
special_params_sql = self.get_special_params_sql(**kwargs)
|
|
|
|
return "{name}({args}{special})".format(
|
|
name=self.name,
|
|
args=",".join(
|
|
p.get_sql(with_alias=False, subquery=True, **kwargs)
|
|
if hasattr(p, "get_sql")
|
|
else self.get_arg_sql(p, **kwargs)
|
|
for p in self.args
|
|
),
|
|
special=(" " + special_params_sql) if special_params_sql else "",
|
|
)
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
with_alias = kwargs.pop("with_alias", False)
|
|
with_namespace = kwargs.pop("with_namespace", False)
|
|
quote_char = kwargs.pop("quote_char", None)
|
|
dialect = kwargs.pop("dialect", None)
|
|
|
|
# FIXME escape
|
|
function_sql = self.get_function_sql(with_namespace=with_namespace, quote_char=quote_char, dialect=dialect)
|
|
|
|
if self.schema is not None:
|
|
function_sql = "{schema}.{function}".format(
|
|
schema=self.schema.get_sql(quote_char=quote_char, dialect=dialect, **kwargs),
|
|
function=function_sql,
|
|
)
|
|
|
|
if with_alias:
|
|
return format_alias_sql(function_sql, self.alias, quote_char=quote_char, **kwargs)
|
|
|
|
return function_sql
|
|
|
|
|
|
class AggregateFunction(Function):
|
|
is_aggregate = True
|
|
|
|
def __init__(self, name, *args, **kwargs):
|
|
super(AggregateFunction, self).__init__(name, *args, **kwargs)
|
|
|
|
self._filters = []
|
|
self._include_filter = False
|
|
|
|
@builder
|
|
def filter(self, *filters: Any) -> "AnalyticFunction":
|
|
self._include_filter = True
|
|
self._filters += filters
|
|
|
|
def get_filter_sql(self, **kwargs: Any) -> str:
|
|
if self._include_filter:
|
|
return "WHERE {criterions}".format(criterions=Criterion.all(self._filters).get_sql(**kwargs))
|
|
|
|
def get_function_sql(self, **kwargs: Any):
|
|
sql = super(AggregateFunction, self).get_function_sql(**kwargs)
|
|
filter_sql = self.get_filter_sql(**kwargs)
|
|
|
|
if self._include_filter:
|
|
sql += " FILTER({filter_sql})".format(filter_sql=filter_sql)
|
|
|
|
return sql
|
|
|
|
|
|
class AnalyticFunction(AggregateFunction):
|
|
is_aggregate = False
|
|
is_analytic = True
|
|
|
|
def __init__(self, name: str, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(name, *args, **kwargs)
|
|
self._filters = []
|
|
self._partition = []
|
|
self._orderbys = []
|
|
self._include_filter = False
|
|
self._include_over = False
|
|
|
|
@builder
|
|
def over(self, *terms: Any) -> "AnalyticFunction":
|
|
self._include_over = True
|
|
self._partition += terms
|
|
|
|
@builder
|
|
def orderby(self, *terms: Any, **kwargs: Any) -> "AnalyticFunction":
|
|
self._include_over = True
|
|
self._orderbys += [(term, kwargs.get("order")) for term in terms]
|
|
|
|
def _orderby_field(self, field: Field, orient: Optional[Order], **kwargs: Any) -> str:
|
|
if orient is None:
|
|
return field.get_sql(**kwargs)
|
|
|
|
return "{field} {orient}".format(
|
|
field=field.get_sql(**kwargs),
|
|
orient=orient.value,
|
|
)
|
|
|
|
def get_partition_sql(self, **kwargs: Any) -> str:
|
|
terms = []
|
|
if self._partition:
|
|
terms.append(
|
|
"PARTITION BY {args}".format(
|
|
args=",".join(p.get_sql(**kwargs) if hasattr(p, "get_sql") else str(p) for p in self._partition)
|
|
)
|
|
)
|
|
|
|
if self._orderbys:
|
|
terms.append(
|
|
"ORDER BY {orderby}".format(
|
|
orderby=",".join(self._orderby_field(field, orient, **kwargs) for field, orient in self._orderbys)
|
|
)
|
|
)
|
|
|
|
return " ".join(terms)
|
|
|
|
def get_function_sql(self, **kwargs: Any) -> str:
|
|
function_sql = super(AnalyticFunction, self).get_function_sql(**kwargs)
|
|
partition_sql = self.get_partition_sql(**kwargs)
|
|
|
|
sql = function_sql
|
|
if self._include_over:
|
|
sql += " OVER({partition_sql})".format(partition_sql=partition_sql)
|
|
|
|
return sql
|
|
|
|
|
|
EdgeT = TypeVar("EdgeT", bound="WindowFrameAnalyticFunction.Edge")
|
|
|
|
|
|
class WindowFrameAnalyticFunction(AnalyticFunction):
|
|
class Edge:
|
|
def __init__(self, value: Optional[Union[str, int]] = None) -> None:
|
|
self.value = value
|
|
|
|
def __str__(self) -> str:
|
|
return "{value} {modifier}".format(
|
|
value=self.value or "UNBOUNDED",
|
|
modifier=self.modifier,
|
|
)
|
|
|
|
def __init__(self, name: str, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(name, *args, **kwargs)
|
|
self.frame = None
|
|
self.bound = None
|
|
|
|
def _set_frame_and_bounds(self, frame: str, bound: str, and_bound: Optional[EdgeT]) -> None:
|
|
if self.frame or self.bound:
|
|
raise AttributeError()
|
|
|
|
self.frame = frame
|
|
self.bound = (bound, and_bound) if and_bound else bound
|
|
|
|
@builder
|
|
def rows(self, bound: Union[str, EdgeT], and_bound: Optional[EdgeT] = None) -> "WindowFrameAnalyticFunction":
|
|
self._set_frame_and_bounds("ROWS", bound, and_bound)
|
|
|
|
@builder
|
|
def range(self, bound: Union[str, EdgeT], and_bound: Optional[EdgeT] = None) -> "WindowFrameAnalyticFunction":
|
|
self._set_frame_and_bounds("RANGE", bound, and_bound)
|
|
|
|
def get_frame_sql(self) -> str:
|
|
if not isinstance(self.bound, tuple):
|
|
return "{frame} {bound}".format(frame=self.frame, bound=self.bound)
|
|
|
|
lower, upper = self.bound
|
|
return "{frame} BETWEEN {lower} AND {upper}".format(
|
|
frame=self.frame,
|
|
lower=lower,
|
|
upper=upper,
|
|
)
|
|
|
|
def get_partition_sql(self, **kwargs: Any) -> str:
|
|
partition_sql = super(WindowFrameAnalyticFunction, self).get_partition_sql(**kwargs)
|
|
|
|
if not self.frame and not self.bound:
|
|
return partition_sql
|
|
|
|
return "{over} {frame}".format(over=partition_sql, frame=self.get_frame_sql())
|
|
|
|
|
|
class IgnoreNullsAnalyticFunction(AnalyticFunction):
|
|
def __init__(self, name: str, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(name, *args, **kwargs)
|
|
self._ignore_nulls = False
|
|
|
|
@builder
|
|
def ignore_nulls(self) -> "IgnoreNullsAnalyticFunction":
|
|
self._ignore_nulls = True
|
|
|
|
def get_special_params_sql(self, **kwargs: Any) -> Optional[str]:
|
|
if self._ignore_nulls:
|
|
return "IGNORE NULLS"
|
|
|
|
# No special params unless ignoring nulls
|
|
return None
|
|
|
|
|
|
class Interval(Node):
|
|
templates = {
|
|
# PostgreSQL, Redshift and Vertica require quotes around the expr and unit e.g. INTERVAL '1 week'
|
|
Dialects.POSTGRESQL: "INTERVAL '{expr} {unit}'",
|
|
Dialects.REDSHIFT: "INTERVAL '{expr} {unit}'",
|
|
Dialects.VERTICA: "INTERVAL '{expr} {unit}'",
|
|
# Oracle and MySQL requires just single quotes around the expr
|
|
Dialects.ORACLE: "INTERVAL '{expr}' {unit}",
|
|
Dialects.MYSQL: "INTERVAL '{expr}' {unit}",
|
|
}
|
|
|
|
units = ["years", "months", "days", "hours", "minutes", "seconds", "microseconds"]
|
|
labels = ["YEAR", "MONTH", "DAY", "HOUR", "MINUTE", "SECOND", "MICROSECOND"]
|
|
|
|
trim_pattern = re.compile(r"(^0+\.)|(\.0+$)|(^[0\-.: ]+[\-: ])|([\-:. ][0\-.: ]+$)")
|
|
|
|
def __init__(
|
|
self,
|
|
years: int = 0,
|
|
months: int = 0,
|
|
days: int = 0,
|
|
hours: int = 0,
|
|
minutes: int = 0,
|
|
seconds: int = 0,
|
|
microseconds: int = 0,
|
|
quarters: int = 0,
|
|
weeks: int = 0,
|
|
dialect: Optional[Dialects] = None,
|
|
):
|
|
self.dialect = dialect
|
|
self.largest = None
|
|
self.smallest = None
|
|
self.is_negative = False
|
|
|
|
if quarters:
|
|
self.quarters = quarters
|
|
return
|
|
|
|
if weeks:
|
|
self.weeks = weeks
|
|
return
|
|
|
|
for unit, label, value in zip(
|
|
self.units,
|
|
self.labels,
|
|
[years, months, days, hours, minutes, seconds, microseconds],
|
|
):
|
|
if value:
|
|
int_value = int(value)
|
|
setattr(self, unit, abs(int_value))
|
|
if self.largest is None:
|
|
self.largest = label
|
|
self.is_negative = int_value < 0
|
|
self.smallest = label
|
|
|
|
def __str__(self) -> str:
|
|
return self.get_sql()
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
dialect = self.dialect or kwargs.get("dialect")
|
|
|
|
if self.largest == "MICROSECOND":
|
|
expr = getattr(self, "microseconds")
|
|
unit = "MICROSECOND"
|
|
|
|
elif hasattr(self, "quarters"):
|
|
expr = getattr(self, "quarters")
|
|
unit = "QUARTER"
|
|
|
|
elif hasattr(self, "weeks"):
|
|
expr = getattr(self, "weeks")
|
|
unit = "WEEK"
|
|
|
|
else:
|
|
# Create the whole expression but trim out the unnecessary fields
|
|
expr = "{years}-{months}-{days} {hours}:{minutes}:{seconds}.{microseconds}".format(
|
|
years=getattr(self, "years", 0),
|
|
months=getattr(self, "months", 0),
|
|
days=getattr(self, "days", 0),
|
|
hours=getattr(self, "hours", 0),
|
|
minutes=getattr(self, "minutes", 0),
|
|
seconds=getattr(self, "seconds", 0),
|
|
microseconds=getattr(self, "microseconds", 0),
|
|
)
|
|
expr = self.trim_pattern.sub("", expr)
|
|
if self.is_negative:
|
|
expr = "-" + expr
|
|
|
|
unit = (
|
|
"{largest}_{smallest}".format(
|
|
largest=self.largest,
|
|
smallest=self.smallest,
|
|
)
|
|
if self.largest != self.smallest
|
|
else self.largest
|
|
)
|
|
|
|
# Set default unit with DAY
|
|
if unit is None:
|
|
unit = "DAY"
|
|
|
|
return self.templates.get(dialect, "INTERVAL '{expr} {unit}'").format(expr=expr, unit=unit)
|
|
|
|
|
|
class Pow(Function):
|
|
def __init__(self, term: Term, exponent: float, alias: Optional[str] = None) -> None:
|
|
super().__init__("POW", term, exponent, alias=alias)
|
|
|
|
|
|
class Mod(Function):
|
|
def __init__(self, term: Term, modulus: float, alias: Optional[str] = None) -> None:
|
|
super().__init__("MOD", term, modulus, alias=alias)
|
|
|
|
|
|
class Rollup(Function):
|
|
def __init__(self, *terms: Any) -> None:
|
|
super().__init__("ROLLUP", *terms)
|
|
|
|
|
|
class PseudoColumn(Term):
|
|
"""
|
|
Represents a pseudo column (a "column" which yields a value when selected
|
|
but is not actually a real table column).
|
|
"""
|
|
|
|
def __init__(self, name: str) -> None:
|
|
super().__init__(alias=None)
|
|
self.name = name
|
|
|
|
def get_sql(self, **kwargs: Any) -> str:
|
|
return self.name
|
|
|
|
|
|
class AtTimezone(Term):
|
|
"""
|
|
Generates AT TIME ZONE SQL.
|
|
Examples:
|
|
AT TIME ZONE 'US/Eastern'
|
|
AT TIME ZONE INTERVAL '-06:00'
|
|
"""
|
|
|
|
is_aggregate = None
|
|
|
|
def __init__(self, field, zone, interval=False, alias=None):
|
|
super().__init__(alias)
|
|
self.field = Field(field) if not isinstance(field, Field) else field
|
|
self.zone = zone
|
|
self.interval = interval
|
|
|
|
def get_sql(self, **kwargs):
|
|
sql = '{name} AT TIME ZONE {interval}\'{zone}\''.format(
|
|
name=self.field.get_sql(**kwargs),
|
|
interval='INTERVAL ' if self.interval else '',
|
|
zone=self.zone,
|
|
)
|
|
return format_alias_sql(sql, self.alias, **kwargs)
|