Click here to Skip to main content
65,938 articles
CodeProject is changing. Read more.
Articles / Languages / C#

Repository and Unit of Work T4 Template for Entity Framework

4.00/5 (3 votes)
18 Jul 2010CPOL 35.2K  
Repository and Unit of Work T4 Template for Entity Framework

Introduction

Two weeks ago, I wrote the Revisiting the Repository and Unit of Work Patterns with Entity Framework post. One thing that I thought would be nice was to have an automatic code generation that would help me to build these patterns without sweating. So I sat down and created a T4 Template to auto generate the same patterns that I showed in the post.

The Code

One thing to understand is that the provided T4 Template isn't bullet proof and errors can occur (you can change the implementation as you like). In order to use it, copy and paste the code to a .tt file that needs to be located in the library of your Entity Data Model file (the T4 seeks edmx files in the current directory to auto generate the classes). Another option is to download the T4 Template from here.
So here it is:

C#
<#@ template debug="true" hostspecific="true" language="C#" #>
<#@ include file="EF.Utility.CS.ttinclude"#>
<#@ import namespace="System.IO" #>
<#@ output extension=".cs" #>
<#         
if(Errors.HasErrors)
{
    return String.Empty;
}

CodeGenerationTools code = new CodeGenerationTools(this)
	{FullyQualifySystemTypes = true, CamelCaseFields = false};
MetadataLoader loader = new MetadataLoader(this);

string open = "<";
string close = ">";
string SourceCsdlPath = FindEDMXFileName();
ReferenceCsdlPaths = new string[] {};
string namespaceName = code.VsNamespaceSuggestion();
ItemCollection = loader.CreateEdmItemCollection
	(SourceCsdlPath, ReferenceCsdlPaths.ToArray());
EntityContainer container = ItemCollection.GetItems<EntityContainer>().FirstOrDefault();
#>
using System;
using System.Collections.Generic;
using System.Data.Objects;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
 
namespace <#=namespaceName#>
{
    public interface IRepository<T> where T : class
    {    
        #region    Methods
    
        T GetById(int id);
        IEnumerable<T> GetAll();
        IEnumerable<T> Query(Expression<Func<T, bool>> filter);        
        void Add(T entity);
        void Remove(T entity);   
        
        #endregion
    }
    
    public abstract class Repository<T> : IRepository<T>
                                  where T : class
    {
        #region Members
 
        protected IObjectSet<T> _objectSet;
 
        #endregion
 
        #region Ctor
 
        public Repository(ObjectContext context)
        {
              _objectSet = context.CreateObjectSet<T>();
        }

        #endregion
 
        #region IRepository<T> Members
 
        public IEnumerable<T> GetAll()
        {
              return _objectSet;
        }
 
        public abstract T GetById(int id);
 
        public IEnumerable<T> Query(Expression<Func<T, bool>> filter)
        {
              return _objectSet.Where(filter);
        }
 
        public void Add(T entity)
        {
              _objectSet.AddObject(entity);
        }
 
        public void Remove(T entity)
        {
              _objectSet.DeleteObject(entity);
        }
 
        #endregion
      }
 
<#
    foreach (EntityType entity in 
	ItemCollection.GetItems<EntityType>().OrderBy(e => e.Name))
    {        
#>
    
    public partial class <#= entity.Name #>Repository : 
		Repository<#=open#><#=entity.Name#><#=close#>
    {
        #region Ctor
 
        public <#= entity.Name #>Repository(ObjectContext context)
               : base(context)
        {
        }
 
        #endregion
 
        #region Methods
 
        public override <#= entity.Name #> GetById(int id)   
        {
            return _objectSet.SingleOrDefault(e => e.<#= 
		entity.KeyMembers.First().Name #> == id);
        }
 
        #endregion        
    }
<# 
    }        
#>
        
  public interface IUnitOfWork
  {
      #region    Methods
    
    <#
        foreach (EntitySet set in container.BaseEntitySets.OfType<EntitySet>())
        {    
    #>
        IRepository<#= open #><#= set.ElementType.Name #><#= close #> 
		<#= set.Name #> { get; }   
    <# 
        }
    #>
    void Commit();
    
    #endregion
  }
 
  public partial class UnitOfWork : IUnitOfWork
  {
    #region Members
 
    private readonly ObjectContext _context;
    <#
        foreach (EntitySet set in container.BaseEntitySets.OfType<EntitySet>())
        {    
    #>
    private <#= set.ElementType.Name #>Repository _<#= set.Name.ToLower() #>;
    <# 
        }
    #>    
    #endregion
 
    #region Ctor
 
    public UnitOfWork(ObjectContext context)
    {
      if (context == null)
      {
        throw new ArgumentNullException("context wasn't supplied");
      }
 
      _context = context;
    }
 
    #endregion
 
    #region IUnitOfWork Members
 
    <#
        foreach (EntitySet set in container.BaseEntitySets.OfType<EntitySet>())
        {    
    #>
    public IRepository<#= open #><#= set.ElementType.Name #><#= close #> <#= set.Name #>
    {
        get
        {
            if (_<#= set.Name.ToLower() #> == null)
            {
                _<#= set.Name.ToLower() #> = new <#= 
		set.ElementType.Name #>Repository(_context);
            }
            return _<#= set.Name.ToLower() #>;
        }
    }
    <# 
        }
    #>    
    
    public void Commit()
    {
      _context.SaveChanges();
    }
 
    #endregion
  }
}
<#+
public string SourceCsdlPath{ get; set; }
public EdmItemCollection ItemCollection{ get; set; }
public IEnumerable<string> ReferenceCsdlPaths{ get; set; }
 
string FindEDMXFileName()
{            
    string[] entityFrameworkFiles = Directory.GetFiles
		(Host.ResolvePath(string.Empty), "*.edmx");
    if(entityFrameworkFiles.Length > 0)
    {
        return entityFrameworkFiles[0];
    }
    
    return string.Empty;
}
#>

And this is the generated code that I get after running the T4 Template on my testing edmx file:

C#
using System;
using System.Collections.Generic;
using System.Data.Objects;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
 
namespace ConsoleApplication1
{
    public interface IRepository<T> where T : class
    {    
        #region    Methods
    
        T GetById(int id);
        IEnumerable<T> GetAll();
        IEnumerable<T> Query(Expression<Func<T, bool>> filter);        
        void Add(T entity);
        void Remove(T entity);   
        
        #endregion
    }
    
    public abstract class Repository<T> : IRepository<T>
                                  where T : class
    {
        #region Members
 
        protected IObjectSet<T> _objectSet;
 
        #endregion
 
        #region Ctor
 
        public Repository(ObjectContext context)
        {
              _objectSet = context.CreateObjectSet<T>();
        }
 
        #endregion
 
        #region IRepository<T> Members
 
        public IEnumerable<T> GetAll()
        {
              return _objectSet;
        }
 
        public abstract T GetById(int id);
 
        public IEnumerable<T> Query(Expression<Func<T, bool>> filter)
        {
              return _objectSet.Where(filter);
        }
 
        public void Add(T entity)
        {
              _objectSet.AddObject(entity);
        }
 
        public void Remove(T entity)
        {
              _objectSet.DeleteObject(entity);
        }
 
        #endregion
      }
 
    public partial class CourseRepository : Repository<Course>
    {
        #region Ctor

        public CourseRepository(ObjectContext context)
               : base(context)
        {
        }
 
        #endregion
 
        #region Methods
 
        public override Course GetById(int id)   
        {
            return _objectSet.SingleOrDefault(e => e.CourseID == id);
        }
 
        #endregion        
    }
    
    public partial class DepartmentRepository : Repository<Department>
    {
        #region Ctor
 
        public DepartmentRepository(ObjectContext context)
               : base(context)
        {
        }
 
        #endregion
 
        #region Methods
 
        public override Department GetById(int id)   
        {
            return _objectSet.SingleOrDefault(e => e.DepartmentID == id);
        }
 
        #endregion        
    }
    
    public partial class EnrollmentRepository : Repository<Enrollment>
    {
        #region Ctor
 
        public EnrollmentRepository(ObjectContext context)
               : base(context)
        {
        }
 
        #endregion
 
        #region Methods
 
        public override Enrollment GetById(int id)   
        {
            return _objectSet.SingleOrDefault(e => e.EnrollmentID == id);
        }
 
        #endregion        
    }
    
    public partial class PersonRepository : Repository<Person>
    {
        #region Ctor
 
        public PersonRepository(ObjectContext context)
               : base(context)
        {
        }
 
        #endregion
 
        #region Methods
 
        public override Person GetById(int id)   
        {
            return _objectSet.SingleOrDefault(e => e.PersonID == id);
        }
 
        #endregion        
    }
        
  public interface IUnitOfWork
  {
      #region    Methods
    
            IRepository<Course> Courses { get; }   
            IRepository<Department> Departments { get; }   
            IRepository<Enrollment> Enrollments { get; }   
            IRepository<Person> People { get; }   
        void Commit();
    
    #endregion
  }
 
  public partial class UnitOfWork
  {
    #region Members
 
    private readonly ObjectContext _context;
        private CourseRepository _courses;
        private DepartmentRepository _departments;
        private EnrollmentRepository _enrollments;
        private PersonRepository _people;
        
    #endregion
 
    #region Ctor
 
    public UnitOfWork(ObjectContext context)
    {
      if (context == null)
      {
        throw new ArgumentNullException("context wasn't supplied");
      }
 
      _context = context;
    }
 
    #endregion
 
    #region IUnitOfWork Members
 
        public IRepository<Course> Courses
    {
        get
        {
            if (_courses == null)
            {
                _courses = new CourseRepository(_context);
            }
            return _courses;
        }
    }
        public IRepository<Department> Departments
    {
        get
        {
            if (_departments == null)
            {
                _departments = new DepartmentRepository(_context);
            }
            return _departments;
        }
    }
        public IRepository<Enrollment> Enrollments
    {
        get
        {
            if (_enrollments == null)
            {
                _enrollments = new EnrollmentRepository(_context);
            }
            return _enrollments;
        }
    }
        public IRepository<Person> People
    {
        get
        {
            if (_people == null)
            {
                _people = new PersonRepository(_context);
            }
            return _people;
        }
    }
            
    public void Commit()
    {
      _context.SaveChanges();
    }
 
    #endregion
  }
}

Enjoy!

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)