Last active
August 29, 2015 14:12
-
-
Save jarshwah/76dbc87577b7fec05807 to your computer and use it in GitHub Desktop.
Fixes problems with oracle, cast issues with postgres, and removes depending on the type of compiler
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py | |
index 64b5ba0..8b69406 100644 | |
--- a/django/db/models/expressions.py | |
+++ b/django/db/models/expressions.py | |
@@ -487,11 +487,7 @@ class Value(ExpressionNode): | |
def as_sql(self, compiler, connection): | |
val = self.value | |
if self._output_field_or_none is not None: | |
- from django.db.models.sql.compiler import SQLUpdateCompiler | |
- if isinstance(compiler, SQLUpdateCompiler): | |
- val = self.output_field.get_db_prep_save(val, connection=connection) | |
- else: | |
- val = self.output_field.get_db_prep_value(val, connection=connection) | |
+ val = self.output_field.get_db_prep_value(val, connection=connection) | |
return '%s', [val] | |
@@ -635,10 +631,12 @@ class BaseCaseExpression(ExpressionNode): | |
def as_postgresql(self, compiler, connection): | |
sql, params = self.as_sql(compiler, connection) | |
if self._output_field_or_none is not None: | |
- from django.db.models.sql.compiler import SQLUpdateCompiler | |
- if isinstance(compiler, SQLUpdateCompiler): | |
- # cast expression for postgres | |
- return 'CAST(%s AS %s)' % (sql, self.output_field.db_type(connection)), params | |
+ # cast expression for postgres - removing components of the type | |
+ # within brackets: varchar(255) -> varchar. Required for values | |
+ # that look like strings but are more specific types like uuid or | |
+ # inet. | |
+ cast_type = self.output_field.db_type(connection).split('(')[0] | |
+ return 'CAST(%s AS %s)' % (sql, cast_type), params | |
return sql, params | |
diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py | |
index cef0c97..d697496 100644 | |
--- a/django/db/models/fields/__init__.py | |
+++ b/django/db/models/fields/__init__.py | |
@@ -1874,7 +1874,7 @@ class IPAddressField(Field): | |
class GenericIPAddressField(Field): | |
- empty_strings_allowed = True | |
+ empty_strings_allowed = False | |
description = _("IP address") | |
default_error_messages = {} | |
diff --git a/tests/expressions_case/models.py b/tests/expressions_case/models.py | |
index ff1e927..2a63e40 100644 | |
--- a/tests/expressions_case/models.py | |
+++ b/tests/expressions_case/models.py | |
@@ -11,14 +11,14 @@ class CaseTestModel(models.Model): | |
string = models.CharField(max_length=100) | |
big_integer = models.BigIntegerField(null=True) | |
- binary = models.BinaryField(null=True) | |
+ binary = models.BinaryField(default=b'') | |
boolean = models.BooleanField(default=False) | |
- comma_separated_integer = models.CommaSeparatedIntegerField(max_length=100, null=True) | |
+ comma_separated_integer = models.CommaSeparatedIntegerField(max_length=100, default='') | |
date = models.DateField(null=True) | |
date_time = models.DateTimeField(null=True) | |
decimal = models.DecimalField(max_digits=2, decimal_places=1, null=True) | |
duration = models.DurationField(null=True) | |
- email = models.EmailField(null=True) | |
+ email = models.EmailField(default='') | |
file = models.FileField(null=True) | |
file_path = models.FilePathField(null=True) | |
float = models.FloatField(null=True) | |
@@ -28,11 +28,11 @@ class CaseTestModel(models.Model): | |
null_boolean = models.NullBooleanField() | |
positive_integer = models.PositiveIntegerField(null=True) | |
positive_small_integer = models.PositiveSmallIntegerField(null=True) | |
- slug = models.SlugField(null=True) | |
+ slug = models.SlugField(default='') | |
small_integer = models.SmallIntegerField(null=True) | |
- text = models.TextField(null=True) | |
+ text = models.TextField(default='') | |
time = models.TimeField(null=True) | |
- url = models.URLField(null=True) | |
+ url = models.URLField(default='') | |
uuid = models.UUIDField(null=True) | |
fk = models.ForeignKey('self', null=True) | |
diff --git a/tests/expressions_case/tests.py b/tests/expressions_case/tests.py | |
index 1bf4422..d34f318 100644 | |
--- a/tests/expressions_case/tests.py | |
+++ b/tests/expressions_case/tests.py | |
@@ -10,7 +10,7 @@ from django.db import models | |
from django.db.models import F, Q, Value | |
from django.db.models.expressions import SearchedCase, SimpleCase | |
from django.test import TestCase | |
-from django.utils.six import binary_type | |
+from django.utils.six import binary_type, text_type | |
from .models import CaseTestModel, FKCaseTestModel | |
@@ -236,12 +236,13 @@ class BaseCaseExpressionTests(TestCase): | |
# set explicitly | |
[(Value(1), Value(b'one')), | |
(Value(2), Value(b'two'))], | |
+ default=Value(b''), | |
output_field=models.BinaryField())) | |
self.assertQuerysetEqual( | |
CaseTestModel.objects.all().order_by('pk'), | |
- [(1, b'one'), (2, b'two'), (3, None), (2, b'two'), (3, None), (3, None), (4, None)], | |
- transform=lambda o: (o.integer, None if o.binary is None else binary_type(o.binary))) | |
+ [(1, b'one'), (2, b'two'), (3, b''), (2, b'two'), (3, b''), (3, b''), (4, b'')], | |
+ transform=lambda o: (o.integer, binary_type(o.binary))) | |
def test_update_boolean(self): | |
CaseTestModel.objects.update( | |
@@ -261,11 +262,11 @@ class BaseCaseExpressionTests(TestCase): | |
comma_separated_integer=self.create_expression( | |
'integer', | |
[(Value(1), Value('1')), | |
- (Value(2), Value('2,2'))])) | |
+ (Value(2), Value('2,2'))], default=Value(''))) | |
self.assertQuerysetEqual( | |
CaseTestModel.objects.all().order_by('pk'), | |
- [(1, '1'), (2, '2,2'), (3, None), (2, '2,2'), (3, None), (3, None), (4, None)], | |
+ [(1, '1'), (2, '2,2'), (3, ''), (2, '2,2'), (3, ''), (3, ''), (4, '')], | |
transform=attrgetter('integer', 'comma_separated_integer')) | |
def test_update_date(self): | |
@@ -325,12 +326,12 @@ class BaseCaseExpressionTests(TestCase): | |
email=self.create_expression( | |
'integer', | |
[(Value(1), Value('[email protected]')), | |
- (Value(2), Value('[email protected]'))])) | |
+ (Value(2), Value('[email protected]'))], default=Value(''))) | |
self.assertQuerysetEqual( | |
CaseTestModel.objects.all().order_by('pk'), | |
- [(1, '[email protected]'), (2, '[email protected]'), (3, None), (2, '[email protected]'), (3, None), (3, None), | |
- (4, None)], | |
+ [(1, '[email protected]'), (2, '[email protected]'), (3, ''), (2, '[email protected]'), (3, ''), (3, ''), | |
+ (4, '')], | |
transform=attrgetter('integer', 'email')) | |
def test_update_file(self): | |
@@ -342,19 +343,19 @@ class BaseCaseExpressionTests(TestCase): | |
self.assertQuerysetEqual( | |
CaseTestModel.objects.all().order_by('pk'), | |
- [(1, '~/1'), (2, '~/2'), (3, None), (2, '~/2'), (3, None), (3, None), (4, None)], | |
- transform=attrgetter('integer', 'file')) | |
+ [(1, '~/1'), (2, '~/2'), (3, ''), (2, '~/2'), (3, ''), (3, ''), (4, '')], | |
+ transform=lambda o: (o.integer, text_type(o.file))) | |
def test_update_file_path(self): | |
CaseTestModel.objects.update( | |
file_path=self.create_expression( | |
'integer', | |
[(Value(1), Value('~/1')), | |
- (Value(2), Value('~/2'))])) | |
+ (Value(2), Value('~/2'))], default=Value(''))) | |
self.assertQuerysetEqual( | |
CaseTestModel.objects.all().order_by('pk'), | |
- [(1, '~/1'), (2, '~/2'), (3, None), (2, '~/2'), (3, None), (3, None), (4, None)], | |
+ [(1, '~/1'), (2, '~/2'), (3, ''), (2, '~/2'), (3, ''), (3, ''), (4, '')], | |
transform=attrgetter('integer', 'file_path')) | |
def test_update_float(self): | |
@@ -375,11 +376,10 @@ class BaseCaseExpressionTests(TestCase): | |
'integer', | |
[(Value(1), Value('~/1')), | |
(Value(2), Value('~/2'))])) | |
- | |
self.assertQuerysetEqual( | |
CaseTestModel.objects.all().order_by('pk'), | |
- [(1, '~/1'), (2, '~/2'), (3, None), (2, '~/2'), (3, None), (3, None), (4, None)], | |
- transform=attrgetter('integer', 'image')) | |
+ [(1, '~/1'), (2, '~/2'), (3, ''), (2, '~/2'), (3, ''), (3, ''), (4, '')], | |
+ transform=lambda o: (o.integer, text_type(o.image))) | |
def test_update_ip_address(self): | |
CaseTestModel.objects.update( | |
@@ -450,11 +450,11 @@ class BaseCaseExpressionTests(TestCase): | |
slug=self.create_expression( | |
'integer', | |
[(Value(1), Value('1')), | |
- (Value(2), Value('2'))])) | |
+ (Value(2), Value('2'))], default=Value(''))) | |
self.assertQuerysetEqual( | |
CaseTestModel.objects.all().order_by('pk'), | |
- [(1, '1'), (2, '2'), (3, None), (2, '2'), (3, None), (3, None), (4, None)], | |
+ [(1, '1'), (2, '2'), (3, ''), (2, '2'), (3, ''), (3, ''), (4, '')], | |
transform=attrgetter('integer', 'slug')) | |
def test_update_small_integer(self): | |
@@ -469,16 +469,28 @@ class BaseCaseExpressionTests(TestCase): | |
[(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)], | |
transform=attrgetter('integer', 'small_integer')) | |
+ def test_update_string(self): | |
+ CaseTestModel.objects.filter(string__in=['1', '2']).update( | |
+ string=self.create_expression( | |
+ 'integer', | |
+ [(Value(1), Value('1', output_field=models.CharField())), | |
+ (Value(2), Value('2', output_field=models.CharField()))])) | |
+ | |
+ self.assertQuerysetEqual( | |
+ CaseTestModel.objects.filter(string__in=['1', '2']).order_by('pk'), | |
+ [(1, '1'), (2, '2'), (2, '2')], | |
+ transform=attrgetter('integer', 'string')) | |
+ | |
def test_update_text(self): | |
CaseTestModel.objects.update( | |
text=self.create_expression( | |
'integer', | |
[(Value(1), Value('1')), | |
- (Value(2), Value('2'))])) | |
+ (Value(2), Value('2'))], default=Value(''))) | |
self.assertQuerysetEqual( | |
CaseTestModel.objects.all().order_by('pk'), | |
- [(1, '1'), (2, '2'), (3, None), (2, '2'), (3, None), (3, None), (4, None)], | |
+ [(1, '1'), (2, '2'), (3, ''), (2, '2'), (3, ''), (3, ''), (4, '')], | |
transform=attrgetter('integer', 'text')) | |
def test_update_time(self): | |
@@ -500,12 +512,12 @@ class BaseCaseExpressionTests(TestCase): | |
url=self.create_expression( | |
'integer', | |
[(Value(1), Value('http://1.example.com/')), | |
- (Value(2), Value('http://2.example.com/'))])) | |
+ (Value(2), Value('http://2.example.com/'))], default=Value(''))) | |
self.assertQuerysetEqual( | |
CaseTestModel.objects.all().order_by('pk'), | |
- [(1, 'http://1.example.com/'), (2, 'http://2.example.com/'), (3, None), (2, 'http://2.example.com/'), | |
- (3, None), (3, None), (4, None)], | |
+ [(1, 'http://1.example.com/'), (2, 'http://2.example.com/'), (3, ''), (2, 'http://2.example.com/'), | |
+ (3, ''), (3, ''), (4, '')], | |
transform=attrgetter('integer', 'url')) | |
def test_update_uuid(self): |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment