//
// $Id: Common.cpp,v 4.15 2005/09/20 20:57:01 bakerj Exp $
//
//****************************************************************************************//
// Copyright (c) 2005, The MITRE Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without modification, are
// permitted provided that the following conditions are met:
//
//     * Redistributions of source code must retain the above copyright notice, this list
//       of conditions and the following disclaimer.
//     * Redistributions in binary form must reproduce the above copyright notice, this 
//       list of conditions and the following disclaimer in the documentation and/or other
//       materials provided with the distribution.
//     * Neither the name of The MITRE Corporation nor the names of its contributors may be
//       used to endorse or promote products derived from this software without specific 
//       prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY 
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
// OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT 
// SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
// OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
// TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
// EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
//****************************************************************************************//

#include "Common.h"

// Initialize static variables.
string	Common::dataFile			= "system-characteristics.xml";
string	Common::xmlfile				= "definitions.xml";
string	Common::outputFilename		= "results.xml";
string	Common::variablesFilename	= "variables.xml";
string	Common::xmlfileMD5			= "";
string	Common::startTime			= "";
string	Common::mappingFile			= "mapping.xml";
string  Common::statusType			= "000110";
string	Common::classType			= "vulnerability";

#ifdef WIN32
	string  Common::familyType   = "windows";
#endif

#ifdef REDHAT
	string  Common::familyType  = "redhat";
#endif

bool	Common::generateMD5			= false;
bool	Common::outputToFile		= true;
bool	Common::saveData			= true;
bool	Common::useProvidedData		= false;
bool	Common::verifyXMLfile		= true;
bool	Common::useConfiguration	= true;
bool	Common::useVariableFile		= false;

sVector  Common::statusVector;
string   Common::statusTypes[6]		= {"INCOMPLETE", "INITIAL SUBMISSION", "DRAFT", "INTERIM", "ACCEPTED", "DEPRECIATED"};

//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~  Accessors  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//

string Common::GetFamilyType()
{
	return familyType;
}

string Common::GetClassType()
{
	return classType;
}
string Common::GetStatusType()
{
	return statusType;
}
string Common::GetDatafile()
{
	return dataFile;	
}

bool Common::GetGenerateMD5()
{
	return generateMD5;	
}

string Common::GetMappingfile()
{
	return mappingFile;	
}

string Common::GetXMLfile()
{
	return xmlfile;	
}

string Common::GetXMLfileMD5()
{
	return xmlfileMD5;	
}

string Common::GetOutputFilename()
{
	return outputFilename;	
}

bool Common::GetOutputToFile()
{
	return outputToFile;	
}

bool Common::GetSaveData()
{
	return saveData;	
}

string Common::GetStartTime()
{
	return startTime;
}

bool Common::GetUseProvidedData()
{
	return useProvidedData;	
}

bool Common::GetUseConfiguration()
{
	return useConfiguration;	
}

bool Common::GetUseVariableFile()
{
	return useVariableFile;	
}

string Common::GetVariableFilename()
{
	return variablesFilename;	
}

bool Common::GetVerifyXMLfile()
{
	return verifyXMLfile;	
}

bool Common::CompareStatusType(string statusIn)
{
	if (statusVector.end() == find(statusVector.begin(), statusVector.end(), statusIn)) {
		return false;
	}
	else {
		return true;
	}
}

bool Common::CompareFamilyType(string familyIn)
{
	if (familyIn == familyType) {
		return true;
	}
	else {
		return false;
	}
}

bool Common::CompareClassType(string classIn)
{
	if (classIn == classType) {
		return true;
	}
	else {
		return false;
	}
}
	


//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~  Mutators  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
void Common::SetFamilyType(string familyIn)
{
	familyType = familyIn;
}

void Common::SetStatusType(string statusIn)
{
	statusType = statusIn;
	
	//set statusVector as well
	
	for (int index = 0; index < 5; index++) {
		if (statusType[index] == '1') {
			statusVector.push_back(statusTypes[index]);
		}
	}		
}

void Common::SetClassType(string classIn)
{
	classType = classIn;
}

void Common::SetDataFile(string fileIn)
{
	dataFile = fileIn;	
}

void Common::SetGenerateMD5(bool genMD5In)
{
	generateMD5 = genMD5In;
}

void Common::SetMappingFile(string fileIn)
{
	mappingFile = fileIn;
}

void Common::SetXMLfile(string xmlfileIn)
{
	xmlfile = xmlfileIn;
}

void Common::SetXMLfileMD5(string xmlfileMD5In)
{
	xmlfileMD5 = xmlfileMD5In;
}

void Common::SetOutputFilename(string outputFilenameIn)
{
	outputFilename = outputFilenameIn;
}

void Common::SetOutputToFile(bool outputToFileIn)
{
	outputToFile = outputToFileIn;
}

void Common::SetSaveData(bool saveIn)
{
	saveData = saveIn;
}

void Common::SetStartTime(string startTimeIn)
{
	startTime = startTimeIn;
}

void Common::SetUseProvidedData(bool useDataIn)
{
	useProvidedData = useDataIn;
}

void Common::SetUseConfiguration(bool useConfigIn)
{
	useConfiguration = useConfigIn;
}

void Common::SetUseVariableFile(bool useVarsIn)
{
	useVariableFile = useVarsIn;
}

void Common::SetVariableFilename(string varFilenameIn)
{
	variablesFilename = varFilenameIn;
}

void Common::SetVerifyXMLfile(bool verifyXMLfileIn)
{
	verifyXMLfile = verifyXMLfileIn;
}

//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~  Public Members  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
#ifdef WIN32
bool Common::DisableAllPrivileges()
{
	//------------------------------------------------------------------------------------//
	//  ABSTRACT
	//
	//  This function disables all the privileges associated with the current process
	//  token.  If a specific privilege is needed later, it can be enabled by calling
	//  AdjustTokenPrivileges() again.
	//
	//------------------------------------------------------------------------------------//

	HANDLE hToken = NULL;

	// Get a handle to the current process.

	if (OpenProcessToken(GetCurrentProcess(),						// handle to the process
						 TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY,		// requested access types 
						 &hToken) == FALSE)							// new access token 
	{
		char buffer[33];
		_itoa(GetLastError(), buffer, 10);

		string errorMessage = "";
		errorMessage.append("\nERROR: Unable to get a handle to the current process.  Error # - ");
		errorMessage.append(buffer);
		errorMessage.append("\n");
		cerr << errorMessage;
		Log::WriteLog(errorMessage);

		return false;
	}

	// Disable all the privileges for this token.

	if (AdjustTokenPrivileges(hToken,					// handle to token
							  TRUE,						// disabling option
							  NULL,						// privilege information
							  0,						// size of buffer
							  NULL,						// original state buffer
							  NULL) == FALSE)			// required buffer size
	{
		char buffer[33];
		_itoa(GetLastError(), buffer, 10);

		string errorMessage = "";
		errorMessage.append("\nERROR: Unable to disable token privileges.  Error # - ");
		errorMessage.append(buffer);
		errorMessage.append("\n");
		cerr << errorMessage;
		Log::WriteLog(errorMessage);

		CloseHandle(hToken);
		return false;
	}

	CloseHandle(hToken);

	return true;
}

bool Common::EnablePrivilege(string privilegeIn)
{
	//------------------------------------------------------------------------------------//
	//  ABSTRACT
	//
	//  This gives us a privilege.
	//
	//------------------------------------------------------------------------------------//

	TOKEN_PRIVILEGES tp;
	HANDLE hProcess = NULL;
	HANDLE hAccessToken = NULL;

	hProcess = GetCurrentProcess();

	if(!OpenProcessToken(hProcess,									// handle to the process
						(TOKEN_QUERY | TOKEN_ADJUST_PRIVILEGES),	// requested access types 
						&hAccessToken) == FALSE)					// new access token 
	{
		return false;
	}

	tp.PrivilegeCount = 1;
    tp.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED;
    
	if (LookupPrivilegeValue(NULL, privilegeIn.c_str(), &tp.Privileges[0].Luid) == 0)
	{
		return false;
	}

	if (AdjustTokenPrivileges(hAccessToken, FALSE, &tp, NULL, NULL, NULL) == 0)
	{
		return false;
	}
	 
	if(GetLastError() == ERROR_NOT_ALL_ASSIGNED)
	{
		// The token for the current process does not have the privilege specified. The
		// AdjustTokenPrivileges() function may succeed with this error value even if no
		// privileges were adjusted.  The privilege parameter can specify privileges that
		// the token does not have, without causing the function to fail. In this case, 
		// the function adjusts the privileges that the token does have and ignores the 
		// other privileges so that the function succeeds.

		CloseHandle(hAccessToken);
		return false;
	}
	else
	{
		CloseHandle(hAccessToken);
		return true;
	}
}

string Common::GetErrorMessage(DWORD dwLastError)
{
	//------------------------------------------------------------------------------------//
	//  ABSTRACT
	//
	//  Return the system's error message for the specified error code.
	//
	//------------------------------------------------------------------------------------//

	string errMsg = "";

    HMODULE hModule = NULL; // default to system source
    LPSTR MessageBuffer;
    DWORD dwBufferLength;

    DWORD dwFormatFlags = FORMAT_MESSAGE_ALLOCATE_BUFFER |
        FORMAT_MESSAGE_IGNORE_INSERTS |
        FORMAT_MESSAGE_FROM_SYSTEM ;

    //
    // If dwLastError is in the network range, 
    //  load the message source.
    //
    if(dwLastError >= NERR_BASE && dwLastError <= MAX_NERR) {
        hModule = LoadLibraryEx(
            TEXT("netmsg.dll"),
            NULL,
            LOAD_LIBRARY_AS_DATAFILE
            );

        if(hModule != NULL)
            dwFormatFlags |= FORMAT_MESSAGE_FROM_HMODULE;
    }

    //
    // Call FormatMessage() to allow for message 
    //  text to be acquired from the system 
    //  or from the supplied module handle.
    //
    if(dwBufferLength = FormatMessageA(dwFormatFlags,
										hModule, // module to get message from (NULL == system)
										dwLastError,
										MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // default language
										(LPSTR) &MessageBuffer,
										0,
										NULL
										))
    {
//      DWORD dwBytesWritten;

        //
        // Output message string on stderr.
        //
     /*   WriteFile(
            GetStdHandle(STD_ERROR_HANDLE),
            MessageBuffer,
            dwBufferLength,
            &dwBytesWritten,
            NULL
            );
*/
		errMsg = MessageBuffer;
        //
        // Free the buffer allocated by the system.
        //
        LocalFree(MessageBuffer);
    }

    //
    // If we loaded a message source, unload it.
    //
    if(hModule != NULL)
        FreeLibrary(hModule);

	return errMsg;
}


#endif

string Common::GetEnviromentVariable(string envVarIn)
{
	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Return the value of the environment variable specified. If it is not 
	//	found return NULL.
	//
	// -----------------------------------------------------------------------

	string value;
	value = getenv(envVarIn.c_str());

	//	Check the value 
	if(value.compare("") == 0)
	{
		string errMsg = "Message: Unable to find the value of: " + envVarIn + "\n";
		throw CommonException(errMsg);
	}

	return (value);
}

string Common::PadString(string strIn, unsigned int desiredLength)
{
	//------------------------------------------------------------------------------------//
	//  ABSTRACT
	//
	//  Padd the provided string with spaces so that it is the desired length.
	//
	//------------------------------------------------------------------------------------//

	while(strIn.length() < desiredLength) {
		strIn.append(" ");
	}

	return strIn;
}

string Common::SwitchChar(string fixedString, string oldChr, string newChr)
{
	//------------------------------------------------------------------------------------//
	//  ABSTRACT
	//
	//  This function takes a string and searches for all oldChrs.  If one is found,
	//  it is replaced with a newChr.  It is only intended to work with a single char 
	//	at a time. No multiple char strings allowed
	//
	//------------------------------------------------------------------------------------//

	if(oldChr.length() != 1 || newChr.length() != 1)
		throw CommonException("Error: (SwitchChar) can only switch strings of length = 1.");

	unsigned int pos = fixedString.find(oldChr, 0);
	while (pos != string::npos)
	{
		fixedString.erase(pos, 1);
		fixedString.insert(pos, newChr);
		pos = fixedString.find(oldChr, pos+1);
	}

	return fixedString;
}

#ifdef WIN32
bool Common::GetTextualSid(PSID pSid, LPTSTR* TextualSid)
{ 
	//------------------------------------------------------------------------------------//
	//
	//  ABSTRACT
	//
	//  A SID value includes components that provide information about the SID structure
	//  and components that uniquely identify a trustee. A SID consists of the following
	//  components: 
	//
	//   * The revision level of the SID structure 
	//   * A 48-bit identifier authority value that identifies the authority that issued
	//     the SID 
	//   * A variable number of subauthority or relative identifier (RID) values that
	//     uniquely identify the trustee relative to the authority that issued the SID
	//
	//  The combination of the identifier authority value and the subauthority values
	//  ensures that no two SIDs will be the same, even if two different SID-issuing
	//  authorities issue the same combination of RID values. Each SID-issuing authority
	//  issues a given RID only once. 
	//
	//  SIDs are stored in binary format in a SID structure. To display a SID, you can
	//  call the ConvertSidToStringSid function to convert a binary SID to string format.
	//  To convert a SID string back to a valid, functional SID, call the
	//  ConvertStringSidToSid function. 
	//
	//  These functions use the following standardized string notation for SIDs, which
	//  makes it simpler to visualize their components: 
	//
	//  S-R-I-S-S...
	//
	//  In this notation, the literal character S identifies the series of digits as a
	//  SID, R is the revision level, I is the identifier-authority value, and S... is one
	//  or more subauthority values. 
	//
	//  NOTE:
	//
	//  Windows 2000 provides the ConvertSidToStringSid and ConvertStringSidToSid functions
	//  for converting a SID to and from string format. For a description of the SID string
	//  format, see SID Components.
	//
	//  On earlier versions of Windows NT, use the following sample code to convert a SID
	//  to string format.
	//
	//------------------------------------------------------------------------------------//

    PSID_IDENTIFIER_AUTHORITY psia;
    DWORD dwSubAuthorities;
    DWORD dwSidRev = SID_REVISION;
    DWORD dwCounter;
    DWORD dwSidSize;

    // Validate the binary SID.

	if(!IsValidSid(pSid)) return false;

    // Get the identifier authority value from the SID.
	
	psia = GetSidIdentifierAuthority(pSid);

	// Get the number of subauthorities in the SID.

    dwSubAuthorities = *GetSidSubAuthorityCount(pSid);

    // compute buffer length
    // S-SID_REVISION- + identifierauthority- + subauthorities- + NULL

    dwSidSize=(15 + 12 + (12 * dwSubAuthorities) + 1) * sizeof(TCHAR);

    // allocate memory
 
	*TextualSid = (LPTSTR)malloc(dwSidSize);
	if(*TextualSid == NULL)
	{
		return false;
	}

    // Add 'S' prefix and revision number to the string.

    dwSidSize = wsprintf(*TextualSid, TEXT("S-%lu-"), dwSidRev);

    // Add SID identifier authority to the string.
 
    if ((psia->Value[0] != 0) || (psia->Value[1] != 0))
    {
        dwSidSize += wsprintf(*TextualSid + lstrlen(*TextualSid),
							  TEXT("0x%02hx%02hx%02hx%02hx%02hx%02hx"),
							  (USHORT)psia->Value[0],
							  (USHORT)psia->Value[1],
							  (USHORT)psia->Value[2],
							  (USHORT)psia->Value[3],
							  (USHORT)psia->Value[4],
							  (USHORT)psia->Value[5]);
    }
    else
    {
        dwSidSize += wsprintf(*TextualSid + lstrlen(*TextualSid),
							  TEXT("%lu"),
							  (ULONG)(psia->Value[5]) +
							  (ULONG)(psia->Value[4] << 8) +
							  (ULONG)(psia->Value[3] << 16) +
							  (ULONG)(psia->Value[2] << 24));
    }

    // Loop through SidSubAuthorities and add them to the string.

    for (dwCounter=0; dwCounter<dwSubAuthorities; dwCounter++)
    {
        dwSidSize += wsprintf(*TextualSid + dwSidSize,
							  TEXT("-%lu"),
			                  *GetSidSubAuthority(pSid, dwCounter));
    }

    return true;
}
#endif

string Common::GetTimeStamp()
{
	//------------------------------------------------------------------------------------//
	//  ABSTRACT
	//
	//  Retrieve the date/time.  The final output will be in the format:
	//
	//    mm/dd/yyyy hh:mm:ss
	//
	//------------------------------------------------------------------------------------//

	char tmpbuf[128];
	
	time_t tmpTime;
	struct tm *todayTime;
	
	// Get the time as a long integer, then convert it to local time.
	time(&tmpTime);
	todayTime = localtime(&tmpTime);
	
	// Build the time string.
	char *format = "%Y%m%d%H%M%S";
	strftime(tmpbuf, 128-1, format, todayTime);

	// Make sure the buffer is null terminated.
	tmpbuf[sizeof(tmpbuf)-1] = '\0';
	
	return (tmpbuf);
}

string Common::ToString(int num)
{
	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Return a the int as a string
	//
	// -----------------------------------------------------------------------
	ostringstream result;
	result << num;

	return result.str();
}
string Common::ToString(long num)
{
	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Return a the long as a string
	//
	// -----------------------------------------------------------------------
	ostringstream result;
	result << num;

	return result.str();
}
string Common::ToString(unsigned long num)
{
	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Return a the unsigned long as a string
	//
	// -----------------------------------------------------------------------
	ostringstream result;
	result << num;

	return result.str();
}


//****************************************************************************************//
//							CommonException Class										  //	
//****************************************************************************************//
CommonException::CommonException(string errMsgIn, int severity) : Exception(errMsgIn, severity)
{
	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Set the error message and then set the severity to ERROR_FATAL. This is 
	//	done with the explicit call to the Exception class constructor that 
	//	takes a single string param.
	//
	// -----------------------------------------------------------------------

}

CommonException::~CommonException()
{
	// -----------------------------------------------------------------------
	//	Abstract
	//
	//	Do nothing for now
	//
	// -----------------------------------------------------------------------

}
