C# Reflection AttributeHelper

AttributeHelper is a collection of methods I’ve developed to help me retrieve Attribute information for Types and methods in C# (through Reflection). I’m publishing it on my blog because it’s too small to create a CodePlex project for and it’s something I want access to from wherever I am. If you find use for it too, then thats great.

It uses caching of attributes to speed up access, because retrieving type information is typically quite time consuming. For caching it uses a trivial Dictionary.

Here’s the code;

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
using System.Reflection;

namespace IdentityAudit.Utilities
{
    public static class AttributeHelper
    {
        private static Dictionary<object, List<Attribute>> _attributeCache = new Dictionary<object, List<Attribute>>();

        public static Dictionary<object, List<Attribute>> AttributeCache { get { return _attributeCache; } }

        // Types
        public static List<Attribute> GetTypeAttributes<TType>()
        {
            return GetTypeAttributes(typeof(TType));
        }

        public static List<Attribute> GetTypeAttributes(Type type)
        {
            return LockAndGetAttributes(type, tp => ((Type)tp).GetCustomAttributes(true));
        }

        public static List<TAttributeType> GetTypeAttributes<TAttributeType>(Type type, Func<TAttributeType, bool> predicate = null)
        {
            return
                GetTypeAttributes(type)
                    .Where<Attribute, TAttributeType>()
                    .Where(attr => predicate == null || predicate(attr))
                    .ToList();
        }

        public static List<TAttributeType> GetTypeAttributes<TType, TAttributeType>(Func<TAttributeType, bool> predicate = null)
        {
            return GetTypeAttributes(typeof(TType), predicate);
        }

        public static TAttributeType GetTypeAttribute<TType, TAttributeType>(Func<TAttributeType, bool> predicate = null)
        {
            return
                GetTypeAttribute(typeof(TType), predicate);
        }

        public static TAttributeType GetTypeAttribute<TAttributeType>(Type type, Func<TAttributeType, bool> predicate = null)
        {
            return
                GetTypeAttributes<TAttributeType>(type, predicate)
                    .FirstOrDefault();
        }

        public static bool HasTypeAttribute<TType, TAttributeType>(Func<TAttributeType, bool> predicate = null)
        {
            return HasTypeAttribute<TAttributeType>(typeof(TType), predicate);
        }

        public static bool HasTypeAttribute<TAttributeType>(Type type, Func<TAttributeType, bool> predicate = null)
        {
            return GetTypeAttribute<TAttributeType>(type, predicate) != null;
        }

        // Members and properties
        public static List<Attribute> GetMemberAttributes<TType>(Expression<Func<TType, object>> action)
        {
            return GetMemberAttributes(GetMember(action));
        }

        public static List<TAttributeType> GetMemberAttributes<TType, TAttributeType>(
            Expression<Func<TType, object>> action,
            Func<TAttributeType, bool> predicate = null)
            where TAttributeType : Attribute
        {
            return GetMemberAttributes<TAttributeType>(GetMember(action), predicate);
        }

        public static TAttributeType GetMemberAttribute<TType, TAttributeType>(
            Expression<Func<TType, object>> action,
            Func<TAttributeType, bool> predicate = null)
            where TAttributeType : Attribute
        {
            return GetMemberAttribute<TAttributeType>(GetMember(action), predicate);
        }

        public static bool HasMemberAttribute<TType, TAttributeType>(Expression<Func<TType, object>> action, Func<TAttributeType, bool> predicate = null) where TAttributeType : Attribute
        {
            return GetMemberAttribute(GetMember(action), predicate) != null;
        }

        // MemberInfo (and PropertyInfo since PropertyInfo inherits from MemberInfo)
        public static List<Attribute> GetMemberAttributes(this MemberInfo memberInfo)
        {
            return
                LockAndGetAttributes(memberInfo, mi => ((MemberInfo)mi).GetCustomAttributes(true));
        }

        public static List<TAttributeType> GetMemberAttributes<TAttributeType>(this MemberInfo memberInfo, Func<TAttributeType, bool> predicate = null) where TAttributeType : Attribute
        {
            return
                GetMemberAttributes(memberInfo)
                    .Where<Attribute, TAttributeType>()
                    .Where(attr => predicate == null || predicate(attr))
                    .ToList();
        }

        public static TAttributeType GetMemberAttribute<TAttributeType>(this MemberInfo memberInfo, Func<TAttributeType, bool> predicate = null) where TAttributeType : Attribute
        {
            return
                GetMemberAttributes<TAttributeType>(memberInfo, predicate)
                    .FirstOrDefault();
        }

        public static bool HasMemberAttribute<TAttributeType>(this MemberInfo memberInfo, Func<TAttributeType, bool> predicate = null) where TAttributeType : Attribute
        {
            return
                memberInfo.GetMemberAttribute<TAttributeType>(predicate) != null;
        }

        // Internal stuff
        private static IEnumerable<TType> Where<X, TType>(this IEnumerable<X> list)
        {
            return
                list
                    .Where(item => item is TType)
                    .Cast<TType>();
        }

        private static TType FirstOrDefault<X, TType>(this IEnumerable<X> list)
        {
            return
                list
                    .Where<X, TType>()
                    .FirstOrDefault();
        }

        private static List<Attribute> LockAndGetAttributes(object key, Func<object, object[]> retrieveValue)
        {
            return
                LockAndGet<object, List<Attribute>>(_attributeCache, key, mi => retrieveValue(mi).Cast<Attribute>().ToList());
        }

        // Method for thread safely executing slow method and storing the result in a dictionary
        private static TValue LockAndGet<TKey, TValue>(Dictionary<TKey, TValue> dictionary, TKey key, Func<TKey, TValue> retrieveValue)
        {
            TValue value = default(TValue);
            lock (dictionary)
            {
                if (dictionary.TryGetValue(key, out value))
                {
                    return value;
                }
            }

            value = retrieveValue(key);

            lock (dictionary)
            {
                if (dictionary.ContainsKey(key) == false)
                {
                    dictionary.Add(key, value);
                }

                return value;
            }
        }

        private static MemberInfo GetMember<T>(Expression<Func<T, object>> expression)
        {
            MemberExpression memberExpression = expression.Body as MemberExpression;

            if (memberExpression != null)
            {
                return memberExpression.Member;
            }

            UnaryExpression unaryExpression = expression.Body as UnaryExpression;

            if (unaryExpression != null)
            {
                memberExpression = unaryExpression.Operand as MemberExpression;

                if (memberExpression != null)
                {
                    return memberExpression.Member;
                }

                MethodCallExpression methodCall = unaryExpression.Operand as MethodCallExpression;
                if (methodCall != null)
                {
                    return methodCall.Method;
                }
            }

            return null;
        }
    }
}

And here’s a number of nunit tests;

[TestFixture]
class AttributeHelperTests
{
[MyAttribute(Name = "Steve")]
internal class Attributed
{
[MyAttribute(Name = "Bob")]
public bool HasAttributeProperty { get; set; }
public bool HasNoAttributeProperty { get; set; }

[MyAttribute(Name = "Stevie")]
public bool HasAttributeMember(int x)
{
return true;
}

public bool HasNoAttributeMember(int x)
{
return true;
}
}

internal class MyAttribute : Attribute
{
public string Name { get; set; }
}

[Test]
public void GetTypeAttributes_works()
{
Assert.That(AttributeHelper.GetTypeAttributes<Attributed>().Count, Is.EqualTo(1));
Assert.That(AttributeHelper.GetTypeAttributes<Attributed, MyAttribute>().Count, Is.EqualTo(1));
Assert.That(AttributeHelper.GetTypeAttributes<Attributed, MyAttribute>(attr => attr.Name == "Jane").Count, Is.EqualTo(0));
Assert.That(AttributeHelper.GetTypeAttributes<Attributed, MyAttribute>(attr => attr.Name == "Steve").Count, Is.EqualTo(1));
}

[Test]
public void GetTypeAttribute_works()
{
Assert.That(AttributeHelper.GetTypeAttribute<Attributed, MyAttribute>(), Is.Not.Null);
Assert.That(AttributeHelper.GetTypeAttribute<Attributed, SequentialAttribute>(), Is.Null);

Assert.That(AttributeHelper.GetTypeAttribute<Attributed, MyAttribute>(attr => attr.Name == "Jane"), Is.Null);
Assert.That(AttributeHelper.GetTypeAttribute<Attributed, MyAttribute>(attr => attr.Name == "Steve"), Is.Not.Null);
}

[Test]
public void HasTypeAttribute_works()
{
Assert.That(AttributeHelper.HasTypeAttribute<Attributed, MyAttribute>(), Is.True);
Assert.That(AttributeHelper.HasTypeAttribute<Attributed, SequentialAttribute>(), Is.False);

Assert.That(AttributeHelper.HasTypeAttribute<Attributed, MyAttribute>(attr => attr.Name == "Jane"), Is.False);
Assert.That(AttributeHelper.HasTypeAttribute<Attributed, MyAttribute>(attr => attr.Name == "Steve"), Is.True);
}

[Test]
public void GetMemberAttributes_works()
{
Assert.That(AttributeHelper.GetMemberAttributes<Attributed>(x => x.HasAttributeMember(0)).Count, Is.EqualTo(1));
Assert.That(AttributeHelper.GetMemberAttributes<Attributed, MyAttribute>(x => x.HasAttributeMember(0)).Count, Is.EqualTo(1));
Assert.That(AttributeHelper.GetMemberAttributes<Attributed>(x => x.HasNoAttributeMember(0)).Count, Is.EqualTo(0));

Assert.That(AttributeHelper.GetMemberAttributes<Attributed>(x => x.HasAttributeProperty).Count, Is.EqualTo(1));
Assert.That(AttributeHelper.GetMemberAttributes<Attributed, MyAttribute>(x => x.HasAttributeProperty).Count, Is.EqualTo(1));
Assert.That(AttributeHelper.GetMemberAttributes<Attributed>(x => x.HasNoAttributeProperty).Count, Is.EqualTo(0));

Assert.That(AttributeHelper.GetMemberAttributes<Attributed, MyAttribute>(x => x.HasAttributeMember(0), attr => attr.Name == "Stevie").Count, Is.EqualTo(1));
Assert.That(AttributeHelper.GetMemberAttributes<Attributed, MyAttribute>(x => x.HasAttributeMember(0), attr => attr.Name == "X").Count, Is.EqualTo(0));

Assert.That(AttributeHelper.GetMemberAttributes<Attributed, MyAttribute>(x => x.HasAttributeProperty, attr => attr.Name == "Bob").Count, Is.EqualTo(1));
Assert.That(AttributeHelper.GetMemberAttributes<Attributed, MyAttribute>(x => x.HasAttributeProperty, attr => attr.Name == "X").Count, Is.EqualTo(0));
}

[Test]
public void GetMemberAttribute_works()
{
Assert.That(AttributeHelper.GetMemberAttribute<Attributed, MyAttribute>(x => x.HasAttributeMember(0)), Is.Not.Null);
Assert.That(AttributeHelper.GetMemberAttribute<Attributed, MyAttribute>(x => x.HasNoAttributeMember(0)), Is.Null);

Assert.That(AttributeHelper.GetMemberAttribute<Attributed, MyAttribute>(x => x.HasAttributeMember(0), attr => attr.Name == "Stevie"), Is.Not.Null);
Assert.That(AttributeHelper.GetMemberAttribute<Attributed, MyAttribute>(x => x.HasAttributeMember(0), attr => attr.Name == "X"), Is.Null);
}

[Test]
public void HasMemberAttribute_works()
{
Assert.That(AttributeHelper.HasMemberAttribute<Attributed, MyAttribute>(x => x.HasAttributeMember(0)), Is.True);
Assert.That(AttributeHelper.HasMemberAttribute<Attributed, MyAttribute>(x => x.HasNoAttributeMember(0)), Is.False);

Assert.That(AttributeHelper.HasMemberAttribute<Attributed, MyAttribute>(x => x.HasAttributeMember(0), attr => attr.Name == "Stevie"), Is.True);
Assert.That(AttributeHelper.HasMemberAttribute<Attributed, MyAttribute>(x => x.HasAttributeMember(0), attr => attr.Name == "X"), Is.False);
}
}