Skip to content

Commit 0d74ed5

Browse files
authored
Fix parameter detection when using custom hql functions (#2964)
Fixes #2963
1 parent 09898a2 commit 0d74ed5

File tree

6 files changed

+71
-27
lines changed

6 files changed

+71
-27
lines changed

src/NHibernate.Test/Async/Linq/CustomExtensionsExample.cs

+13-1
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
//------------------------------------------------------------------------------
99

1010

11+
using System;
1112
using System.Collections.ObjectModel;
1213
using System.Linq;
1314
using System.Linq.Expressions;
1415
using System.Reflection;
1516
using System.Text.RegularExpressions;
16-
using NHibernate.Cfg;
1717
using NHibernate.DomainModel.Northwind.Entities;
1818
using NHibernate.Hql.Ast;
1919
using NHibernate.Linq.Functions;
@@ -42,6 +42,18 @@ public async Task CanUseObjectEqualsAsync()
4242
Assert.That(users.All(c => c.NullableEnum1 == EnumStoredAsString.Medium), Is.True);
4343
}
4444

45+
[Test(Description = "GH-2963")]
46+
public async Task CanUseComparisonWithExtensionOnMappedPropertyAsync()
47+
{
48+
if (!TestDialect.SupportsTime)
49+
{
50+
Assert.Ignore("Time type is not supported");
51+
}
52+
53+
var time = DateTime.UtcNow.GetTime();
54+
await (db.Users.Where(u => u.RegisteredAt.GetTime() > time).Select(u => u.Id).ToListAsync());
55+
}
56+
4557
[Test]
4658
public async Task CanUseMyCustomExtensionAsync()
4759
{

src/NHibernate.Test/Linq/CustomExtensionsExample.cs

+32-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
using System;
12
using System.Collections.ObjectModel;
23
using System.Linq;
34
using System.Linq.Expressions;
45
using System.Reflection;
56
using System.Text.RegularExpressions;
6-
using NHibernate.Cfg;
77
using NHibernate.DomainModel.Northwind.Entities;
88
using NHibernate.Hql.Ast;
99
using NHibernate.Linq.Functions;
@@ -23,6 +23,11 @@ public static bool IsLike(this string source, string pattern)
2323

2424
return Regex.IsMatch(source, pattern);
2525
}
26+
27+
public static TimeSpan GetTime(this DateTime dateTime)
28+
{
29+
return dateTime.TimeOfDay;
30+
}
2631
}
2732

2833
public class MyLinqToHqlGeneratorsRegistry: DefaultLinqToHqlGeneratorsRegistry
@@ -32,6 +37,20 @@ public MyLinqToHqlGeneratorsRegistry():base()
3237
RegisterGenerator(ReflectHelper.GetMethodDefinition(() => MyLinqExtensions.IsLike(null, null)),
3338
new IsLikeGenerator());
3439
RegisterGenerator(ReflectHelper.GetMethodDefinition(() => new object().Equals(null)), new ObjectEqualsGenerator());
40+
RegisterGenerator(ReflectHelper.GetMethodDefinition(() => MyLinqExtensions.GetTime(default(DateTime))), new GetTimeGenerator());
41+
}
42+
}
43+
44+
public class GetTimeGenerator : BaseHqlGeneratorForMethod
45+
{
46+
public GetTimeGenerator()
47+
{
48+
SupportedMethods = new[] { ReflectHelper.GetMethodDefinition(() => MyLinqExtensions.GetTime(default(DateTime))) };
49+
}
50+
51+
public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
52+
{
53+
return treeBuilder.MethodCall("cast", visitor.Visit(arguments[0]).AsExpression(), treeBuilder.Ident(NHibernateUtil.TimeAsTimeSpan.Name));
3554
}
3655
}
3756

@@ -81,6 +100,18 @@ public void CanUseObjectEquals()
81100
Assert.That(users.All(c => c.NullableEnum1 == EnumStoredAsString.Medium), Is.True);
82101
}
83102

103+
[Test(Description = "GH-2963")]
104+
public void CanUseComparisonWithExtensionOnMappedProperty()
105+
{
106+
if (!TestDialect.SupportsTime)
107+
{
108+
Assert.Ignore("Time type is not supported");
109+
}
110+
111+
var time = DateTime.UtcNow.GetTime();
112+
db.Users.Where(u => u.RegisteredAt.GetTime() > time).Select(u => u.Id).ToList();
113+
}
114+
84115
[Test]
85116
public void CanUseMyCustomExtension()
86117
{

src/NHibernate.Test/TestDialect.cs

+3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Data;
23
using NHibernate.Hql.Ast.ANTLR;
34
using NHibernate.Id;
45
using NHibernate.SqlTypes;
@@ -42,6 +43,8 @@ public bool NativeGeneratorSupportsBulkInsertion
4243
(IIdentifierGenerator) Cfg.Environment.ObjectsFactory.CreateInstance(
4344
_dialect.NativeIdentifierGeneratorClass));
4445

46+
public virtual bool SupportsTime => _dialect.GetTypeName(new SqlType(DbType.Time)) != _dialect.GetTypeName(new SqlType(DbType.DateTime));
47+
4548
public virtual bool SupportsOperatorAll => true;
4649
public virtual bool SupportsOperatorSome => true;
4750

src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs

+19
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ public partial class HqlSqlWalker
5656

5757
private readonly IDictionary<string, string> _tokenReplacements;
5858
private readonly IDictionary<string, NamedParameter> _namedParameters;
59+
private readonly IDictionary<IParameterSpecification, IType> _guessedParameterTypes = new Dictionary<IParameterSpecification, IType>();
5960

6061
private JoinType _impliedJoinType;
6162

@@ -98,6 +99,21 @@ public override void ReportError(RecognitionException e)
9899
_parseErrorHandler.ReportError(e);
99100
}
100101

102+
internal IStatement Transform()
103+
{
104+
var tree = (IStatement) statement().Tree;
105+
// Use the guessed type in case we weren't been able to detect the type
106+
foreach (var parameter in _parameters)
107+
{
108+
if (parameter.ExpectedType == null && _guessedParameterTypes.TryGetValue(parameter, out var guessedType))
109+
{
110+
parameter.ExpectedType = guessedType;
111+
}
112+
}
113+
114+
return tree;
115+
}
116+
101117
/*
102118
protected override void Mismatch(IIntStream input, int ttype, BitSet follow)
103119
{
@@ -1072,7 +1088,10 @@ IASTNode GenerateNamedParameter(IASTNode delimiterNode, IASTNode nameNode)
10721088
// Add the parameter type information so that we are able to calculate functions return types
10731089
// when the parameter is used as an argument.
10741090
if (namedParameter.IsGuessedType)
1091+
{
1092+
_guessedParameterTypes[paramSpec] = namedParameter.Type;
10751093
parameter.GuessedType = namedParameter.Type;
1094+
}
10761095
else
10771096
parameter.ExpectedType = namedParameter.Type;
10781097
}

src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ public IStatement Translate()
610610
try
611611
{
612612
// Transform the tree.
613-
_resultAst = (IStatement) hqlSqlWalker.statement().Tree;
613+
_resultAst = hqlSqlWalker.Transform();
614614
}
615615
finally
616616
{

src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs

+3-24
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,9 @@ private static IType GetParameterType(
159159
return candidateType;
160160
}
161161

162-
if (visitor.NotGuessableConstants.Contains(constantExpression) && constantExpression.Value != null)
163-
{
164-
tryProcessInHql = true;
165-
}
166-
162+
// Leave hql logic to determine the type except when the value is a char. Hql logic detects a char as a string, which causes an exception
163+
// when trying to set a string db parameter with a char value.
164+
tryProcessInHql = !(constantExpression.Value is char);
167165
// No related MemberExpressions was found, guess the type by value or its type when null.
168166
// When a numeric parameter is compared to different columns with different types (e.g. Where(o => o.Single >= singleParam || o.Double <= singleParam))
169167
// do not change the parameter type, but instead cast the parameter when comparing with different column types.
@@ -174,13 +172,10 @@ private static IType GetParameterType(
174172

175173
private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor
176174
{
177-
private bool _hqlGenerator;
178175
private readonly bool _removeMappedAsCalls;
179176
private readonly System.Type _targetType;
180177
private readonly IDictionary<ConstantExpression, NamedParameter> _parameters;
181178
private readonly ISessionFactoryImplementor _sessionFactory;
182-
private readonly ILinqToHqlGeneratorsRegistry _functionRegistry;
183-
public readonly HashSet<ConstantExpression> NotGuessableConstants = new HashSet<ConstantExpression>();
184179
public readonly Dictionary<ConstantExpression, IType> ConstantExpressions =
185180
new Dictionary<ConstantExpression, IType>();
186181
public readonly Dictionary<NamedParameter, HashSet<ConstantExpression>> ParameterConstants =
@@ -198,7 +193,6 @@ public ConstantTypeLocatorVisitor(
198193
_targetType = targetType;
199194
_sessionFactory = sessionFactory;
200195
_parameters = parameters;
201-
_functionRegistry = sessionFactory.Settings.LinqToHqlGeneratorsRegistry;
202196
}
203197

204198
protected override Expression VisitBinary(BinaryExpression node)
@@ -269,16 +263,6 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
269263
return node;
270264
}
271265

272-
// For hql method generators we do not want to guess the parameter type here, let hql logic figure it out.
273-
if (_functionRegistry.TryGetGenerator(node.Method, out _))
274-
{
275-
var origHqlGenerator = _hqlGenerator;
276-
_hqlGenerator = true;
277-
var expression = base.VisitMethodCall(node);
278-
_hqlGenerator = origHqlGenerator;
279-
return expression;
280-
}
281-
282266
return base.VisitMethodCall(node);
283267
}
284268

@@ -289,11 +273,6 @@ protected override Expression VisitConstant(ConstantExpression node)
289273
return node;
290274
}
291275

292-
if (_hqlGenerator)
293-
{
294-
NotGuessableConstants.Add(node);
295-
}
296-
297276
RelatedExpressions.Add(node, new HashSet<Expression>());
298277
ConstantExpressions.Add(node, null);
299278
if (!ParameterConstants.TryGetValue(param, out var set))

0 commit comments

Comments
 (0)