Introduction
Two weeks ago, I wrote the Revisiting the Repository and Unit of Work Patterns with Entity Framework post. One thing that I thought would be nice was to have an automatic code generation that would help me to build these patterns without sweating. So I sat down and created a T4 Template to auto generate the same patterns that I showed in the post.
The Code
One thing to understand is that the provided T4 Template isn't bullet proof and errors can occur (you can change the implementation as you like). In order to use it, copy and paste the code to a .tt file that needs to be located in the library of your Entity Data Model file (the T4 seeks edmx files in the current directory to auto generate the classes). Another option is to download the T4 Template from here.
So here it is:
<#@ template debug="true" hostspecific="true" language="C#" #>
<#@ include file="EF.Utility.CS.ttinclude"#>
<#@ import namespace="System.IO" #>
<#@ output extension=".cs" #>
<#
if(Errors.HasErrors)
{
return String.Empty;
}
CodeGenerationTools code = new CodeGenerationTools(this)
{FullyQualifySystemTypes = true, CamelCaseFields = false};
MetadataLoader loader = new MetadataLoader(this);
string open = "<";
string close = ">";
string SourceCsdlPath = FindEDMXFileName();
ReferenceCsdlPaths = new string[] {};
string namespaceName = code.VsNamespaceSuggestion();
ItemCollection = loader.CreateEdmItemCollection
(SourceCsdlPath, ReferenceCsdlPaths.ToArray());
EntityContainer container = ItemCollection.GetItems<EntityContainer>().FirstOrDefault();
#>
using System;
using System.Collections.Generic;
using System.Data.Objects;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
namespace <#=namespaceName#>
{
public interface IRepository<T> where T : class
{
#region Methods
T GetById(int id);
IEnumerable<T> GetAll();
IEnumerable<T> Query(Expression<Func<T, bool>> filter);
void Add(T entity);
void Remove(T entity);
#endregion
}
public abstract class Repository<T> : IRepository<T>
where T : class
{
#region Members
protected IObjectSet<T> _objectSet;
#endregion
#region Ctor
public Repository(ObjectContext context)
{
_objectSet = context.CreateObjectSet<T>();
}
#endregion
#region IRepository<T> Members
public IEnumerable<T> GetAll()
{
return _objectSet;
}
public abstract T GetById(int id);
public IEnumerable<T> Query(Expression<Func<T, bool>> filter)
{
return _objectSet.Where(filter);
}
public void Add(T entity)
{
_objectSet.AddObject(entity);
}
public void Remove(T entity)
{
_objectSet.DeleteObject(entity);
}
#endregion
}
<#
foreach (EntityType entity in
ItemCollection.GetItems<EntityType>().OrderBy(e => e.Name))
{
#>
public partial class <#= entity.Name #>Repository :
Repository<#=open#><#=entity.Name#><#=close#>
{
#region Ctor
public <#= entity.Name #>Repository(ObjectContext context)
: base(context)
{
}
#endregion
#region Methods
public override <#= entity.Name #> GetById(int id)
{
return _objectSet.SingleOrDefault(e => e.<#=
entity.KeyMembers.First().Name #> == id);
}
#endregion
}
<#
}
#>
public interface IUnitOfWork
{
#region Methods
<#
foreach (EntitySet set in container.BaseEntitySets.OfType<EntitySet>())
{
#>
IRepository<#= open #><#= set.ElementType.Name #><#= close #>
<#= set.Name #> { get; }
<#
}
#>
void Commit();
#endregion
}
public partial class UnitOfWork : IUnitOfWork
{
#region Members
private readonly ObjectContext _context;
<#
foreach (EntitySet set in container.BaseEntitySets.OfType<EntitySet>())
{
#>
private <#= set.ElementType.Name #>Repository _<#= set.Name.ToLower() #>;
<#
}
#>
#endregion
#region Ctor
public UnitOfWork(ObjectContext context)
{
if (context == null)
{
throw new ArgumentNullException("context wasn't supplied");
}
_context = context;
}
#endregion
#region IUnitOfWork Members
<#
foreach (EntitySet set in container.BaseEntitySets.OfType<EntitySet>())
{
#>
public IRepository<#= open #><#= set.ElementType.Name #><#= close #> <#= set.Name #>
{
get
{
if (_<#= set.Name.ToLower() #> == null)
{
_<#= set.Name.ToLower() #> = new <#=
set.ElementType.Name #>Repository(_context);
}
return _<#= set.Name.ToLower() #>;
}
}
<#
}
#>
public void Commit()
{
_context.SaveChanges();
}
#endregion
}
}
<#+
public string SourceCsdlPath{ get; set; }
public EdmItemCollection ItemCollection{ get; set; }
public IEnumerable<string> ReferenceCsdlPaths{ get; set; }
string FindEDMXFileName()
{
string[] entityFrameworkFiles = Directory.GetFiles
(Host.ResolvePath(string.Empty), "*.edmx");
if(entityFrameworkFiles.Length > 0)
{
return entityFrameworkFiles[0];
}
return string.Empty;
}
#>
And this is the generated code that I get after running the T4 Template on my testing edmx file:
using System;
using System.Collections.Generic;
using System.Data.Objects;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
namespace ConsoleApplication1
{
public interface IRepository<T> where T : class
{
#region Methods
T GetById(int id);
IEnumerable<T> GetAll();
IEnumerable<T> Query(Expression<Func<T, bool>> filter);
void Add(T entity);
void Remove(T entity);
#endregion
}
public abstract class Repository<T> : IRepository<T>
where T : class
{
#region Members
protected IObjectSet<T> _objectSet;
#endregion
#region Ctor
public Repository(ObjectContext context)
{
_objectSet = context.CreateObjectSet<T>();
}
#endregion
#region IRepository<T> Members
public IEnumerable<T> GetAll()
{
return _objectSet;
}
public abstract T GetById(int id);
public IEnumerable<T> Query(Expression<Func<T, bool>> filter)
{
return _objectSet.Where(filter);
}
public void Add(T entity)
{
_objectSet.AddObject(entity);
}
public void Remove(T entity)
{
_objectSet.DeleteObject(entity);
}
#endregion
}
public partial class CourseRepository : Repository<Course>
{
#region Ctor
public CourseRepository(ObjectContext context)
: base(context)
{
}
#endregion
#region Methods
public override Course GetById(int id)
{
return _objectSet.SingleOrDefault(e => e.CourseID == id);
}
#endregion
}
public partial class DepartmentRepository : Repository<Department>
{
#region Ctor
public DepartmentRepository(ObjectContext context)
: base(context)
{
}
#endregion
#region Methods
public override Department GetById(int id)
{
return _objectSet.SingleOrDefault(e => e.DepartmentID == id);
}
#endregion
}
public partial class EnrollmentRepository : Repository<Enrollment>
{
#region Ctor
public EnrollmentRepository(ObjectContext context)
: base(context)
{
}
#endregion
#region Methods
public override Enrollment GetById(int id)
{
return _objectSet.SingleOrDefault(e => e.EnrollmentID == id);
}
#endregion
}
public partial class PersonRepository : Repository<Person>
{
#region Ctor
public PersonRepository(ObjectContext context)
: base(context)
{
}
#endregion
#region Methods
public override Person GetById(int id)
{
return _objectSet.SingleOrDefault(e => e.PersonID == id);
}
#endregion
}
public interface IUnitOfWork
{
#region Methods
IRepository<Course> Courses { get; }
IRepository<Department> Departments { get; }
IRepository<Enrollment> Enrollments { get; }
IRepository<Person> People { get; }
void Commit();
#endregion
}
public partial class UnitOfWork
{
#region Members
private readonly ObjectContext _context;
private CourseRepository _courses;
private DepartmentRepository _departments;
private EnrollmentRepository _enrollments;
private PersonRepository _people;
#endregion
#region Ctor
public UnitOfWork(ObjectContext context)
{
if (context == null)
{
throw new ArgumentNullException("context wasn't supplied");
}
_context = context;
}
#endregion
#region IUnitOfWork Members
public IRepository<Course> Courses
{
get
{
if (_courses == null)
{
_courses = new CourseRepository(_context);
}
return _courses;
}
}
public IRepository<Department> Departments
{
get
{
if (_departments == null)
{
_departments = new DepartmentRepository(_context);
}
return _departments;
}
}
public IRepository<Enrollment> Enrollments
{
get
{
if (_enrollments == null)
{
_enrollments = new EnrollmentRepository(_context);
}
return _enrollments;
}
}
public IRepository<Person> People
{
get
{
if (_people == null)
{
_people = new PersonRepository(_context);
}
return _people;
}
}
public void Commit()
{
_context.SaveChanges();
}
#endregion
}
}
Enjoy!