1"""
2The functions in this module can be used for testing that the constraints of
3your models. Each assert function runs SQL UPDATEs that check for the existence
4of given constraint. Consider the following model::
5
6
7 class User(Base):
8 __tablename__ = 'user'
9 id = sa.Column(sa.Integer, primary_key=True)
10 name = sa.Column(sa.String(200), nullable=True)
11 email = sa.Column(sa.String(255), nullable=False)
12
13
14 user = User(name='John Doe', email='john@example.com')
15 session.add(user)
16 session.commit()
17
18
19We can easily test the constraints by assert_* functions::
20
21
22 from sqlalchemy_utils import (
23 assert_nullable,
24 assert_non_nullable,
25 assert_max_length
26 )
27
28 assert_nullable(user, 'name')
29 assert_non_nullable(user, 'email')
30 assert_max_length(user, 'name', 200)
31
32 # raises AssertionError because the max length of email is 255
33 assert_max_length(user, 'email', 300)
34"""
35
36from decimal import Decimal
37
38import sqlalchemy as sa
39from sqlalchemy.dialects.postgresql import ARRAY
40from sqlalchemy.exc import DataError, IntegrityError
41
42
43def _update_field(obj, field, value):
44 session = sa.orm.object_session(obj)
45 column = sa.inspect(obj.__class__).columns[field]
46 query = column.table.update().values(**{column.key: value})
47 session.execute(query)
48 session.flush()
49
50
51def _expect_successful_update(obj, field, value, reraise_exc):
52 try:
53 _update_field(obj, field, value)
54 except reraise_exc as e:
55 session = sa.orm.object_session(obj)
56 session.rollback()
57 assert False, str(e)
58
59
60def _expect_failing_update(obj, field, value, expected_exc):
61 try:
62 _update_field(obj, field, value)
63 except expected_exc:
64 pass
65 else:
66 raise AssertionError('Expected update to raise %s' % expected_exc)
67 finally:
68 session = sa.orm.object_session(obj)
69 session.rollback()
70
71
72def _repeated_value(type_):
73 if isinstance(type_, ARRAY):
74 if isinstance(type_.item_type, sa.Integer):
75 return [0]
76 elif isinstance(type_.item_type, sa.String):
77 return ['a']
78 elif isinstance(type_.item_type, sa.Numeric):
79 return [Decimal('0')]
80 else:
81 raise TypeError('Unknown array item type')
82 else:
83 return 'a'
84
85
86def _expected_exception(type_):
87 if isinstance(type_, ARRAY):
88 return IntegrityError
89 else:
90 return DataError
91
92
93def assert_nullable(obj, column):
94 """
95 Assert that given column is nullable. This is checked by running an SQL
96 update that assigns given column as None.
97
98 :param obj: SQLAlchemy declarative model object
99 :param column: Name of the column
100 """
101 _expect_successful_update(obj, column, None, IntegrityError)
102
103
104def assert_non_nullable(obj, column):
105 """
106 Assert that given column is not nullable. This is checked by running an SQL
107 update that assigns given column as None.
108
109 :param obj: SQLAlchemy declarative model object
110 :param column: Name of the column
111 """
112 _expect_failing_update(obj, column, None, IntegrityError)
113
114
115def assert_max_length(obj, column, max_length):
116 """
117 Assert that the given column is of given max length. This function supports
118 string typed columns as well as PostgreSQL array typed columns.
119
120 In the following example we add a check constraint that user can have a
121 maximum of 5 favorite colors and then test this.::
122
123
124 class User(Base):
125 __tablename__ = 'user'
126 id = sa.Column(sa.Integer, primary_key=True)
127 favorite_colors = sa.Column(ARRAY(sa.String), nullable=False)
128 __table_args__ = (
129 sa.CheckConstraint(
130 sa.func.array_length(favorite_colors, 1) <= 5
131 )
132 )
133
134
135 user = User(name='John Doe', favorite_colors=['red', 'blue'])
136 session.add(user)
137 session.commit()
138
139
140 assert_max_length(user, 'favorite_colors', 5)
141
142
143 :param obj: SQLAlchemy declarative model object
144 :param column: Name of the column
145 :param max_length: Maximum length of given column
146 """
147 type_ = sa.inspect(obj.__class__).columns[column].type
148 _expect_successful_update(
149 obj, column, _repeated_value(type_) * max_length, _expected_exception(type_)
150 )
151 _expect_failing_update(
152 obj,
153 column,
154 _repeated_value(type_) * (max_length + 1),
155 _expected_exception(type_),
156 )
157
158
159def assert_min_value(obj, column, min_value):
160 """
161 Assert that the given column must have a minimum value of `min_value`.
162
163 :param obj: SQLAlchemy declarative model object
164 :param column: Name of the column
165 :param min_value: The minimum allowed value for given column
166 """
167 _expect_successful_update(obj, column, min_value, IntegrityError)
168 _expect_failing_update(obj, column, min_value - 1, IntegrityError)
169
170
171def assert_max_value(obj, column, min_value):
172 """
173 Assert that the given column must have a minimum value of `max_value`.
174
175 :param obj: SQLAlchemy declarative model object
176 :param column: Name of the column
177 :param max_value: The maximum allowed value for given column
178 """
179 _expect_successful_update(obj, column, min_value, IntegrityError)
180 _expect_failing_update(obj, column, min_value + 1, IntegrityError)