//
// $Id: XmlProcessor.cpp,v 1.17 2005/03/28 15:59:42 bakerj Exp $
//
//****************************************************************************************//
// Copyright (c) 2005, The MITRE Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without modification, are
// permitted provided that the following conditions are met:
//
//     * Redistributions of source code must retain the above copyright notice, this list
//       of conditions and the following disclaimer.
//     * Redistributions in binary form must reproduce the above copyright notice, this 
//       list of conditions and the following disclaimer in the documentation and/or other
//       materials provided with the distribution.
//     * Neither the name of The MITRE Corporation nor the names of its contributors may be
//       used to endorse or promote products derived from this software without specific 
//       prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY 
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
// OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT 
// SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
// OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
// TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
// EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
//****************************************************************************************//
#include "XmlProcessor.h"

//****************************************************************************************//
//								XmlProcessor Class										  //	
//****************************************************************************************//

XmlProcessor::XmlProcessor()
{
	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Init the XmlProcessor
	//	Throw an exception if there is an error.
	//
	// -----------------------------------------------------------------------

    try 
	{
        XMLPlatformUtils::Initialize();
    }
    catch (const XMLException& toCatch) 
	{
        string errMsg = "Error:  An error occured durring initialization of the xml utilities:\n";
        errMsg.append(XmlCommon::ToString(toCatch.getMessage()));
		errMsg.append("\n");

		throw XmlProcessorException(errMsg);
    }	
}

XmlProcessor::~XmlProcessor()
{
	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Clean up after the DOMOvalXMLAnalyzer. 
	//
	// -----------------------------------------------------------------------

	//  Delete the parser itself.  Must be done prior to calling Terminate, below.
	if(parser != NULL)
		parser->release();

	XMLPlatformUtils::Terminate();

}


XERCES_CPP_NAMESPACE_QUALIFIER DOMDocument* XmlProcessor::ParseFile(string filePathIn, bool validate)
{
	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Parse the specified file and return a DOMDocument. 
	//	'filePathIn' should be the complete path to the desired file.
	//
	//	2/25/2004 - Added validate paramater. 
	//		The validate flag is used to indicate whether the xml file 
	//		should be checked with a schema file. 
	//			- validate == (false)	- <DEFAULT> never validate the xml
	//			- validate == (true)	- always validate the xml
	//		When validating and xml file the schema must be in the same 
	//		directory as the file. 
	//
	// -----------------------------------------------------------------------
	
    XERCES_CPP_NAMESPACE_QUALIFIER DOMDocument *resultDocument = NULL;

    // Instantiate the DOM parser.
    static const XMLCh gLS[] = { chLatin_L, chLatin_S, chNull };
    DOMImplementation *impl = DOMImplementationRegistry::getDOMImplementation(gLS);
    parser = ((DOMImplementationLS*)impl)->createDOMBuilder(DOMImplementationLS::MODE_SYNCHRONOUS, 0);

	///////////////////////////////////////////////////////
    //	Set fetuares on the builder
	///////////////////////////////////////////////////////

	//	If validating the document namespaces and Schema must be set to true
	if(validate)
	{
		parser->setFeature(XMLUni::fgDOMNamespaces, true);
		parser->setFeature(XMLUni::fgXercesSchema, true);
		parser->setFeature(XMLUni::fgXercesSchemaFullChecking, true);
		//	Treat validation errors as fatal - default is false
		//	The parser, by default will exit after the first fatal error.
		parser->setFeature(XMLUni::fgXercesValidationErrorAsFatal, true); 
		parser->setFeature(XMLUni::fgDOMValidation, true);

	}else
	{
		//	Set all validation features to false
		parser->setFeature(XMLUni::fgDOMNamespaces, false);
		parser->setFeature(XMLUni::fgXercesSchema, false);
		parser->setFeature(XMLUni::fgXercesSchemaFullChecking, false);
		parser->setFeature(XMLUni::fgDOMValidation, false);
	}
	
	//	Don't read in comments 
	parser->setFeature(XMLUni::fgDOMComments, false);
    //	Enable DataType normalization - default is off
    parser->setFeature(XMLUni::fgDOMDatatypeNormalization, true);

	// Create a new DOMErrorHandler
	// and set it to the builder
	XmlProcessorErrorHandler *errHandler = new XmlProcessorErrorHandler();
	parser->setErrorHandler(errHandler);

    try 
	{
		// reset document pool
		parser->resetDocumentPool();
        resultDocument = parser->parseURI(filePathIn.c_str());
    }
    catch (const XMLException& toCatch) 
	{
		string error = "Error while parsing xml file:";
		error.append(filePathIn);
		error.append("\n\tMessage: \n\t");
		error.append(XmlCommon::ToString(toCatch.getMessage()));

		throw XmlProcessorException(error);
    }
    catch (const DOMException& toCatch) 
	{
		string error = "Error while parsing xml file:";
		error.append(filePathIn);
		error.append("\n\tMessage: \n\t");
		error.append(XmlCommon::ToString(toCatch.msg));

		throw XmlProcessorException(error);
    }
    catch (...) 
	{
        string error = "Error while parsing xml file:";
		error.append(filePathIn);
		error.append("\n\tMessage: \n\tUnknown message");

		throw XmlProcessorException(error);
    }

	if(errHandler->getSawErrors())
	{
		string error = "Error while parsing xml file:";
		error.append(errHandler->getErrorMessages());
		throw XmlProcessorException(error);
	}

	return resultDocument;
}

XERCES_CPP_NAMESPACE_QUALIFIER DOMDocument* XmlProcessor::PruneDOMDocument(XERCES_CPP_NAMESPACE_QUALIFIER DOMDocument* ovalDoc)
{
	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Removes all oval definitions and tests not specified in classType, StatusType and familyType from the DOMDocument.
	//
	// -----------------------------------------------------------------------

	DOMElement *definitions, *software, *configuration, *tests;
	DOMNode *def = NULL;
	DOMNode *test = NULL;
	DOMNodeList *defList, *testList;
	bool removeDefinition = false;
    taVector testContainer;
    TestArray* testArray = NULL;

	// get a vector of all compound tests 
	NodeVector *compoundTestVector = XmlCommon::FindAllNodes(ovalDoc, "compound_test");

	//	get a ptr to the definitions node in the oval document.
	definitions = (DOMElement*)XmlCommon::FindNodeNS(ovalDoc, "definitions");

	//	get a list of the child nodes
	defList = definitions->getElementsByTagName(XMLString::transcode ("definition"));

	//	Loop through all the nodes in defList
	unsigned int index = 0;
	while(index < defList->getLength()) {

		// get next child
		def = defList->item(index++);

		if(def->getNodeType() == DOMNode::ELEMENT_NODE) {
			// reset flag
			removeDefinition = false;
			string ovalID = "";

			try {
				//Get node's Class value
				string classValue = XmlCommon::GetAttributeByName(def, "class");
				ovalID = XmlCommon::GetAttributeByName(def, "id");

				//Compare class value with desired type
				if (Common::CompareClassType(classValue)) {
					//Get Status element value
					DOMNode* status = XmlCommon::FindNodeNS(def, "status");
					if (status == NULL)
						throw XmlProcessorException("Error: Unable to locate the required status element in definition. Oval " + ovalID + " will not be processed.", ERROR_WARN);
					
					string statusValue = XmlCommon::GetDataNodeValue(status);
					//Compare status value with desired type
					if (Common::CompareStatusType(statusValue)) {
						//get family value
						DOMNode* affected = XmlCommon::FindNodeNS(def, "affected");
						if (affected != NULL) {
							//If not null then extract value
							string familyValue = XmlCommon::GetAttributeByName(affected, "family");
							//Compare family value with desired type
							if (Common::CompareFamilyType(familyValue)) {
								//add tests to vector 
								DOMNode* criteria = XmlCommon::FindNodeNS(def, "criteria");
								//check if criteria is not null
								if (criteria != NULL) {
									unsigned int testLoop = 0;
									software = (DOMElement*)XmlCommon::FindNodeNS(criteria, "software");

									//check if software is not null
									if (software != NULL) {
										
										testList = software->getElementsByTagName(XMLString::transcode ("criterion"));
										// Loop through all nodes in the testList
										while(testLoop < testList->getLength()) {
											test = testList->item(testLoop++);
											string testID = XmlCommon::GetAttributeByName(test, "test_ref");
											bool addedTest = addTest(&testContainer, testID);
											if(testID.substr(0,3) == "cmp" && addedTest) 
												processCompoundTest(&testContainer, testID, compoundTestVector);

										} //while loop
									}//software !=null
									configuration = (DOMElement*)XmlCommon::FindNodeNS(criteria, "configuration");
									//check if configuration is not null
									if (configuration != NULL) {
										testList = NULL;
										test = NULL;
										testList = configuration->getElementsByTagName(XMLString::transcode ("criterion"));
										// Loop through all nodes in the testList
										testLoop = 0;
										while(testLoop < testList->getLength()) {
											test = testList->item(testLoop++);
											string testID = XmlCommon::GetAttributeByName(test, "test_ref");
											bool addedTest = addTest(&testContainer, testID);
											if(testID.substr(0,3) == "cmp" && addedTest) 
												processCompoundTest(&testContainer, testID, compoundTestVector);

										} //while loop
									}//configuration != NUL
								}//criteria != NULL
							}// familyValue == family Type
							else removeDefinition = true;
						} // family != NULL
					}// statusValue == statusType
					else removeDefinition = true;
				}// classValue == classType
				else removeDefinition = true;

			} catch(Exception ex) {
				removeDefinition = true;
				
				string errMsg = "";
				if(ovalID.compare("") != 0) {
					errMsg = "An error has occured while selecting a definition to analyze based on specified criteria. Current defintion: "
						   + ovalID
						   + "\n"
						   + ex.GetErrorMessage()
						   + "\n";
				} else {
					errMsg = "An error has occured while selecting a definition to analyze based on specified criteria.\n"
						   + ex.GetErrorMessage()
						   + "\n";
				}

				// Print the error message if in verbose mode
				if(Log::verboseMode) {
					cout << errMsg << endl;
				}
                
				// Log the error message
				Log::WriteLog(errMsg + "\n");

			} catch(...) {
				removeDefinition = true;

				string errMsg = "";
				if(ovalID.compare("") != 0) {
					errMsg = "An error has occured while selecting a definition to analyze based on specified criteria. Current defintion: "
						   + ovalID
						   + "\n";
				} else {
					errMsg = "An error has occured while selecting a definition to analyze based on specified criteria.\n";
				}

				// Print the error message if in verbose mode
				if(Log::verboseMode) {
					cout << errMsg << endl;
				}
                
				// Log the error message
				Log::WriteLog(errMsg);
			}

			if (removeDefinition){
				def = definitions->removeChild(def);
				def->release();
				removeDefinition = false;
				index--;
			}//removeDefinition
		}//def is an element node			
	}//definitions loop
	
	
	
	

	//	get a ptr to the tests node in the oval document.
	tests = (DOMElement*)XmlCommon::FindNodeNS(ovalDoc, "tests");

	//	get a list of the child nodes
	testList = tests->getChildNodes();

    // Prune any tests that are not needed
	for(unsigned int k = 0; k < testList->getLength(); k++) {

		test = testList->item(k);
		//Ensure we are dealing with an ELEMENT_NODE
		if(test->getNodeType() == DOMNode::ELEMENT_NODE) {
			
			//Extract the testID
			string fullTestID = XmlCommon::GetAttributeByName(test, "id");
            
			//Separate test type from test id
            string tempType = fullTestID.substr(0, 3);
            int testID = atoi((fullTestID.substr(4, fullTestID.length()-4)).c_str());
			
			//Find the appropriate array for testType
			bool isFound = false;
			for(unsigned int i = 0; i < testContainer.size(); i++) {
				if (testContainer[i]->checkTestType(tempType)) {
					testArray = testContainer[i];
					isFound = true;
					break;
				}
			}	
			
			// If the test type is not found just remove the test.
			if(!isFound) {
				test = tests->removeChild(test);
				test->release();
			} else {
				// If the test is not in the array remove it.
				if (testArray->getTestBit(testID) == false) {
					test = tests->removeChild(test);
					test->release();
				}
			}
			
		}//testID == ELEMENT_Node
	}//for loop

	return ovalDoc;
}//PruneDOMDocument

void XmlProcessor::processCompoundTest(taVector *testContainer, string testID,  NodeVector *compoundTestVector)
{

	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Given a vector of test IDs to keep, a compound test ID to add and the DOM nodes of the compound tests
	//  this function will add the compound test and its subtests to the vector of test IDs.
	//
	// -----------------------------------------------------------------------



	DOMNode* compoundTest = NULL;
	bool foundCmpTest = false;

	// Set compoundTest to the DOMNode of the compound test in question (testID
	NodeVector::iterator nodeVectorIterator;
	for(nodeVectorIterator = compoundTestVector->begin(); nodeVectorIterator != compoundTestVector->end(); nodeVectorIterator++) {
		compoundTest = (*nodeVectorIterator);
		string compoundTestID = XmlCommon::GetAttributeByName(compoundTest, "id");
		if(compoundTestID == testID){
			compoundTestVector->erase(nodeVectorIterator);
			foundCmpTest = true;
			break;
		}
	}

	// Make sure that the specified test is found if not found
	// then the test should be skipped and just return from this function
	if(!foundCmpTest) {
		return;
	}
			
	//Loop through Subtests and add them to testContainer, recurse on subtests if it is a cmp test	
	DOMElement* compoundTestElm = (DOMElement*)compoundTest;
	DOMNodeList* subTestList = compoundTestElm->getElementsByTagName(XMLString::transcode ("subtest"));
	unsigned int subTestLoop = 0;
	while(subTestLoop < subTestList->getLength()) {
		DOMNode* subTest = subTestList->item(subTestLoop++);
		string subTestID = XmlCommon::GetAttributeByName(subTest, "test_ref");
		bool addedTest = addTest(testContainer, subTestID);
		if (subTestID.substr(0,3) == "cmp" && addedTest) {
			processCompoundTest(testContainer, subTestID, compoundTestVector);
		}
	}//subtestwhile loop
}//processCoupoundTest()

bool XmlProcessor::addTest(taVector *testContainer, string fullTestID) 
{
	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Adds test id to the appropriate TestArray object or creates a new one.
	//
	// -----------------------------------------------------------------------

	TestArray* ta;
    string tempType;
    int testID;
	bool addedTest = false;
    
    //Separate test type from test id
    tempType = fullTestID.substr(0, 3);
    testID = atoi((fullTestID.substr(4, fullTestID.length()-4)).c_str());
    
	// get a reference to the 
    //Have we encountered this kind of test before.
    bool isFound = false;
	unsigned int i = 0;
   	for(i = 0; i < testContainer->size(); i++) {
        if ((testContainer->at(i))->checkTestType(tempType)) {
			ta = testContainer->at(i);
			isFound = true;
			break;
		}
	}

	// The test type was not found so add a new test type to the vector
    if(isFound == false) {
		ta = new TestArray(tempType);
        testContainer->push_back(ta);
    }
    
	//Check and see if this test is in TestArray
    if (ta->getTestBit(testID) == false) {
		ta->flipTestBit(testID);
		addedTest = true;
	}
	
    ta = NULL;

	return addedTest;

}//addTest()

//****************************************************************************************//
//								TestArray Class									     	  //	
//****************************************************************************************//

//Constructor and Destructors
TestArray::TestArray(string type)
{
	testBits = new bool[100];
	arraysize = 100;
	for (int i=0; i< arraysize; i++) {
		 testBits[i] = false;    // Initialize all elements to false.
	}
    testType = type;
}

TestArray::~TestArray() {
    //do cleanup
	delete [] testBits;
}

//Keep track of what tests we want to keep
void TestArray::flipTestBit(int testID) 
{	
    while ((arraysize-1) < testID) {
        expandArray();
    }

    testBits[testID] = true;
}    

//Ask if testID is true or false
bool TestArray::getTestBit(int testID)
{

	if ((arraysize-1) < testID) {

        return false;

    } else {

	    return testBits[testID];
	}
}

//Check the testType of this instance of TestArray
bool TestArray::checkTestType(string testTypeIn)
{
    if (testType == testTypeIn) {
        return true;
    } else {
        return false;
    }
}

//Get the test type
string TestArray::getTestType()
{
	return testType;
}

int TestArray::getArraySize()
{
	return arraysize;
}

//Set the testType of this instance of TestArray
void TestArray::setTestType(string type){
    this->testType = type;
}

//Expand the array to hold more test values
void TestArray::expandArray()
{		

	bool* temp = new bool[arraysize+500];

	// Copy bits from old array to new array
    for (int j=0; j < arraysize; j++) {
        temp[j] = testBits[j];
    }

	// Set newly allocated bits to false
	for (int i = arraysize; i < arraysize+500; i++) {
		temp[i] = false;
	}
    arraysize = arraysize+500;

	// delete the old array
    delete [] testBits;
 
    testBits = temp;
}

XERCES_CPP_NAMESPACE_QUALIFIER DOMDocument* XmlProcessor::CreateDOMDocument(string root)
{
	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Create a new DOMDocument.
	//
	// -----------------------------------------------------------------------

	const XMLCh *xmlRoot = XMLString::transcode(root.c_str());


	DOMImplementation* impl =  DOMImplementationRegistry::getDOMImplementation(XMLString::transcode ("Core"));
	XERCES_CPP_NAMESPACE_QUALIFIER DOMDocument* doc = impl->createDocument(0, xmlRoot, 0);

	return(doc);

}

XERCES_CPP_NAMESPACE_QUALIFIER DOMDocument* XmlProcessor::CreateDOMDocumentNS(string namespaceURI, string qualifiedName)
{
	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Create a new DOMDocument.
	//
	// -----------------------------------------------------------------------

	const XMLCh *uri = XMLString::transcode(namespaceURI.c_str());
	const XMLCh *name = XMLString::transcode(qualifiedName.c_str());


	DOMImplementation* impl =  DOMImplementationRegistry::getDOMImplementation(XMLString::transcode ("Core"));
	XERCES_CPP_NAMESPACE_QUALIFIER DOMDocument* doc = impl->createDocument(uri, name, NULL);

	return(doc);

}

void XmlProcessor::WriteDOMDocument(XERCES_CPP_NAMESPACE_QUALIFIER DOMDocument* doc,  string filePath, bool writeToFile)
{
	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Write the DOMDocument to the specified XML file.
	//	filePath is the filename and path to the file that will be written
	//
	// -----------------------------------------------------------------------

	try
	{
		// get a serializer, an instance of DOMWriter
		XMLCh tempStr[100];
		XMLString::transcode("LS", tempStr, 99);
		DOMImplementation *impl = DOMImplementationRegistry::getDOMImplementation(tempStr);
		DOMWriter *theSerializer = ((DOMImplementationLS*)impl)->createDOMWriter();

		// set feature if the serializer supports the feature/mode
		if (theSerializer->canSetFeature(XMLUni::fgDOMWRTSplitCdataSections, true))
			theSerializer->setFeature(XMLUni::fgDOMWRTSplitCdataSections, true);

		if (theSerializer->canSetFeature(XMLUni::fgDOMWRTDiscardDefaultContent, true))
			theSerializer->setFeature(XMLUni::fgDOMWRTDiscardDefaultContent, true);

		if (theSerializer->canSetFeature(XMLUni::fgDOMWRTFormatPrettyPrint, true))
			theSerializer->setFeature(XMLUni::fgDOMWRTFormatPrettyPrint, true);

		if (theSerializer->canSetFeature(XMLUni::fgDOMWRTBOM, false))
			theSerializer->setFeature(XMLUni::fgDOMWRTBOM, false);

		//
		// Plug in a format target to receive the resultant
		// XML stream from the serializer.
		//
		// StdOutFormatTarget prints the resultant XML stream
		// to stdout once it receives any thing from the serializer.
		//
		XMLFormatTarget *myFormTarget;
		if (writeToFile)
			myFormTarget = new LocalFileFormatTarget(filePath.c_str());
		else
			myFormTarget = new StdOutFormatTarget();

		//
		// do the serialization through DOMWriter::writeNode();
		//
		theSerializer->writeNode(myFormTarget, *doc);

		delete theSerializer;
		delete myFormTarget;
	}
	catch(...)
	{
		string error;
		if(writeToFile)
		{
			error.append("Error while writing Document to XML file: ");
			error.append(filePath);
		}else
		{
			error.append("Error while writing Document to screen");
		}

		throw XmlProcessorException(error);
	}

}

//****************************************************************************************//
//								XmlProcessorException Class								  //	
//****************************************************************************************//
XmlProcessorException::XmlProcessorException(string errMsgIn, int severity) : Exception(errMsgIn, severity)
{
	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Set the error message and then set the severity to ERROR_FATAL. This is 
	//	done with the explicit call to the Exception class constructor that 
	//	takes a single string param.
	//
	// -----------------------------------------------------------------------

}

XmlProcessorException::~XmlProcessorException()
{
	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Do nothing for now
	//
	// -----------------------------------------------------------------------

}

//****************************************************************************************//
//								XmlProcessorErrorHandler Class							  //	
//****************************************************************************************//
XmlProcessorErrorHandler::XmlProcessorErrorHandler() : fSawErrors(false)
{
	errorMessages = "";
}

XmlProcessorErrorHandler::~XmlProcessorErrorHandler()
{
}


// ---------------------------------------------------------------------------
//  XmlProcessorErrorHandler: Overrides of the DOM ErrorHandler interface
// ---------------------------------------------------------------------------
bool XmlProcessorErrorHandler::handleError(const DOMError& domError)
{
    fSawErrors = true;
    if (domError.getSeverity() == DOMError::DOM_SEVERITY_WARNING)
        errorMessages.append("\n\tSeverity: Warning");
    else if (domError.getSeverity() == DOMError::DOM_SEVERITY_ERROR)
        errorMessages.append("\n\tSeverity: Error");
    else
        errorMessages.append("\n\tSeverity: Fatal Error");

	string msg =  XmlCommon::ToString(domError.getMessage());
	string file = XmlCommon::ToString(domError.getLocation()->getURI());
	long line = domError.getLocation()->getLineNumber();
	long at = domError.getLocation()->getColumnNumber();


	errorMessages.append("\n\tMessage: " + msg);
	errorMessages.append("\n\tFile: " + file);
	errorMessages.append("\n\tLine " + Common::ToString(line));
	errorMessages.append("\n\tAt char " + Common::ToString(at));
	

    return true;
}

void XmlProcessorErrorHandler::resetErrors()
{
    fSawErrors = false;
}

bool XmlProcessorErrorHandler::getSawErrors() const
{
    return fSawErrors;
}

string XmlProcessorErrorHandler::getErrorMessages() const
{
    return errorMessages;
}

