Skip to content

Commit 8201004

Browse files
NH-3787 - Decimal truncation in some Linq queries
* Fixes #1335
1 parent 70413b7 commit 8201004

File tree

11 files changed

+332
-5
lines changed

11 files changed

+332
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
//------------------------------------------------------------------------------
2+
// <auto-generated>
3+
// This code was generated by AsyncGenerator.
4+
//
5+
// Changes to this file may cause incorrect behavior and will be lost if
6+
// the code is regenerated.
7+
// </auto-generated>
8+
//------------------------------------------------------------------------------
9+
10+
11+
using System.Linq;
12+
using NHibernate.Criterion;
13+
using NHibernate.Transform;
14+
using NUnit.Framework;
15+
using NHibernate.Linq;
16+
17+
namespace NHibernate.Test.NHSpecificTest.NH3787
18+
{
19+
using System.Threading.Tasks;
20+
[TestFixture]
21+
public class TestFixtureAsync : BugTestCase
22+
{
23+
private const decimal _testRate = 12345.123456789M;
24+
25+
protected override void OnSetUp()
26+
{
27+
base.OnSetUp();
28+
29+
using (var s = OpenSession())
30+
using (var t = s.BeginTransaction())
31+
{
32+
var testEntity = new TestEntity
33+
{
34+
UsePreviousRate = true,
35+
PreviousRate = _testRate,
36+
Rate = 54321.123456789M
37+
};
38+
s.Save(testEntity);
39+
t.Commit();
40+
}
41+
}
42+
43+
protected override void OnTearDown()
44+
{
45+
using (var s = OpenSession())
46+
using (var t = s.BeginTransaction())
47+
{
48+
s.CreateQuery("delete from TestEntity").ExecuteUpdate();
49+
t.Commit();
50+
}
51+
}
52+
53+
[Test]
54+
public async Task TestLinqQueryAsync()
55+
{
56+
using (var s = OpenSession())
57+
using (var t = s.BeginTransaction())
58+
{
59+
var queryResult = await (s
60+
.Query<TestEntity>()
61+
.Where(e => e.PreviousRate == _testRate)
62+
.ToListAsync());
63+
64+
Assert.That(queryResult.Count, Is.EqualTo(1));
65+
Assert.That(queryResult[0].PreviousRate, Is.EqualTo(_testRate));
66+
await (t.CommitAsync());
67+
}
68+
}
69+
70+
[Test]
71+
public async Task TestLinqProjectionAsync()
72+
{
73+
using (var s = OpenSession())
74+
using (var t = s.BeginTransaction())
75+
{
76+
var queryResult = await ((from test in s.Query<TestEntity>()
77+
select new RateDto { Rate = test.UsePreviousRate ? test.PreviousRate : test.Rate }).ToListAsync());
78+
79+
// Check it has not been truncated to the default 5 positions of NHibernate.
80+
Assert.That(queryResult[0].Rate, Is.EqualTo(_testRate));
81+
await (t.CommitAsync());
82+
}
83+
}
84+
85+
[Test]
86+
public async Task TestLinqQueryOnExpressionAsync()
87+
{
88+
using (var s = OpenSession())
89+
using (var t = s.BeginTransaction())
90+
{
91+
var queryResult = await (s
92+
.Query<TestEntity>()
93+
.Where(e => (e.UsePreviousRate ? e.PreviousRate : e.Rate) == _testRate)
94+
.ToListAsync());
95+
96+
Assert.That(queryResult.Count, Is.EqualTo(1));
97+
Assert.That(queryResult[0].PreviousRate, Is.EqualTo(_testRate));
98+
await (t.CommitAsync());
99+
}
100+
}
101+
102+
[Test]
103+
public async Task TestQueryOverProjectionAsync()
104+
{
105+
using (var s = OpenSession())
106+
using (var t = s.BeginTransaction())
107+
{
108+
TestEntity testEntity = null;
109+
110+
var rateDto = new RateDto();
111+
//Generated sql
112+
//exec sp_executesql N'SELECT (case when this_.UsePreviousRate = @p0 then this_.PreviousRate else this_.Rate end) as y0_ FROM [TestEntity] this_',N'@p0 bit',@p0=1
113+
var query = s
114+
.QueryOver(() => testEntity)
115+
.Select(
116+
Projections.Alias(
117+
Projections.Conditional(
118+
Restrictions.Eq(Projections.Property(() => testEntity.UsePreviousRate), true),
119+
Projections.Property(() => testEntity.PreviousRate),
120+
Projections.Property(() => testEntity.Rate)),
121+
"Rate")
122+
.WithAlias(() => rateDto.Rate));
123+
124+
var queryResult = await (query.TransformUsing(Transformers.AliasToBean<RateDto>()).ListAsync<RateDto>());
125+
126+
Assert.That(queryResult[0].Rate, Is.EqualTo(_testRate));
127+
await (t.CommitAsync());
128+
}
129+
}
130+
}
131+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
<?xml version="1.0" encoding="utf-8" ?>
2+
<hibernate-mapping xmlns="urn:nhibernate-mapping-2.2" assembly="NHibernate.Test"
3+
namespace="NHibernate.Test.NHSpecificTest.NH3787">
4+
<class name="TestEntity" table="TestEntity">
5+
<id name="Id">
6+
<generator class="native"/>
7+
</id>
8+
<property name="UsePreviousRate" type="boolean" not-null="true"/>
9+
<property name="PreviousRate" type="decimal(18,13)" not-null="true"/>
10+
<property name="Rate" type="decimal(18,13)" not-null="true"/>
11+
</class>
12+
</hibernate-mapping>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
namespace NHibernate.Test.NHSpecificTest.NH3787
2+
{
3+
public class RateDto
4+
{
5+
public decimal Rate { get; set; }
6+
}
7+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
namespace NHibernate.Test.NHSpecificTest.NH3787
2+
{
3+
public class TestEntity
4+
{
5+
public virtual int Id { get; set; }
6+
public virtual bool UsePreviousRate { get; set; }
7+
public virtual decimal Rate { get; set; }
8+
public virtual decimal PreviousRate { get; set; }
9+
}
10+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
using System.Linq;
2+
using NHibernate.Criterion;
3+
using NHibernate.Linq;
4+
using NHibernate.Transform;
5+
using NHibernate.Type;
6+
using NUnit.Framework;
7+
8+
namespace NHibernate.Test.NHSpecificTest.NH3787
9+
{
10+
[TestFixture]
11+
public class TestFixture : BugTestCase
12+
{
13+
private const decimal _testRate = 12345.1234567890123M;
14+
15+
protected override void OnSetUp()
16+
{
17+
base.OnSetUp();
18+
19+
using (var s = OpenSession())
20+
using (var t = s.BeginTransaction())
21+
{
22+
var testEntity = new TestEntity
23+
{
24+
UsePreviousRate = true,
25+
PreviousRate = _testRate,
26+
Rate = 54321.1234567890123M
27+
};
28+
s.Save(testEntity);
29+
t.Commit();
30+
}
31+
}
32+
33+
protected override void OnTearDown()
34+
{
35+
using (var s = OpenSession())
36+
using (var t = s.BeginTransaction())
37+
{
38+
s.CreateQuery("delete from TestEntity").ExecuteUpdate();
39+
t.Commit();
40+
}
41+
}
42+
43+
[Test]
44+
public void TestLinqQuery()
45+
{
46+
using (var s = OpenSession())
47+
using (var t = s.BeginTransaction())
48+
{
49+
var queryResult = s
50+
.Query<TestEntity>()
51+
.Where(e => e.PreviousRate == _testRate)
52+
.ToList();
53+
54+
Assert.That(queryResult.Count, Is.EqualTo(1));
55+
Assert.That(queryResult[0].PreviousRate, Is.EqualTo(_testRate));
56+
t.Commit();
57+
}
58+
}
59+
60+
[Test]
61+
public void TestLinqProjection()
62+
{
63+
using (var s = OpenSession())
64+
using (var t = s.BeginTransaction())
65+
{
66+
var queryResult = (from test in s.Query<TestEntity>()
67+
select new RateDto { Rate = test.UsePreviousRate ? test.PreviousRate : test.Rate }).ToList();
68+
69+
// Check it has not been truncated to the default scale (10) of NHibernate.
70+
Assert.That(queryResult[0].Rate, Is.EqualTo(_testRate));
71+
t.Commit();
72+
}
73+
}
74+
75+
[Test]
76+
public void TestLinqQueryOnExpression()
77+
{
78+
using (var s = OpenSession())
79+
using (var t = s.BeginTransaction())
80+
{
81+
var queryResult = s
82+
.Query<TestEntity>()
83+
.Where(
84+
// Without MappedAs, the test fails for SQL Server because it would restrict its parameter to the dialect's default scale.
85+
e => (e.UsePreviousRate ? e.PreviousRate : e.Rate) == _testRate.MappedAs(TypeFactory.Basic("decimal(18,13)")))
86+
.ToList();
87+
88+
Assert.That(queryResult.Count, Is.EqualTo(1));
89+
Assert.That(queryResult[0].PreviousRate, Is.EqualTo(_testRate));
90+
t.Commit();
91+
}
92+
}
93+
94+
[Test]
95+
public void TestQueryOverProjection()
96+
{
97+
using (var s = OpenSession())
98+
using (var t = s.BeginTransaction())
99+
{
100+
TestEntity testEntity = null;
101+
102+
var rateDto = new RateDto();
103+
//Generated sql
104+
//exec sp_executesql N'SELECT (case when this_.UsePreviousRate = @p0 then this_.PreviousRate else this_.Rate end) as y0_ FROM [TestEntity] this_',N'@p0 bit',@p0=1
105+
var query = s
106+
.QueryOver(() => testEntity)
107+
.Select(
108+
Projections
109+
.Alias(
110+
Projections.Conditional(
111+
Restrictions.Eq(Projections.Property(() => testEntity.UsePreviousRate), true),
112+
Projections.Property(() => testEntity.PreviousRate),
113+
Projections.Property(() => testEntity.Rate)),
114+
"Rate")
115+
.WithAlias(() => rateDto.Rate));
116+
117+
var queryResult = query.TransformUsing(Transformers.AliasToBean<RateDto>()).List<RateDto>();
118+
119+
Assert.That(queryResult[0].Rate, Is.EqualTo(_testRate));
120+
t.Commit();
121+
}
122+
}
123+
}
124+
}

src/NHibernate/Dialect/Dialect.cs

+1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ protected Dialect()
9898
RegisterFunction("upper", new StandardSQLFunction("upper"));
9999
RegisterFunction("lower", new StandardSQLFunction("lower"));
100100
RegisterFunction("cast", new CastFunction());
101+
RegisterFunction("transparentcast", new TransparentCastFunction());
101102
RegisterFunction("extract", new AnsiExtractFunction());
102103
RegisterFunction("concat", new VarArgsSQLFunction(NHibernateUtil.String, "(", "||", ")"));
103104

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System;
2+
3+
namespace NHibernate.Dialect.Function
4+
{
5+
/// <summary>
6+
/// A HQL only cast for helping HQL knowing the type. Does not generates any actual cast in SQL code.
7+
/// </summary>
8+
[Serializable]
9+
public class TransparentCastFunction : CastFunction
10+
{
11+
protected override bool CastingIsRequired(string sqlType)
12+
{
13+
return false;
14+
}
15+
}
16+
}

src/NHibernate/Hql/Ast/ANTLR/SessionFactoryHelperExtensions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ public IType FindFunctionReturnType(String functionName, IASTNode first)
7777

7878
if (first != null)
7979
{
80-
if (functionName == "cast")
80+
if (sqlFunction is CastFunction)
8181
{
8282
argumentType = TypeFactory.HeuristicType(first.NextSibling.Text);
8383
}

src/NHibernate/Hql/Ast/HqlTreeBuilder.cs

+11
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,17 @@ public HqlCast Cast(HqlExpression expression, System.Type type)
301301
return new HqlCast(_factory, expression, type);
302302
}
303303

304+
/// <summary>
305+
/// Generate a cast node intended solely to hint HQL at the resulting type, without issuing an actual SQL cast.
306+
/// </summary>
307+
/// <param name="expression">The expression to cast.</param>
308+
/// <param name="type">The resulting type.</param>
309+
/// <returns>A <see cref="HqlTransparentCast"/> node.</returns>
310+
public HqlTransparentCast TransparentCast(HqlExpression expression, System.Type type)
311+
{
312+
return new HqlTransparentCast(_factory, expression, type);
313+
}
314+
304315
public HqlBitwiseNot BitwiseNot()
305316
{
306317
return new HqlBitwiseNot(_factory);

src/NHibernate/Hql/Ast/HqlTreeNode.cs

+13
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,19 @@ public HqlCast(IASTFactory factory, HqlExpression expression, System.Type type)
701701
}
702702
}
703703

704+
/// <summary>
705+
/// Cast node intended solely to hint HQL at the resulting type, without issuing an actual SQL cast.
706+
/// </summary>
707+
public class HqlTransparentCast : HqlExpression
708+
{
709+
public HqlTransparentCast(IASTFactory factory, HqlExpression expression, System.Type type)
710+
: base(HqlSqlWalker.METHOD_CALL, "method", factory)
711+
{
712+
AddChild(new HqlIdent(factory, "transparentcast"));
713+
AddChild(new HqlExpressionList(factory, expression, new HqlIdent(factory, type)));
714+
}
715+
}
716+
704717
public class HqlCoalesce : HqlExpression
705718
{
706719
public HqlCoalesce(IASTFactory factory, HqlExpression lhs, HqlExpression rhs)

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs

+6-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
using NHibernate.Param;
99
using NHibernate.Util;
1010
using Remotion.Linq.Clauses.Expressions;
11-
using Remotion.Linq.Clauses.ResultOperators;
1211

1312
namespace NHibernate.Linq.Visitors
1413
{
@@ -538,9 +537,12 @@ protected HqlTreeNode VisitConditionalExpression(ConditionalExpression expressio
538537

539538
HqlExpression @case = _hqlTreeBuilder.Case(new[] {_hqlTreeBuilder.When(test, ifTrue)}, ifFalse);
540539

541-
return (expression.Type == typeof (bool) || expression.Type == (typeof (bool?)))
542-
? @case
543-
: _hqlTreeBuilder.Cast(@case, expression.Type);
540+
// If both operands are parameters, HQL will not be able to determine the resulting type before
541+
// parameters binding. But it has to compute result set columns type before parameters are bound,
542+
// so an artificial cast is introduced to hint HQL at the resulting type.
543+
return expression.Type == typeof(bool) || expression.Type == typeof(bool?)
544+
? @case
545+
: _hqlTreeBuilder.TransparentCast(@case, expression.Type);
544546
}
545547

546548
protected HqlTreeNode VisitSubQueryExpression(SubQueryExpression expression)

0 commit comments

Comments
 (0)