Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parameter detection when using custom hql functions #2964

Merged
merged 8 commits into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/NHibernate.Test/Async/Linq/CustomExtensionsExample.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
//------------------------------------------------------------------------------


using System;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text.RegularExpressions;
using NHibernate.Cfg;
using NHibernate.DomainModel.Northwind.Entities;
using NHibernate.Hql.Ast;
using NHibernate.Linq.Functions;
Expand Down Expand Up @@ -42,6 +42,18 @@ public async Task CanUseObjectEqualsAsync()
Assert.That(users.All(c => c.NullableEnum1 == EnumStoredAsString.Medium), Is.True);
}

[Test(Description = "GH-2963")]
public async Task CanUseComparisonWithExtensionOnMappedPropertyAsync()
{
if (!TestDialect.SupportsTime)
{
Assert.Ignore("Time type is not supported");
}

var time = DateTime.UtcNow.GetTime();
await (db.Users.Where(u => u.RegisteredAt.GetTime() > time).Select(u => u.Id).ToListAsync());
}

[Test]
public async Task CanUseMyCustomExtensionAsync()
{
Expand Down
33 changes: 32 additions & 1 deletion src/NHibernate.Test/Linq/CustomExtensionsExample.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using System;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text.RegularExpressions;
using NHibernate.Cfg;
using NHibernate.DomainModel.Northwind.Entities;
using NHibernate.Hql.Ast;
using NHibernate.Linq.Functions;
Expand All @@ -23,6 +23,11 @@ public static bool IsLike(this string source, string pattern)

return Regex.IsMatch(source, pattern);
}

public static TimeSpan GetTime(this DateTime dateTime)
{
return dateTime.TimeOfDay;
}
}

public class MyLinqToHqlGeneratorsRegistry: DefaultLinqToHqlGeneratorsRegistry
Expand All @@ -32,6 +37,20 @@ public MyLinqToHqlGeneratorsRegistry():base()
RegisterGenerator(ReflectHelper.GetMethodDefinition(() => MyLinqExtensions.IsLike(null, null)),
new IsLikeGenerator());
RegisterGenerator(ReflectHelper.GetMethodDefinition(() => new object().Equals(null)), new ObjectEqualsGenerator());
RegisterGenerator(ReflectHelper.GetMethodDefinition(() => MyLinqExtensions.GetTime(default(DateTime))), new GetTimeGenerator());
}
}

public class GetTimeGenerator : BaseHqlGeneratorForMethod
{
public GetTimeGenerator()
{
SupportedMethods = new[] { ReflectHelper.GetMethodDefinition(() => MyLinqExtensions.GetTime(default(DateTime))) };
}

public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
{
return treeBuilder.MethodCall("cast", visitor.Visit(arguments[0]).AsExpression(), treeBuilder.Ident(NHibernateUtil.TimeAsTimeSpan.Name));
}
}

Expand Down Expand Up @@ -81,6 +100,18 @@ public void CanUseObjectEquals()
Assert.That(users.All(c => c.NullableEnum1 == EnumStoredAsString.Medium), Is.True);
}

[Test(Description = "GH-2963")]
public void CanUseComparisonWithExtensionOnMappedProperty()
{
if (!TestDialect.SupportsTime)
{
Assert.Ignore("Time type is not supported");
}

var time = DateTime.UtcNow.GetTime();
db.Users.Where(u => u.RegisteredAt.GetTime() > time).Select(u => u.Id).ToList();
}

[Test]
public void CanUseMyCustomExtension()
{
Expand Down
3 changes: 3 additions & 0 deletions src/NHibernate.Test/TestDialect.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Data;
using NHibernate.Hql.Ast.ANTLR;
using NHibernate.Id;
using NHibernate.SqlTypes;
Expand Down Expand Up @@ -42,6 +43,8 @@ public bool NativeGeneratorSupportsBulkInsertion
(IIdentifierGenerator) Cfg.Environment.ObjectsFactory.CreateInstance(
_dialect.NativeIdentifierGeneratorClass));

public virtual bool SupportsTime => _dialect.GetTypeName(new SqlType(DbType.Time)) != _dialect.GetTypeName(new SqlType(DbType.DateTime));

public virtual bool SupportsOperatorAll => true;
public virtual bool SupportsOperatorSome => true;

Expand Down
19 changes: 19 additions & 0 deletions src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public partial class HqlSqlWalker

private readonly IDictionary<string, string> _tokenReplacements;
private readonly IDictionary<string, NamedParameter> _namedParameters;
private readonly IDictionary<IParameterSpecification, IType> _guessedParameterTypes = new Dictionary<IParameterSpecification, IType>();

private JoinType _impliedJoinType;

Expand Down Expand Up @@ -98,6 +99,21 @@ public override void ReportError(RecognitionException e)
_parseErrorHandler.ReportError(e);
}

internal IStatement Transform()
{
var tree = (IStatement) statement().Tree;
// Use the guessed type in case we weren't been able to detect the type
foreach (var parameter in _parameters)
{
if (parameter.ExpectedType == null && _guessedParameterTypes.TryGetValue(parameter, out var guessedType))
{
parameter.ExpectedType = guessedType;
}
}

return tree;
}

/*
protected override void Mismatch(IIntStream input, int ttype, BitSet follow)
{
Expand Down Expand Up @@ -1072,7 +1088,10 @@ IASTNode GenerateNamedParameter(IASTNode delimiterNode, IASTNode nameNode)
// Add the parameter type information so that we are able to calculate functions return types
// when the parameter is used as an argument.
if (namedParameter.IsGuessedType)
{
_guessedParameterTypes[paramSpec] = namedParameter.Type;
parameter.GuessedType = namedParameter.Type;
}
else
parameter.ExpectedType = namedParameter.Type;
}
Expand Down
2 changes: 1 addition & 1 deletion src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ public IStatement Translate()
try
{
// Transform the tree.
_resultAst = (IStatement) hqlSqlWalker.statement().Tree;
_resultAst = hqlSqlWalker.Transform();
}
finally
{
Expand Down
27 changes: 3 additions & 24 deletions src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,9 @@ private static IType GetParameterType(
return candidateType;
}

if (visitor.NotGuessableConstants.Contains(constantExpression) && constantExpression.Value != null)
{
tryProcessInHql = true;
}

// Leave hql logic to determine the type except when the value is a char. Hql logic detects a char as a string, which causes an exception
// when trying to set a string db parameter with a char value.
tryProcessInHql = !(constantExpression.Value is char);
// No related MemberExpressions was found, guess the type by value or its type when null.
// When a numeric parameter is compared to different columns with different types (e.g. Where(o => o.Single >= singleParam || o.Double <= singleParam))
// do not change the parameter type, but instead cast the parameter when comparing with different column types.
Expand All @@ -174,13 +172,10 @@ private static IType GetParameterType(

private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor
{
private bool _hqlGenerator;
private readonly bool _removeMappedAsCalls;
private readonly System.Type _targetType;
private readonly IDictionary<ConstantExpression, NamedParameter> _parameters;
private readonly ISessionFactoryImplementor _sessionFactory;
private readonly ILinqToHqlGeneratorsRegistry _functionRegistry;
public readonly HashSet<ConstantExpression> NotGuessableConstants = new HashSet<ConstantExpression>();
public readonly Dictionary<ConstantExpression, IType> ConstantExpressions =
new Dictionary<ConstantExpression, IType>();
public readonly Dictionary<NamedParameter, HashSet<ConstantExpression>> ParameterConstants =
Expand All @@ -198,7 +193,6 @@ public ConstantTypeLocatorVisitor(
_targetType = targetType;
_sessionFactory = sessionFactory;
_parameters = parameters;
_functionRegistry = sessionFactory.Settings.LinqToHqlGeneratorsRegistry;
}

protected override Expression VisitBinary(BinaryExpression node)
Expand Down Expand Up @@ -269,16 +263,6 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
return node;
}

// For hql method generators we do not want to guess the parameter type here, let hql logic figure it out.
if (_functionRegistry.TryGetGenerator(node.Method, out _))
{
var origHqlGenerator = _hqlGenerator;
_hqlGenerator = true;
var expression = base.VisitMethodCall(node);
_hqlGenerator = origHqlGenerator;
return expression;
}

return base.VisitMethodCall(node);
}

Expand All @@ -289,11 +273,6 @@ protected override Expression VisitConstant(ConstantExpression node)
return node;
}

if (_hqlGenerator)
{
NotGuessableConstants.Add(node);
}

RelatedExpressions.Add(node, new HashSet<Expression>());
ConstantExpressions.Add(node, null);
if (!ParameterConstants.TryGetValue(param, out var set))
Expand Down