Skip to content

Commit 0cbedba

Browse files
committed
minor refactoring and fixed default value of VersionField
1 parent 8e1e561 commit 0cbedba

File tree

4 files changed

+28
-12
lines changed

4 files changed

+28
-12
lines changed

concurrency/core.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ class RecordModifiedError(DatabaseError):
1010
pass
1111

1212

13+
class InconsistencyError(DatabaseError):
14+
pass
15+
16+
1317
def apply_concurrency_check(model, fieldname, versionclass):
1418
"""
1519
Apply concurrency management to existing Models.
@@ -42,20 +46,20 @@ def concurrency_check(model_instance, force_insert=False, force_update=False, us
4246

4347

4448
def _select_lock(model_instance, version_value=None):
49+
version_field = model_instance.RevisionMetaInfo.field
50+
value = getattr(model_instance, version_field.name)
51+
is_versioned = value != version_field.get_default()
52+
4553
if model_instance.pk is not None:
46-
version_field = model_instance.RevisionMetaInfo.field
4754
kwargs = {'pk': model_instance.pk,
4855
version_field.name: version_value or getattr(model_instance, version_field.name)}
4956
alias = router.db_for_write(model_instance)
5057
NOWAIT = connections[alias].features.has_select_for_update_nowait
5158
entry = model_instance.__class__.objects.select_for_update(nowait=NOWAIT).filter(**kwargs)
5259
if not entry:
53-
value = getattr(model_instance, version_field.name)
54-
if value != version_field.get_default():
55-
raise RecordModifiedError(_('Version field is set (%s) but record has `pk`.' % value))
56-
elif value == version_field.get_default():
57-
return
5860
raise RecordModifiedError(_('Record has been modified'))
61+
elif is_versioned:
62+
raise InconsistencyError(_('Version field is set (%s) but record has `pk`.' % value))
5963

6064

6165
def _wrap_save(func):

concurrency/fields.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def __init__(self, **kwargs):
2424
help_text = kwargs.get('help_text', _('record revision number'))
2525

2626
super(VersionField, self).__init__(verbose_name, name, editable=True,
27-
help_text=help_text, null=False, blank=False, default=1,
27+
help_text=help_text, null=False, blank=False,
28+
default=0,
2829
db_tablespace=db_tablespace, db_column=db_column)
2930

3031
def get_default(self):
@@ -61,7 +62,7 @@ def get_internal_type(self):
6162
return "BigIntegerField"
6263

6364
def pre_save(self, model_instance, add):
64-
old_value = getattr(model_instance, self.attname)
65+
old_value = getattr(model_instance, self.attname) or 0
6566
value = max(old_value + 1, (int(time.time() * 1000000) - OFFSET))
6667
setattr(model_instance, self.attname, value)
6768
return value
@@ -78,7 +79,7 @@ def get_internal_type(self):
7879
return "BigIntegerField"
7980

8081
def pre_save(self, model_instance, add):
81-
value = getattr(model_instance, self.attname) + 1
82+
value = (getattr(model_instance, self.attname) or 0) + 1
8283
setattr(model_instance, self.attname, value)
8384
return value
8485

@@ -127,6 +128,17 @@ def pre_save(self, model_instance, add):
127128
# return None
128129
# return time.strftime('%Y%m%d%H%M%S', value.timetuple())
129130

131+
try:
132+
from django_any import any_field
133+
import random
134+
from django.db.models.fields import Field, BigIntegerField
135+
any_field.register(IntegerVersionField,
136+
lambda x, **kwargs: random.randint(1, BigIntegerField.MAX_BIGINT))
137+
any_field.register(AutoIncVersionField,
138+
lambda x, **kwargs: random.randint(1, BigIntegerField.MAX_BIGINT))
139+
140+
except ImportError as e:
141+
pass
130142

131143
try:
132144
from south.modelsinspector import add_introspection_rules

concurrency/tests/all.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _get_REVISION_NUMBER(obj):
4646

4747
def _check_save(self, obj):
4848
obj.save()
49-
assert obj.pk # sanity check
49+
assert obj.pk is not None # sanity check
5050

5151
def _get_form_data(self, **kwargs):
5252
data = {}
@@ -299,7 +299,7 @@ def _get_target(self):
299299

300300
def _check_save(self, obj):
301301
obj.save()
302-
assert obj.pk # sanity check
302+
assert obj.pk is not None # sanity check
303303
self.assertTrue(obj.version)
304304

305305
def _get_form_data(self, **kwargs):

concurrency/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_concurrency_conflict(self):
4848
v2 = get_revision_of_object(target_copy)
4949
assert v1 == v2, "got same row with different version (%s/%s)" % (v1, v2)
5050
target.save()
51-
assert target.pk
51+
assert target.pk is not None # sanity check
5252
self.assertRaises(RecordModifiedError, target_copy.save)
5353

5454
def test_concurrency_safety(self):

0 commit comments

Comments
 (0)