diff --git a/src/NHibernate.Test/Async/Linq/CustomExtensionsExample.cs b/src/NHibernate.Test/Async/Linq/CustomExtensionsExample.cs index cc04d951659..e1d32918feb 100644 --- a/src/NHibernate.Test/Async/Linq/CustomExtensionsExample.cs +++ b/src/NHibernate.Test/Async/Linq/CustomExtensionsExample.cs @@ -8,12 +8,12 @@ //------------------------------------------------------------------------------ +using System; using System.Collections.ObjectModel; using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Text.RegularExpressions; -using NHibernate.Cfg; using NHibernate.DomainModel.Northwind.Entities; using NHibernate.Hql.Ast; using NHibernate.Linq.Functions; @@ -42,6 +42,18 @@ public async Task CanUseObjectEqualsAsync() Assert.That(users.All(c => c.NullableEnum1 == EnumStoredAsString.Medium), Is.True); } + [Test(Description = "GH-2963")] + public async Task CanUseComparisonWithExtensionOnMappedPropertyAsync() + { + if (!TestDialect.SupportsTime) + { + Assert.Ignore("Time type is not supported"); + } + + var time = DateTime.UtcNow.GetTime(); + await (db.Users.Where(u => u.RegisteredAt.GetTime() > time).Select(u => u.Id).ToListAsync()); + } + [Test] public async Task CanUseMyCustomExtensionAsync() { diff --git a/src/NHibernate.Test/Linq/CustomExtensionsExample.cs b/src/NHibernate.Test/Linq/CustomExtensionsExample.cs index c9b76f92cec..58e6a5e0f4d 100644 --- a/src/NHibernate.Test/Linq/CustomExtensionsExample.cs +++ b/src/NHibernate.Test/Linq/CustomExtensionsExample.cs @@ -1,9 +1,9 @@ +using System; using System.Collections.ObjectModel; using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Text.RegularExpressions; -using NHibernate.Cfg; using NHibernate.DomainModel.Northwind.Entities; using NHibernate.Hql.Ast; using NHibernate.Linq.Functions; @@ -23,6 +23,11 @@ public static bool IsLike(this string source, string pattern) return Regex.IsMatch(source, pattern); } + + public static TimeSpan GetTime(this DateTime dateTime) + { + return dateTime.TimeOfDay; + } } public class MyLinqToHqlGeneratorsRegistry: DefaultLinqToHqlGeneratorsRegistry @@ -32,6 +37,20 @@ public MyLinqToHqlGeneratorsRegistry():base() RegisterGenerator(ReflectHelper.GetMethodDefinition(() => MyLinqExtensions.IsLike(null, null)), new IsLikeGenerator()); RegisterGenerator(ReflectHelper.GetMethodDefinition(() => new object().Equals(null)), new ObjectEqualsGenerator()); + RegisterGenerator(ReflectHelper.GetMethodDefinition(() => MyLinqExtensions.GetTime(default(DateTime))), new GetTimeGenerator()); + } + } + + public class GetTimeGenerator : BaseHqlGeneratorForMethod + { + public GetTimeGenerator() + { + SupportedMethods = new[] { ReflectHelper.GetMethodDefinition(() => MyLinqExtensions.GetTime(default(DateTime))) }; + } + + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + return treeBuilder.MethodCall("cast", visitor.Visit(arguments[0]).AsExpression(), treeBuilder.Ident(NHibernateUtil.TimeAsTimeSpan.Name)); } } @@ -81,6 +100,18 @@ public void CanUseObjectEquals() Assert.That(users.All(c => c.NullableEnum1 == EnumStoredAsString.Medium), Is.True); } + [Test(Description = "GH-2963")] + public void CanUseComparisonWithExtensionOnMappedProperty() + { + if (!TestDialect.SupportsTime) + { + Assert.Ignore("Time type is not supported"); + } + + var time = DateTime.UtcNow.GetTime(); + db.Users.Where(u => u.RegisteredAt.GetTime() > time).Select(u => u.Id).ToList(); + } + [Test] public void CanUseMyCustomExtension() { diff --git a/src/NHibernate.Test/TestDialect.cs b/src/NHibernate.Test/TestDialect.cs index cc925423d2a..ce73d6f2cb7 100644 --- a/src/NHibernate.Test/TestDialect.cs +++ b/src/NHibernate.Test/TestDialect.cs @@ -1,4 +1,5 @@ using System; +using System.Data; using NHibernate.Hql.Ast.ANTLR; using NHibernate.Id; using NHibernate.SqlTypes; @@ -42,6 +43,8 @@ public bool NativeGeneratorSupportsBulkInsertion (IIdentifierGenerator) Cfg.Environment.ObjectsFactory.CreateInstance( _dialect.NativeIdentifierGeneratorClass)); + public virtual bool SupportsTime => _dialect.GetTypeName(new SqlType(DbType.Time)) != _dialect.GetTypeName(new SqlType(DbType.DateTime)); + public virtual bool SupportsOperatorAll => true; public virtual bool SupportsOperatorSome => true; diff --git a/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs b/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs index 3dc2cd30640..271c5143b48 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs @@ -56,6 +56,7 @@ public partial class HqlSqlWalker private readonly IDictionary _tokenReplacements; private readonly IDictionary _namedParameters; + private readonly IDictionary _guessedParameterTypes = new Dictionary(); private JoinType _impliedJoinType; @@ -98,6 +99,21 @@ public override void ReportError(RecognitionException e) _parseErrorHandler.ReportError(e); } + internal IStatement Transform() + { + var tree = (IStatement) statement().Tree; + // Use the guessed type in case we weren't been able to detect the type + foreach (var parameter in _parameters) + { + if (parameter.ExpectedType == null && _guessedParameterTypes.TryGetValue(parameter, out var guessedType)) + { + parameter.ExpectedType = guessedType; + } + } + + return tree; + } + /* protected override void Mismatch(IIntStream input, int ttype, BitSet follow) { @@ -1072,7 +1088,10 @@ IASTNode GenerateNamedParameter(IASTNode delimiterNode, IASTNode nameNode) // Add the parameter type information so that we are able to calculate functions return types // when the parameter is used as an argument. if (namedParameter.IsGuessedType) + { + _guessedParameterTypes[paramSpec] = namedParameter.Type; parameter.GuessedType = namedParameter.Type; + } else parameter.ExpectedType = namedParameter.Type; } diff --git a/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs b/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs index bcf3dc14e11..1dea389b9b0 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs @@ -610,7 +610,7 @@ public IStatement Translate() try { // Transform the tree. - _resultAst = (IStatement) hqlSqlWalker.statement().Tree; + _resultAst = hqlSqlWalker.Transform(); } finally { diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index de5f3511e23..2d3038abc64 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -159,11 +159,9 @@ private static IType GetParameterType( return candidateType; } - if (visitor.NotGuessableConstants.Contains(constantExpression) && constantExpression.Value != null) - { - tryProcessInHql = true; - } - + // Leave hql logic to determine the type except when the value is a char. Hql logic detects a char as a string, which causes an exception + // when trying to set a string db parameter with a char value. + tryProcessInHql = !(constantExpression.Value is char); // No related MemberExpressions was found, guess the type by value or its type when null. // When a numeric parameter is compared to different columns with different types (e.g. Where(o => o.Single >= singleParam || o.Double <= singleParam)) // do not change the parameter type, but instead cast the parameter when comparing with different column types. @@ -174,13 +172,10 @@ private static IType GetParameterType( private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor { - private bool _hqlGenerator; private readonly bool _removeMappedAsCalls; private readonly System.Type _targetType; private readonly IDictionary _parameters; private readonly ISessionFactoryImplementor _sessionFactory; - private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; - public readonly HashSet NotGuessableConstants = new HashSet(); public readonly Dictionary ConstantExpressions = new Dictionary(); public readonly Dictionary> ParameterConstants = @@ -198,7 +193,6 @@ public ConstantTypeLocatorVisitor( _targetType = targetType; _sessionFactory = sessionFactory; _parameters = parameters; - _functionRegistry = sessionFactory.Settings.LinqToHqlGeneratorsRegistry; } protected override Expression VisitBinary(BinaryExpression node) @@ -269,16 +263,6 @@ protected override Expression VisitMethodCall(MethodCallExpression node) return node; } - // For hql method generators we do not want to guess the parameter type here, let hql logic figure it out. - if (_functionRegistry.TryGetGenerator(node.Method, out _)) - { - var origHqlGenerator = _hqlGenerator; - _hqlGenerator = true; - var expression = base.VisitMethodCall(node); - _hqlGenerator = origHqlGenerator; - return expression; - } - return base.VisitMethodCall(node); } @@ -289,11 +273,6 @@ protected override Expression VisitConstant(ConstantExpression node) return node; } - if (_hqlGenerator) - { - NotGuessableConstants.Add(node); - } - RelatedExpressions.Add(node, new HashSet()); ConstantExpressions.Add(node, null); if (!ParameterConstants.TryGetValue(param, out var set))