Click here to Skip to main content
65,938 articles
CodeProject is changing. Read more.
Articles / Languages / SQL

Creating .In and .NotIn Extension Methods for NHibernate 3 Linq Provider

5.00/5 (2 votes)
22 Jan 2012CPOL2 min read 25.3K  
Creating .In and .NotIn Extension methods for NHibernate 3 Linq provider

In Bringing the IN clause from SQL to C#, I have shown how to create extension methods for C# that mimic the “in” clause from SQL. I like these methods a lot, but they cannot be used in Linq to NHibernate queries, because it cannot interpret them by default. Luckily, it’s not that hard to extend NHibernate’s Linq provider behavior. Fabio Maulo has already blogged about the extension points here NHibernate LINQ provider extension, so I’m just going to jump straight into the code.

For starters, here is the extension method class we’re going to be using:

C#
public static class ObjectExtensions
{
    public static bool In<T>(this T @value, params T[] values)
    {
        return values.Contains(@value);
    }

    public static bool In<T>(this T @value, IQueryable<T> values)
    {
        return values.Contains(@value);
    }

    public static bool NotIn<T>(this T @value, params T[] values)
    {
        return !values.Contains(@value);
    }

    public static bool NotIn<T>(this T @value, IQueryable<T> values)
    {
        return !values.Contains(@value);
    }
}

These are very simple methods that let you use syntax like:

if(1.In(1,2,3) && 3.NotIn(1,2))
        ...

Notice that there are also overloads that accept an IQueryable as an argument. These are meant for use with subqueries, since I had a hard time getting NHibernate to generate them. What I want NH to generate is something like this:

SQL
...
where id in(select id from some_table where id > 100)

So here’s what I came up with for SQL Server:

C#
public class InGenerator : BaseHqlGeneratorForMethod
{
    public InGenerator()
    {
        SupportedMethods = new[]
        {
            ReflectionHelper.GetMethodDefinition(() => 
            ObjectExtensions.In(null, (object[]) null)),
            ReflectionHelper.GetMethodDefinition(() => 
            ObjectExtensions.In<object>(null, (IQueryable<object>) null)),
            ReflectionHelper.GetMethodDefinition(() => 
            ObjectExtensions.NotIn<object>(null, (object[]) null)),
            ReflectionHelper.GetMethodDefinition(() => 
        ObjectExtensions.NotIn<object>(null, (IQueryable<object>) null))
        };
    }

    public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, 
    ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, 
    IHqlExpressionVisitor visitor)
    {
        var value = visitor.Visit(arguments[0]).AsExpression();
        HqlTreeNode inClauseNode;

        if (arguments[1] is ConstantExpression)
            inClauseNode = BuildFromArray((Array) 
        ((ConstantExpression) arguments[1]).Value, treeBuilder);
        else
            inClauseNode = BuildFromExpression(arguments[1], visitor);

        HqlTreeNode inClause = treeBuilder.In(value, inClauseNode);

        if (method.Name == "NotIn")
            inClause = treeBuilder.BooleanNot((HqlBooleanExpression)inClause);

        return inClause;
    }

    private HqlTreeNode BuildFromExpression(Expression expression, 
                IHqlExpressionVisitor visitor)
    {
        //TODO: check if it's a valid expression for in clause, 
        //i.e. it selects only one column
        return visitor.Visit(expression).AsExpression();
    }

    private HqlTreeNode BuildFromArray(Array valueArray, HqlTreeBuilder treeBuilder)
    {
        var elementType = valueArray.GetType().GetElementType();

        if (!elementType.IsValueType && elementType != typeof(string))
            throw new ArgumentException("Only primitives and strings can be used");

        Type enumUnderlyingType = elementType.IsEnum ? 
        Enum.GetUnderlyingType(elementType) : null;
        var variants = new HqlExpression[valueArray.Length];

        for (int index = 0; index < valueArray.Length; index++)
        {
            var variant = valueArray.GetValue(index);
            var val = variant;

            if (elementType.IsEnum)
                val = Convert.ChangeType(variant, enumUnderlyingType);

            variants[index] = treeBuilder.Constant(val);
        }

        return treeBuilder.DistinctHolder(variants);
    }
}

It starts by listing the supported extension methods, which are the four methods shown at the beginning of this post. The BuildHql method creates the objects used to translate Linq expressions to HQL query. What we do here is build an instance of HqlIn class, giving it the expression we’re comparing (variable value) and the contents of in clause (variable inClauseNode). Two possible ways of calling the extension methods are handled here: an array of constants and an IQueryable. If the method that was called was NotIn, the HqlIn object instance is wrapped into an instance of HqlBooleanNot class, which effectively appends the keyword ‘not’ before in clause.

Building the subquery from IQueryable is easy enough, we just use NHibernate’s default way to handle that for us. What’s missing here is that the expression we’re passing should be checked to see if it’s valid for in clause in SQL – it must select only a single column from the subquery.

Building the list of possible values for in clause is a little trickier. What’s handled here currently is only an array of constants. Most of the code in BuildFromArray method deals with converting enums to base type values, so they can be used in query directly. What’s missing here is support for normal arrays (not constants, but arrays of variables), but it might not be necessary, because we can achieve similar functionality by using Contains Linq method in the query.

To use the new generator, we have to also create a custom generator registry (we derive from the default one to extend it). We use configuration to set the new registry before creating ISessionFactory:

C#
public class CustomLinqGeneratorsRegistry : DefaultLinqToHqlGeneratorsRegistry
{
    public CustomLinqGeneratorsRegistry()
    {
        RegisterGenerator(ReflectionHelper.GetMethodDefinition(() => 
            ObjectExtensions.In<object>(null, (object[]) null)),
                            new InGenerator());
        RegisterGenerator(ReflectionHelper.GetMethodDefinition(() => 
            ObjectExtensions.In<object>(null, (IQueryable<object>)null)),
                            new InGenerator());
        RegisterGenerator(ReflectionHelper.GetMethodDefinition(() => 
            ObjectExtensions.NotIn<object>(null, (object[]) null)),
                            new InGenerator());
        RegisterGenerator(ReflectionHelper.GetMethodDefinition(() => 
        ObjectExtensions.NotIn<object>(null, (IQueryable<object>) null)),
                            new InGenerator());
    }
}
...
configuration.LinqToHqlGeneratorsRegistry<CustomLinqGeneratorsRegistry>();

Now, taking a look at the queries we get, we see that the following Linq queries...

SQL
session.Query<Order>().Where(x => x.State.NotIn
    (OrderStates.Created, OrderStates.Executed)).ToArray();
...
session.Query<Category>().Where(x => x.Id.In(1, 2, 3)).ToArray();
...
var categories = session.Query<Category>()
    .Where(x => x.Name.NotIn(session.Query<Category>().Where
            (c => c.Name != "Var2"&& c.Id > 100)
                        .Select(c => c.Name))).ToArray();

...translate to the following SQL queries:

SQL
select
        order0_.OrderId as OrderId0_,
        order0_.state as state0_,
        order0_.Customer as Customer0_ 
    from
        Orders order0_ 
    where
        not (order0_.state in (1 , 2))
...
select
        category0_.CategoryId as CategoryId4_,
        category0_.Name as Name4_ 
    from
        Category category0_ 
    where
        category0_.CategoryId in (1 , 2 , 3)
...
select
    category0_.CategoryId as CategoryId4_,
    category0_.Name as Name4_ 
from
    Category category0_ 
where
    not (category0_.Name in (select
        category1_.Name 
    from
        Category category1_ 
    where
        category1_.Name<>'Var2' 
        and category1_.CategoryId>100));

These extension methods can now be used easily combined with Linq to NHibernate to simplify your queries.

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)