[Vm-dev] [commit] r2242 - Add win32 implementation for SqueakSSL plugin.

commits at squeakvm.org commits at squeakvm.org
Mon Jul 19 06:12:15 UTC 2010


Author: andreas
Date: 2010-07-18 23:12:14 -0700 (Sun, 18 Jul 2010)
New Revision: 2242

Added:
   trunk/platforms/win32/plugins/SqueakSSL/
   trunk/platforms/win32/plugins/SqueakSSL/sqWin32SSL.c
Log:
Add win32 implementation for SqueakSSL plugin.

Added: trunk/platforms/win32/plugins/SqueakSSL/sqWin32SSL.c
===================================================================
--- trunk/platforms/win32/plugins/SqueakSSL/sqWin32SSL.c	                        (rev 0)
+++ trunk/platforms/win32/plugins/SqueakSSL/sqWin32SSL.c	2010-07-19 06:12:14 UTC (rev 2242)
@@ -0,0 +1,827 @@
+/* sqWin32SSL.c: SqueakSSL implementation for Windows */
+
+#include <windows.h>
+#include <errno.h>
+#include <malloc.h>
+
+#include "sq.h"
+#include "SqueakSSL.h"
+
+#define SECURITY_WIN32
+#include <security.h>
+#include <schannel.h>
+#include <wincrypt.h>
+
+typedef struct sqSSL {
+	int state;
+	int certFlags;
+	int loglevel;
+
+	char *certName;
+	char *peerName;
+
+	CredHandle sslCred;
+	CtxtHandle sslCtxt;
+	
+	SecBufferDesc sbdIn;
+	SecBufferDesc sbdOut;
+	SecBuffer inbuf[4];
+	SecBuffer outbuf[4];
+
+	/* internal data buffer */
+	char *dataBuf;
+	int dataLen;
+	int dataMax;
+
+	SecPkgContext_StreamSizes sslSizes;
+} sqSSL;
+
+static sqSSL **handleBuf = NULL;
+static sqInt handleMax = 0;
+
+
+/********************************************************************/
+/********************************************************************/
+/********************************************************************/
+
+/* sslFromHandle: Maps a handle to an SSL */
+static sqSSL *sslFromHandle(sqInt handle) {
+	return handle < handleMax ? handleBuf[handle] : NULL;
+}
+
+
+/* sqPrintSBD: Prints a SecurityBuffer for debugging */
+static void sqPrintSBD(char *title, SecBufferDesc sbd) {
+	unsigned int i;
+	printf("%s\n", title);
+	for(i=0; i<sbd.cBuffers; i++) {
+		SecBuffer *buf = sbd.pBuffers + i;
+		printf("\tbuf[%d]: %d (%d bytes) ptr=%x\n", i,buf->BufferType, buf->cbBuffer, (int)buf->pvBuffer);
+	}
+}
+
+/* sqCopyExtraData: Retains any SECBUFFER_EXTRA data. */
+static void sqCopyExtraData(sqSSL *ssl, SecBufferDesc sbd) {
+	unsigned int i;
+	if(sbd.pBuffers[0].BufferType == SECBUFFER_MISSING) {
+		if(ssl->loglevel) printf("sqCopyExtra: Encountered SECBUFFER_MISSING; retaining %d bytes\n", ssl->dataLen);
+		return;
+	}
+	ssl->dataLen = 0;
+	for(i=0; i<sbd.cBuffers; i++) {
+		SecBuffer *buf = sbd.pBuffers + i;
+		if(buf->BufferType == SECBUFFER_EXTRA) {
+			int count = buf->cbBuffer;
+			char *srcPtr = buf->pvBuffer;
+			char *dstPtr = ssl->dataBuf + ssl->dataLen;
+			if(ssl->loglevel) printf("sqCopyExtraData: Retaining %d bytes\n", count);
+			/* I *think* the extra buffers are always in input range.
+			   Make sure that's the case or at least report it if not. */
+			if(srcPtr < dstPtr || (srcPtr + count) > (ssl->dataBuf + ssl->dataMax)) {
+				if(ssl->loglevel) printf("sqCopyExtraDataSSL: Encountered out-of-range extra buffer\n");
+			}
+			if(srcPtr != dstPtr) {
+				/* memmove() not memcpy() since the memory mayoverlap */
+				memmove(dstPtr, srcPtr, count);
+			}
+			ssl->dataLen += count;
+		}
+	}
+}
+
+/* Copies the data from a SecBufferDesc to dstBuf */
+static sqInt sqCopyDescToken(sqSSL *ssl, SecBufferDesc sbd, char *dstBuf, sqInt dstLen) {
+	unsigned int i;
+	int result = 0;
+
+	if(ssl->loglevel) printf("sqCopyDescToken: \n");
+	for(i = 0; i < sbd.cBuffers; i++) {
+		SecBuffer *buf = sbd.pBuffers + i;
+		if(ssl->loglevel) printf("\t type=%d, size=%d\n", buf->BufferType, buf->cbBuffer);
+		if(buf->BufferType == SECBUFFER_TOKEN) {
+			int count = buf->cbBuffer;
+			if(count > dstLen) return SQSSL_BUFFER_TOO_SMALL;
+			memcpy(dstBuf, buf->pvBuffer, count);
+			result += count;
+			dstBuf += count;
+			dstLen -= count;
+			FreeContextBuffer(buf->pvBuffer);
+		}
+		if(buf->BufferType == SECBUFFER_EXTRA) {
+			/* XXXX: Preserve contents for the next round */
+			if(ssl->loglevel) printf("sqCopyDescToken: Unexpectedly encountered SECBUFFER_EXTRA\n");
+		}
+	}
+	return result;
+}
+
+/* Set up the local certificate for SSL */
+static sqInt sqSetupCert(sqSSL *ssl, char *certName, int server) {
+	SCHANNEL_CRED sc_cred = { 0 };
+	SECURITY_STATUS ret;
+	HCERTSTORE hStore;
+	PCCERT_CONTEXT pContext = NULL;
+
+	if(certName) {
+		hStore = CertOpenSystemStore(0, "MY");
+		if(!hStore) {
+			if(ssl->loglevel) printf("sqSetupCert: CertOpenSystemStore failed\n");
+			return 0;
+		}
+		pContext = CertFindCertificateInStore(hStore, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING,
+											  0, CERT_FIND_SUBJECT_STR_A, certName, NULL);
+
+		/* XXXX: Fail? Or just not provide the cert? For now, fail. */
+		if(!pContext) {
+			if(ssl->loglevel) printf("sqSetupCert: CertFindCertitficateInStore failed\n");
+			CertCloseStore(hStore, 0);
+			return 0;
+		}
+	}
+
+	sc_cred.dwVersion = SCHANNEL_CRED_VERSION;
+	sc_cred.dwFlags = SCH_CRED_NO_DEFAULT_CREDS | SCH_CRED_MANUAL_CRED_VALIDATION;
+	sc_cred.grbitEnabledProtocols = server ? SP_PROT_TLS1_SERVER | SP_PROT_SSL3_SERVER | SP_PROT_SSL2_SERVER : 0;
+	sc_cred.dwMinimumCipherStrength = 0;
+	sc_cred.dwMaximumCipherStrength = 0;
+
+	if(pContext) {
+		sc_cred.cCreds = 1;
+		sc_cred.paCred = &pContext;
+	} else {
+		sc_cred.cCreds = 0;
+	}
+
+	ret = AcquireCredentialsHandle(NULL, UNISP_NAME, 
+									server ? SECPKG_CRED_INBOUND : SECPKG_CRED_OUTBOUND, 
+									NULL, &sc_cred, NULL, NULL, &ssl->sslCred, NULL);
+	if(ssl->loglevel) printf("AquireCredentialsHandle returned: %x\n", ret);
+
+	if(pContext) {
+		CertCloseStore(hStore, 0);
+	    CertFreeCertificateContext(pContext);
+	}
+
+	if (ret != SEC_E_OK) {
+		if(ssl->loglevel) printf("AquireCredentialsHandle error: %x\n", ret);
+		return 0;
+	}
+	return 1;
+}
+
+/* sqExtractPeerName: Extract the name from the cert of the remote peer. */
+static int sqExtractPeerName(sqSSL *ssl) {
+	SECURITY_STATUS ret;
+	PCCERT_CONTEXT certHandle = NULL;
+	PCERT_NAME_INFO certInfo = NULL;
+	PCERT_RDN_ATTR certAttr = NULL;
+	DWORD dwSize = 0;
+	char tmpBuf[1024];
+
+	if(ssl->peerName) {
+		free(ssl->peerName);
+		ssl->peerName = NULL;
+	}
+	ret = QueryContextAttributes(&ssl->sslCtxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (PVOID)&certHandle);
+	/* No credentials were provided; can't extract peer name */
+	if(ret == SEC_E_NO_CREDENTIALS) return 1;
+
+	if(ret != SEC_E_OK) {
+		if(ssl->loglevel) printf("sqExtractPeerName: QueryContextAttributes failed (code = %x)\n", ret);
+		return 0;
+	}
+
+	/* Figure out the size of the blob */
+	if(!CryptDecodeObject(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, X509_NAME,
+		certHandle->pCertInfo->Subject.pbData,
+		certHandle->pCertInfo->Subject.cbData,
+		0, NULL, &dwSize)) {
+			if(ssl->loglevel) printf("sqExtractPeerName: CryptDecodeObject failed\n");
+			return 0;
+	}
+	
+	/* Get the contents */
+	certInfo = alloca(dwSize);
+	if(!CryptDecodeObject(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, X509_NAME, 
+		certHandle->pCertInfo->Subject.pbData,
+		certHandle->pCertInfo->Subject.cbData,
+		0, certInfo, &dwSize)) {
+			if(ssl->loglevel) printf("sqExtractPeerName: CryptDecodeObject failed\n");
+			return 0;
+	}
+
+	/* Fetch the CN from the cert */
+	certAttr = CertFindRDNAttr(szOID_COMMON_NAME, certInfo);
+	if(certAttr == NULL) return 0;
+
+	/* Translate from RDN to string */
+	if(CertRDNValueToStr(CERT_RDN_PRINTABLE_STRING, &certAttr->Value, tmpBuf, sizeof(tmpBuf)) == 0) return 0;
+	ssl->peerName = _strdup(tmpBuf);
+	if(ssl->loglevel) printf("sqExtractPeerName: Peer name is %s\n", ssl->peerName);
+
+	return 1;
+}
+
+/* sqVerifyCert: Verify the validity of the remote certificate */
+static int sqVerifyCert(sqSSL *ssl, int isServer) {
+	SECURITY_STATUS ret;
+	PCCERT_CONTEXT certHandle = NULL;
+	PCCERT_CHAIN_CONTEXT chainContext = NULL;
+	CERT_CHAIN_PARA chainPara;
+	SSL_EXTRA_CERT_CHAIN_POLICY_PARA epp;
+	CERT_CHAIN_POLICY_PARA policyPara;
+	CERT_CHAIN_POLICY_STATUS policyStatus;
+	
+	static LPSTR serverUsage[] = {
+		szOID_PKIX_KP_SERVER_AUTH,
+		szOID_SERVER_GATED_CRYPTO,
+		szOID_SGC_NETSCAPE
+	};
+	static LPSTR clientUsage[] = {
+		szOID_PKIX_KP_CLIENT_AUTH
+	};
+	
+	ret = QueryContextAttributes(&ssl->sslCtxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (PVOID)&certHandle);
+	/* No credentials were provided */
+	if(ret == SEC_E_NO_CREDENTIALS) {
+		ssl->certFlags = SQSSL_NO_CERTIFICATE;
+		return 1;
+	}
+
+	memset(&chainPara, 0, sizeof(chainPara));
+	chainPara.cbSize = sizeof(chainPara);
+	chainPara.RequestedUsage.dwType = USAGE_MATCH_TYPE_OR;
+	if(!isServer) {
+		chainPara.RequestedUsage.Usage.cUsageIdentifier = 3;
+		chainPara.RequestedUsage.Usage.rgpszUsageIdentifier = serverUsage;
+	} else {
+		chainPara.RequestedUsage.Usage.cUsageIdentifier = 1;
+		chainPara.RequestedUsage.Usage.rgpszUsageIdentifier = clientUsage;
+	}
+	if(!CertGetCertificateChain(NULL, certHandle, NULL,
+                                certHandle->hCertStore,
+								&chainPara, 0, NULL, &chainContext)) {
+		/* XXXX: Does this mean the other end did not provide a cert? */
+		ssl->certFlags = SQSSL_OTHER_ISSUE;
+		goto done;
+	}
+	
+	memset(&epp, 0, sizeof(epp));
+	epp.cbSize = sizeof(epp);
+	epp.dwAuthType = AUTHTYPE_SERVER;
+	epp.fdwChecks = 0;
+	epp.pwszServerName = NULL;
+
+	memset(&policyPara, 0, sizeof(policyPara));
+	policyPara.cbSize = sizeof(policyPara);
+	policyPara.dwFlags = 0;
+	policyPara.pvExtraPolicyPara = &epp;
+	memset(&policyStatus, 0, sizeof(policyStatus));
+	policyStatus.cbSize = sizeof(policyStatus);
+
+	/* We loop here CertVerifyCertificateChainPolicy() returns only a 
+	   single error even if there is more than one issue with the cert. */
+	ssl->certFlags = 0;
+	while(true) {
+		if (!CertVerifyCertificateChainPolicy(
+           CERT_CHAIN_POLICY_SSL,
+           chainContext,
+           &policyPara,
+           &policyStatus)) {
+				ssl->certFlags |= SQSSL_OTHER_ISSUE;
+				goto done;
+		}
+		switch(policyStatus.dwError) {
+			case SEC_E_OK:
+				goto done;
+			case CERT_E_UNTRUSTEDROOT:
+				if(ssl->certFlags & SQSSL_UNTRUSTED_ROOT) goto done;
+				ssl->certFlags |= SQSSL_UNTRUSTED_ROOT;
+				epp.fdwChecks  |= 0x00000100; /* SECURITY_FLAG_IGNORE_UNKNOWN_CA */
+				break;
+			case CERT_E_EXPIRED:
+				if(ssl->certFlags & SQSSL_CERT_EXPIRED) goto done;
+				ssl->certFlags |= SQSSL_CERT_EXPIRED;
+				epp.fdwChecks  |= 0x00002000;  /* SECURITY_FLAG_IGNORE_CERT_DATE_INVALID */
+				break;
+			case CERT_E_WRONG_USAGE:
+				if(ssl->certFlags & SQSSL_WRONG_USAGE) goto done;
+				ssl->certFlags |= SQSSL_WRONG_USAGE;
+				epp.fdwChecks  |= 0x00000200;   /* SECURITY_FLAG_IGNORE_WRONG_USAGE */
+			case CERT_E_REVOKED:
+				if(ssl->certFlags & SQSSL_CERT_REVOKED) goto done;
+				ssl->certFlags |= SQSSL_CERT_REVOKED;
+				epp.fdwChecks  |= 0x00000080;   /* SECURITY_FLAG_IGNORE_REVOCATION */
+				break;
+			default:
+				ssl->certFlags |= SQSSL_OTHER_ISSUE;
+				goto done;
+		}
+	}
+done:
+	return 1;
+}
+
+/********************************************************************/
+/********************************************************************/
+/********************************************************************/
+
+/* sqCreateSSL: Creates a new SSL instance.
+	Arguments: None.
+	Returns: SSL handle.
+*/
+sqInt sqCreateSSL(void) {
+	sqInt handle;
+	sqSSL *ssl = NULL;
+
+	ssl = calloc(1, sizeof(sqSSL));
+
+	ssl->sbdIn.ulVersion = SECBUFFER_VERSION;
+	ssl->sbdIn.cBuffers = 4;
+	ssl->sbdIn.pBuffers = ssl->inbuf;
+
+	ssl->sbdOut.ulVersion = SECBUFFER_VERSION;
+	ssl->sbdOut.cBuffers = 4;
+	ssl->sbdOut.pBuffers = ssl->outbuf;
+
+	/* Find a free handle */
+	for(handle = 1; handle < handleMax; handle++)
+		if(handleBuf[handle] == NULL) break;
+
+	if(handle >= handleMax) {
+		int i, delta = 100;
+		/* Resize the handle buffer */
+		handleBuf = realloc(handleBuf, (handleMax+delta)*sizeof(void*));
+		for(i = handleMax; i < handleMax+delta; i++)
+			handleBuf[i] = NULL;
+		handleMax += delta;
+	}
+	handleBuf[handle] = ssl;
+	return handle;
+}
+
+/* sqDestroySSL: Destroys an SSL instance.
+	Arguments:
+		handle - the SSL handle
+	Returns: Non-zero if successful.
+*/
+sqInt sqDestroySSL(sqInt handle) {
+	sqSSL *ssl = sslFromHandle(handle);
+	if(ssl == NULL) return 0;
+
+	FreeCredentialsHandle(&ssl->sslCred);
+	DeleteSecurityContext(&ssl->sslCtxt);
+
+	if(ssl->certName) free(ssl->certName);
+	if(ssl->peerName) free(ssl->peerName);
+	if(ssl->dataBuf) free(ssl->dataBuf);
+
+	free(ssl);
+	handleBuf[handle] = NULL;
+	return 1;
+}
+
+/* sqConnectSSL: Start/continue an SSL client handshake.
+	Arguments:
+		handle - the SSL handle
+		srcBuf - the input token sent by the remote peer
+		srcLen - the size of the input token
+		dstBuf - the output buffer for a new token
+		dstLen - the size of the output buffer
+	Returns: The size of the output token or an error code.
+*/
+sqInt sqConnectSSL(sqInt handle, char* srcBuf, sqInt srcLen, char *dstBuf, sqInt dstLen) {
+	SecBufferDesc *sbdIn = NULL;
+	SECURITY_STATUS ret;
+	SCHANNEL_CRED sc_cred = { 0 };
+	ULONG sslFlags, retFlags;
+	sqSSL *ssl = sslFromHandle(handle);
+
+	/* Verify state of session */
+	if(ssl == NULL || (ssl->state != SQSSL_UNUSED && ssl->state != SQSSL_CONNECTING)) {
+		return SQSSL_INVALID_STATE;
+	}
+
+	if(ssl->dataLen + srcLen > ssl->dataMax) {
+		/* resize the data buffer */
+		ssl->dataMax += (srcLen < 4096) ? (4096) : (srcLen+1024);
+		ssl->dataBuf = realloc(ssl->dataBuf, ssl->dataMax);
+		if(!ssl->dataBuf) return SQSSL_OUT_OF_MEMORY;
+	}
+	if(ssl->loglevel) printf("sqConnectSSL: input token %d bytes\n", srcLen);
+	memcpy(ssl->dataBuf + ssl->dataLen, srcBuf, srcLen);
+	ssl->dataLen += srcLen;
+
+	/* Standard flags for SSL connection */
+	sslFlags = 
+		ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_EXTENDED_ERROR |
+		ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT | ISC_REQ_STREAM |
+		ISC_REQ_MANUAL_CRED_VALIDATION;
+
+	/* Set up the input and output buffers */
+	ssl->inbuf[0].BufferType = SECBUFFER_TOKEN;
+	ssl->inbuf[0].cbBuffer = ssl->dataLen;
+	ssl->inbuf[0].pvBuffer = ssl->dataBuf;
+	ssl->inbuf[1].BufferType = SECBUFFER_EMPTY;
+	ssl->inbuf[1].cbBuffer = 0;
+	ssl->inbuf[1].pvBuffer = NULL;
+	ssl->inbuf[2].BufferType = SECBUFFER_EMPTY;
+	ssl->inbuf[2].cbBuffer = 0;
+	ssl->inbuf[2].pvBuffer = NULL;
+	ssl->inbuf[3].BufferType = SECBUFFER_EMPTY;
+	ssl->inbuf[3].cbBuffer = 0;
+	ssl->inbuf[3].pvBuffer = NULL;
+	ssl->sbdIn.cBuffers = 4;
+
+	ssl->outbuf[0].BufferType = SECBUFFER_EMPTY;
+	ssl->outbuf[0].cbBuffer = 0;
+	ssl->outbuf[0].pvBuffer = NULL;
+	ssl->outbuf[1].BufferType = SECBUFFER_EMPTY;
+	ssl->outbuf[1].cbBuffer = 0;
+	ssl->outbuf[1].pvBuffer = NULL;
+	ssl->sbdOut.cBuffers = 2;
+
+	if(ssl->loglevel) printf("sqConnectSSL: Input to InitSecCtxt is %d bytes\n", ssl->dataLen);
+
+	if(ssl->state == SQSSL_UNUSED) {
+		ssl->state = SQSSL_CONNECTING;
+
+		if(!sqSetupCert(ssl, ssl->certName, 0)) 
+			/* FIXME. We need a different error code here. */
+			return SQSSL_GENERIC_ERROR;
+
+		ret = InitializeSecurityContext(&ssl->sslCred, NULL, NULL,
+										sslFlags, 0, 0, NULL, 0, &ssl->sslCtxt,
+										&ssl->sbdOut, &retFlags, NULL);
+	} else {
+		ret = InitializeSecurityContext(&ssl->sslCred, &ssl->sslCtxt, NULL,
+										sslFlags, 0, 0, &ssl->sbdIn, 0, NULL,
+										&ssl->sbdOut, &retFlags, NULL);
+	}
+
+	if(ssl->loglevel) printf("InitializeSecurityContext returned: %x\n", ret);
+
+	if(ssl->loglevel) sqPrintSBD("Input Buffers:", ssl->sbdIn);
+	if(ssl->loglevel) sqPrintSBD("Output Buffers:", ssl->sbdOut);
+
+	if(ret != SEC_E_OK) {
+		int count;
+		/* Handle various failure conditions */
+		switch(ret) {
+			case SEC_I_CONTINUE_NEEDED:
+				/* Send contents back to peer and come back with more data */
+				count = sqCopyDescToken(ssl, ssl->sbdOut, dstBuf, dstLen);
+				/* Sanity checks for buffers */
+				if(ssl->inbuf[0].BufferType != SECBUFFER_TOKEN) {
+					if(ssl->loglevel) printf("sqConnectSSL: Unexpected buffer[0].BufferType -- %d\n", ssl->inbuf[0].BufferType);
+				}
+				if(ssl->inbuf[2].BufferType != SECBUFFER_EMPTY) {
+					if(ssl->loglevel) printf("sqConnectSSL: Unexpected buffer[2].BufferType -- %d\n", ssl->inbuf[0].BufferType);
+				}
+
+				/* If there is SECBUFFER_EXTRA in the input we need to retain it */
+				if(ssl->inbuf[1].BufferType == SECBUFFER_EXTRA) {
+					int extra = ssl->inbuf[1].cbBuffer;
+					if(ssl->loglevel) printf("sqConnectSSL: Retaining %d token bytes\n", extra);
+					memmove(ssl->dataBuf, ssl->dataBuf + (ssl->dataLen - extra), extra);
+					ssl->dataLen = extra;
+				} else ssl->dataLen = 0;
+
+				/* Don't return zero (SQSSL_OK) when more data is needed */
+				return count ? count : SQSSL_NEED_MORE_DATA;
+			default:
+				if(ssl->loglevel) printf("Unexpected return code %d\n", ret);
+				return SQSSL_GENERIC_ERROR;
+		}
+	}
+
+	/* TODO: Look at retFlags */
+	ssl->state = SQSSL_CONNECTED;
+	sqCopyExtraData(ssl, ssl->sbdOut);
+    ret = QueryContextAttributes(&ssl->sslCtxt, SECPKG_ATTR_STREAM_SIZES, &ssl->sslSizes);
+	if(ssl->loglevel) printf("sqConnectSSL: Maximum message size is %d bytes\n", ssl->sslSizes.cbMaximumMessage);
+
+	/* Extract the peer name */
+	sqExtractPeerName(ssl);
+
+	/* Verify the certificate (sets certFlags) */
+	sqVerifyCert(ssl, false);
+
+	return SQSSL_OK;
+}
+
+/* sqAcceptSSL: Start/continue an SSL server handshake.
+	Arguments:
+		handle - the SSL handle
+		srcBuf - the input token sent by the remote peer
+		srcLen - the size of the input token
+		dstBuf - the output buffer for a new token
+		dstLen - the size of the output buffer
+	Returns: The size of the output token or an error code.
+*/
+sqInt sqAcceptSSL(sqInt handle, char* srcBuf, sqInt srcLen, char *dstBuf, sqInt dstLen) {
+	SECURITY_STATUS ret;
+	SCHANNEL_CRED sc_cred = { 0 };
+ 	ULONG sslFlags, retFlags;
+	sqSSL *ssl = sslFromHandle(handle);
+
+	/* Verify state of session */
+	if(ssl == NULL || (ssl->state != SQSSL_UNUSED && ssl->state != SQSSL_ACCEPTING)) {
+		return SQSSL_INVALID_STATE;
+	}
+
+	/* Standard flags for SSL connection */
+	sslFlags = 
+		ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_CONFIDENTIALITY | ASC_REQ_EXTENDED_ERROR |
+		ASC_REQ_INTEGRITY | ASC_REQ_REPLAY_DETECT | ASC_REQ_STREAM;
+
+	ssl->inbuf[0].BufferType = SECBUFFER_TOKEN;
+	ssl->inbuf[0].cbBuffer = srcLen;
+	ssl->inbuf[0].pvBuffer = srcBuf;
+	ssl->inbuf[1].BufferType = SECBUFFER_EMPTY;
+	ssl->inbuf[1].cbBuffer = 0;
+	ssl->inbuf[1].pvBuffer = NULL;
+
+	ssl->sbdIn.cBuffers = 2;
+
+	ssl->outbuf[0].BufferType = SECBUFFER_EMPTY;
+	ssl->outbuf[0].cbBuffer = 0;
+	ssl->outbuf[0].pvBuffer = NULL;
+	ssl->outbuf[1].BufferType = SECBUFFER_EMPTY;
+	ssl->outbuf[1].cbBuffer = 0;
+	ssl->outbuf[1].pvBuffer = NULL;
+
+	ssl->sbdOut.cBuffers = 2;
+
+	if(ssl->state == SQSSL_UNUSED) {
+		ssl->state = SQSSL_ACCEPTING;
+
+		if(!sqSetupCert(ssl, ssl->certName, 1)) 
+			return SQSSL_GENERIC_ERROR;
+
+		ret = AcceptSecurityContext(&ssl->sslCred, NULL, &ssl->sbdIn, sslFlags,
+									SECURITY_NATIVE_DREP, &ssl->sslCtxt, &ssl->sbdOut,
+									&retFlags, NULL);
+	} else {
+		ret = AcceptSecurityContext(&ssl->sslCred, &ssl->sslCtxt, &ssl->sbdIn, sslFlags,
+									SECURITY_NATIVE_DREP, &ssl->sslCtxt, &ssl->sbdOut,
+									&retFlags, NULL);
+	}
+
+	if(ssl->loglevel) printf("AcceptSecurityContext returned: %x\n", ret);
+
+	if(ret != SEC_E_OK) {
+		/* Handle various failure conditions */
+		switch(ret) {
+			case SEC_I_CONTINUE_NEEDED:
+				/* Send contents back to peer and come back with more data */
+				return sqCopyDescToken(ssl, ssl->sbdOut, dstBuf, dstLen);
+			default:
+				if(ssl->loglevel) printf("Unexpected return code %d\n", ret);
+				return SQSSL_GENERIC_ERROR;
+		}
+	}
+
+	/* TODO: Look at retFlags */
+	ssl->state = SQSSL_CONNECTED;
+    ret = QueryContextAttributes(&ssl->sslCtxt, SECPKG_ATTR_STREAM_SIZES, &ssl->sslSizes);
+	if(ssl->loglevel) printf("sqAcceptSSL: Maximum message size is %d bytes\n", ssl->sslSizes.cbMaximumMessage);
+
+	/* Extract the peer name */
+	sqExtractPeerName(ssl);
+
+	/* Verify the certificate (sets certFlags) */
+	sqVerifyCert(ssl, true);
+
+	return sqCopyDescToken(ssl, ssl->sbdOut, dstBuf, dstLen);
+}
+
+/* sqEncryptSSL: Encrypt data for SSL transmission.
+	Arguments:
+		handle - the SSL handle
+		srcBuf - the unencrypted input data
+		srcLen - the size of the input data
+		dstBuf - the output buffer for the encrypted contents
+		dstLen - the size of the output buffer
+	Returns: The size of the output generated or an error code.
+*/
+sqInt sqEncryptSSL(sqInt handle, char* srcBuf, sqInt srcLen, char *dstBuf, sqInt dstLen) {
+	SECURITY_STATUS ret;
+	sqInt total;
+	sqSSL *ssl = sslFromHandle(handle);
+
+	if(ssl == NULL || ssl->state != SQSSL_CONNECTED) return SQSSL_INVALID_STATE;
+
+	if(ssl->loglevel) printf("sqEncryptSSL: Encrypting %d bytes\n", srcLen);
+
+	if(srcLen > (int)ssl->sslSizes.cbMaximumMessage) 
+		return SQSSL_INPUT_TOO_LARGE;
+
+	ssl->inbuf[0].BufferType = SECBUFFER_STREAM_HEADER;
+	ssl->inbuf[0].cbBuffer = ssl->sslSizes.cbHeader;
+	ssl->inbuf[0].pvBuffer = dstBuf;
+
+	ssl->inbuf[1].BufferType = SECBUFFER_DATA;
+	ssl->inbuf[1].cbBuffer = srcLen;
+	ssl->inbuf[1].pvBuffer = dstBuf + ssl->inbuf[0].cbBuffer;
+
+	ssl->inbuf[2].BufferType = SECBUFFER_STREAM_TRAILER;
+	ssl->inbuf[2].cbBuffer = ssl->sslSizes.cbTrailer;
+	ssl->inbuf[2].pvBuffer = dstBuf + ssl->inbuf[0].cbBuffer + ssl->inbuf[1].cbBuffer;
+
+	ssl->inbuf[3].BufferType = SECBUFFER_EMPTY;
+	ssl->inbuf[3].cbBuffer = 0;
+	ssl->inbuf[3].pvBuffer = NULL;  
+
+	ssl->sbdIn.cBuffers = 4;
+
+	/* Check to ensure that encrypted contents fits dstBuf.
+	   Fail with BUFFER_TOO_SMALL to allow caller to retry. */
+	total = ssl->inbuf[0].cbBuffer + ssl->inbuf[1].cbBuffer + ssl->inbuf[2].cbBuffer;
+	if(dstLen < total) return SQSSL_BUFFER_TOO_SMALL;
+
+	memcpy(ssl->inbuf[1].pvBuffer, srcBuf, srcLen);
+
+	if(ssl->loglevel) printf("Header: %d; Data: %d; Trailer: %d\n", 
+			ssl->inbuf[0].cbBuffer, ssl->inbuf[1].cbBuffer, ssl->inbuf[2].cbBuffer);
+
+	ret = EncryptMessage(&ssl->sslCtxt, 0, &ssl->sbdIn, 0);
+
+	if (ret != SEC_E_OK) {
+		if(ssl->loglevel) printf("EncryptMessage returned: %x\n", ret);
+		return SQSSL_GENERIC_ERROR;
+	}
+
+	/* Return total amount of encrypted contents.
+	   Must recompute total here since trailer may be overestimated */
+	total = ssl->inbuf[0].cbBuffer + ssl->inbuf[1].cbBuffer + ssl->inbuf[2].cbBuffer;
+	return total;
+}
+
+/* sqDecryptSSL: Decrypt data for SSL transmission.
+	Arguments:
+		handle - the SSL handle
+		srcBuf - the encrypted input data
+		srcLen - the size of the input data
+		dstBuf - the output buffer for the decrypted contents
+		dstLen - the size of the output buffer
+	Returns: The size of the output generated or an error code.
+*/
+sqInt sqDecryptSSL(sqInt handle, char* srcBuf, sqInt srcLen, char *dstBuf, sqInt dstLen) {
+	int i, total;
+	SECURITY_STATUS ret;
+	sqSSL *ssl = sslFromHandle(handle);
+
+	if(ssl == NULL || ssl->state != SQSSL_CONNECTED) return SQSSL_INVALID_STATE;
+
+	if(ssl->dataLen + srcLen > ssl->dataMax) {
+		/* resize the read buffer */
+		ssl->dataMax += (srcLen < 4096) ? (4096) : (srcLen+1024);
+		ssl->dataBuf = realloc(ssl->dataBuf, ssl->dataMax);
+		if(!ssl->dataBuf) return SQSSL_OUT_OF_MEMORY;
+	}
+	if(ssl->loglevel) printf("sqDecryptSSL: Input data %d bytes\n", srcLen);
+	memcpy(ssl->dataBuf + ssl->dataLen, srcBuf, srcLen);
+	ssl->dataLen += srcLen;
+
+	if(ssl->loglevel) printf("sqDecryptSSL: Decrypting %d bytes\n", ssl->dataLen);
+
+	ssl->inbuf[0].BufferType = SECBUFFER_DATA;
+	ssl->inbuf[0].cbBuffer = ssl->dataLen;
+	ssl->inbuf[0].pvBuffer = ssl->dataBuf;
+
+	ssl->inbuf[1].BufferType = SECBUFFER_EMPTY;
+	ssl->inbuf[1].cbBuffer = 0;
+	ssl->inbuf[1].pvBuffer = NULL;
+
+	ssl->inbuf[2].BufferType = SECBUFFER_EMPTY;
+	ssl->inbuf[2].cbBuffer = 0;
+	ssl->inbuf[2].pvBuffer = NULL;  
+
+	ssl->inbuf[3].BufferType = SECBUFFER_EMPTY;
+	ssl->inbuf[3].cbBuffer = 0;
+	ssl->inbuf[3].pvBuffer = NULL;  
+
+	ssl->sbdIn.cBuffers = 4;
+	ret = DecryptMessage(&ssl->sslCtxt, &ssl->sbdIn, 0, 0);
+
+	/* Copy the result into destination buffer */
+	total = 0;
+	for(i=0;i<4;i++) {
+		int buftype = ssl->inbuf[i].BufferType;
+		int count = ssl->inbuf[i].cbBuffer;
+		char *buffer = ssl->inbuf[i].pvBuffer;
+		if(ssl->loglevel) printf("buf[%d]: %d (%d bytes) ptr=%x\n", i,buftype, count, (int)buffer);
+		if(buftype == SECBUFFER_DATA) {
+			if(count > dstLen) return SQSSL_BUFFER_TOO_SMALL;
+			memcpy(dstBuf, buffer, count);
+			dstBuf += count;
+			total += count;
+			dstLen -= count;
+		}
+	}
+
+	/* We ran out of steam. Hopefully this was because we've produced
+	   a bunch'o bits from the decryption. */
+	if(ret == SEC_E_OK || ret == SEC_E_INCOMPLETE_MESSAGE) {
+		/* Retain any remaining extra buffers and return output */
+		sqCopyExtraData(ssl, ssl->sbdIn);
+		/* Return the total number of bytes decrypted */
+		return total;		
+	}
+
+	if(ssl->loglevel) printf("DecryptMessage returned: %x\n", ret);
+	return SQSSL_GENERIC_ERROR;
+}
+
+/* sqGetStringPropertySSL: Retrieve a string property from SSL.
+	Arguments:
+		handle - the ssl handle
+		propID - the property id to retrieve
+	Returns: The string value of the property.
+*/
+char* sqGetStringPropertySSL(sqInt handle, int propID) {
+	sqSSL *ssl = sslFromHandle(handle);
+
+	if(ssl == NULL) return NULL;
+	switch(propID) {
+		case SQSSL_PROP_PEERNAME:  return ssl->peerName;
+		case SQSSL_PROP_CERTNAME:  return ssl->certName;
+		default:
+			if(ssl->loglevel) printf("sqGetStringPropertySSL: Unknown property ID %d\n", propID);
+			return NULL;
+	}
+	return NULL;
+}
+
+/* sqSetStringPropertySSL: Set a string property in SSL.
+	Arguments:
+		handle - the ssl handle
+		propID - the property id to retrieve
+		propName - the property string
+		propLen  - the length of the property string
+	Returns: Non-zero if successful.
+*/
+sqInt sqSetStringPropertySSL(sqInt handle, int propID, char *propName, sqInt propLen) {
+	sqSSL *ssl = sslFromHandle(handle);
+	char *property = NULL;
+
+	if(ssl == NULL) return 0;
+
+	if(propLen) {
+	  property = calloc(1, propLen+1);
+	  memcpy(property, propName, propLen);
+	};
+
+	if(ssl->loglevel) printf("sqSetStringPropertySSL(%d): %s\n", propID, property);
+
+	switch(propID) {
+		case SQSSL_PROP_CERTNAME: ssl->certName = property; break;
+		default: 
+			if(ssl->loglevel) printf("sqSetStringPropertySSL: Unknown property ID %d\n", propID);
+			return 0;
+	}
+	return 1;
+}
+
+/* sqGetIntPropertySSL: Retrieve an integer property from SSL.
+	Arguments:
+		handle - the ssl handle
+		propID - the property id to retrieve
+	Returns: The integer value of the property.
+*/
+int sqGetIntPropertySSL(sqInt handle, int propID) {
+	sqSSL *ssl = sslFromHandle(handle);
+
+	if(ssl == NULL) return 0;
+	switch(propID) {
+		case SQSSL_PROP_SSLSTATE: return ssl->state;
+		case SQSSL_PROP_CERTSTATE: return ssl->certFlags;
+		case SQSSL_PROP_VERSION: return 1;
+		case SQSSL_PROP_LOGLEVEL: return ssl->loglevel;
+		default:
+			if(ssl->loglevel) printf("sqGetIntPropertySSL: Unknown property ID %d\n", propID);
+			return 0;
+	}
+	return 0;
+}
+
+/* sqSetIntPropertySSL: Set an integer property in SSL.
+	Arguments:
+		handle - the ssl handle
+		propID - the property id to retrieve
+		propValue - the property value
+	Returns: Non-zero if successful.
+*/
+sqInt sqSetIntPropertySSL(sqInt handle, sqInt propID, sqInt propValue) {
+	sqSSL *ssl = sslFromHandle(handle);
+	if(ssl == NULL) return 0;
+
+	switch(propID) {
+		case SQSSL_PROP_LOGLEVEL: ssl->loglevel = propValue; break;
+		default:
+			if(ssl->loglevel) printf("sqSetIntPropertySSL: Unknown property ID %d\n", propID);
+			return 0;
+	}
+	return 0;
+}



More information about the Vm-dev mailing list