⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 codegenerator.cs

📁 drools 一个开放源码的规则引擎
💻 CS
字号:
using System;
using System.CodeDom;
using org.drools.rule;
using System.Collections.Generic;
using org.drools.spi;
using System.Reflection;

namespace org.drools.semantics.dotnet
{
	/// <summary>
	/// Generates .NET code using CodeDOM
	/// </summary>
	internal class CodeGenerator
	{
		public static CodeCompileUnit CreateCondition(string namespaceName, string className, 
			string methodName, Declaration[] parameters, string expression, 
			DotNetImporter importer, DotNetFunctions functions)
		{
			try
			{
				CodeCompileUnit code = GenerateClassCode(namespaceName, className, methodName, parameters,
					typeof(bool));
				CodeMemberMethod method = (CodeMemberMethod)code.Namespaces[0].Types[0].Members[0];
				CodeStatement returnStatement = new CodeMethodReturnStatement(
					new CodeSnippetExpression(expression));
				method.Statements.Add(returnStatement);

				code = AddFunctions(code, functions);
				code = AddReferencedAssemblies(code);
				code = AddImports(code, importer);

				return code;
			}
			catch (Exception e)
			{
				throw new CodeGenerationException("Unable to create condition for expression [" + 
					expression + "]", e);
			}
		}

		public static CodeCompileUnit CreateConsequence(string namespaceName, string className,
			string methodName, Declaration[] parameters, string expression,
			DotNetImporter importer, DotNetFunctions functions)
		{
			try
			{
				CodeCompileUnit code = GenerateClassCode(namespaceName, className, methodName, parameters,
					typeof(void));
				CodeMemberMethod method = (CodeMemberMethod)code.Namespaces[0].Types[0].Members[0];
				method.Statements.Add(new CodeSnippetExpression(expression));

				code = AddFunctions(code, functions);
				code = AddReferencedAssemblies(code);
				code = AddImports(code, importer);

				return code;
			}
			catch (Exception e)
			{
				throw new CodeGenerationException("Unable to create consequence for expression [" +
					expression + "]", e);
			}
		}

		private static CodeCompileUnit GenerateClassCode(string namespaceName, string className,
			string methodName, Declaration[] parameters, Type returnType)
		{
			//Generate class code
			CodeCompileUnit code = new CodeCompileUnit();
			CodeNamespace namespaceCode = new CodeNamespace(namespaceName);
			code.Namespaces.Add(namespaceCode);
			CodeTypeDeclaration classCode = new CodeTypeDeclaration(className);
			namespaceCode.Types.Add(classCode);
			CodeMemberMethod methodCode = new CodeMemberMethod();
			classCode.Members.Add(methodCode);
			methodCode.Name = methodName;
			methodCode.ReturnType = new CodeTypeReference(returnType);
			methodCode.Attributes = MemberAttributes.Public;

			foreach (Declaration d in parameters)
			{
				DotNetObjectType o = d.getObjectType() as DotNetObjectType;
				if (o == null)
				{
					throw new Exception("Parameters must be of type " + typeof(DotNetObjectType).FullName + ".");
				}
				CodeParameterDeclarationExpression parameter = new
					CodeParameterDeclarationExpression(o.Type, d.getIdentifier());
				parameter.Direction = FieldDirection.In;
				methodCode.Parameters.Add(parameter);
			}

			CodeParameterDeclarationExpression droolsParam = new CodeParameterDeclarationExpression(
				typeof(KnowledgeHelper), "drools");
			droolsParam.Direction = FieldDirection.In;
			methodCode.Parameters.Add(droolsParam);
			return code;
		}

		private static CodeCompileUnit AddFunctions(CodeCompileUnit code, DotNetFunctions functions)
		{
			if (functions != null)
			{
				CodeSnippetTypeMember functionMember = new CodeSnippetTypeMember(functions.getText());
				code.Namespaces[0].Types[0].Members.Add(functionMember);
			}
			return code;
		}

		private static CodeCompileUnit AddImports(CodeCompileUnit code, DotNetImporter importer)
		{
			IList<string> imports = new List<string>();

			//Look thru class for types to import
			foreach (CodeTypeMember member in code.Namespaces[0].Types[0].Members)
			{
				CodeMemberMethod method = member as CodeMemberMethod;
				if (method != null)
				{
					//Return Type
					string returnTypeNS = GetNamespace(method.ReturnType);
					if (! imports.Contains(returnTypeNS)) imports.Add(returnTypeNS);

					//Parameters
					foreach(CodeParameterDeclarationExpression parameter in method.Parameters)
					{
						string parameterNS = GetNamespace(parameter.Type);
						if (! imports.Contains(parameterNS)) imports.Add(parameterNS);
					}
				}
			}

			//Import any types specified by importer
			foreach (string entry in importer.getImports().toArray(new string[]{}))
			{
				if (!imports.Contains(entry)) imports.Add(entry);
			}

			//Add to code
			foreach (string import in imports)
			{
				code.Namespaces[0].Imports.Add(new CodeNamespaceImport(import));
			}
			return code;
		}

		private static CodeCompileUnit AddReferencedAssemblies(CodeCompileUnit code)
		{
			foreach (CodeTypeMember member in code.Namespaces[0].Types[0].Members)
			{
				CodeMemberMethod method = member as CodeMemberMethod;
				if (method != null)
				{
					//Return Type
					string returnTypeAssembly = GetAssemblyReference(method.ReturnType);
					if (! code.ReferencedAssemblies.Contains(returnTypeAssembly))
						code.ReferencedAssemblies.Add(returnTypeAssembly);

					//Parameters
					foreach (CodeParameterDeclarationExpression parameter in method.Parameters)
					{
						string parameterAssembly = GetAssemblyReference(parameter.Type);
						if (! code.ReferencedAssemblies.Contains(parameterAssembly))
							code.ReferencedAssemblies.Add(parameterAssembly);
					}
				}
			}
			return code;
		}

		private static string GetNamespace(CodeTypeReference typeRef)
		{
			return SearchAppDomainForType(typeRef.BaseType).Namespace;
		}

		private static string GetAssemblyReference(CodeTypeReference typeRef)
		{
			return SearchAppDomainForType(typeRef.BaseType).Assembly.Location;
		}

		private static Type SearchAppDomainForType(string typeName)
		{
			Type type = null;
			foreach (Assembly assembly in AppDomain.CurrentDomain.GetAssemblies())
			{
				type = assembly.GetType(typeName);
				if (type != null) return type;
			}
			if (type == null) throw new Exception("Unable to find type [" + typeName + "].");
			return type;
		}
	}
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -