Browse Source

OpenSSL: Add wrapper struct for tls_init() result

This new struct tls_data is needed to store per-tls_init() information
in the followup commits.

Signed-off-by: Jouni Malinen <j@w1.fi>
Jouni Malinen 9 years ago
parent
commit
bd9b8b2b68
1 changed files with 76 additions and 50 deletions
  1. 76 50
      src/crypto/tls_openssl.c

+ 76 - 50
src/crypto/tls_openssl.c

@@ -86,6 +86,10 @@ struct tls_context {
 static struct tls_context *tls_global = NULL;
 
 
+struct tls_data {
+	SSL_CTX *ssl;
+};
+
 struct tls_connection {
 	struct tls_context *context;
 	SSL_CTX *ssl_ctx;
@@ -746,6 +750,7 @@ static int tls_engine_load_dynamic_opensc(const char *opensc_so_path)
 
 void * tls_init(const struct tls_config *conf)
 {
+	struct tls_data *data;
 	SSL_CTX *ssl;
 	struct tls_context *context;
 	const char *ciphers;
@@ -810,7 +815,11 @@ void * tls_init(const struct tls_config *conf)
 	}
 	tls_openssl_ref_count++;
 
-	ssl = SSL_CTX_new(SSLv23_method());
+	data = os_zalloc(sizeof(*data));
+	if (data)
+		ssl = SSL_CTX_new(SSLv23_method());
+	else
+		ssl = NULL;
 	if (ssl == NULL) {
 		tls_openssl_ref_count--;
 		if (context != tls_global)
@@ -821,6 +830,7 @@ void * tls_init(const struct tls_config *conf)
 		}
 		return NULL;
 	}
+	data->ssl = ssl;
 
 	SSL_CTX_set_options(ssl, SSL_OP_NO_SSLv2);
 	SSL_CTX_set_options(ssl, SSL_OP_NO_SSLv3);
@@ -839,7 +849,7 @@ void * tls_init(const struct tls_config *conf)
 		if (tls_engine_load_dynamic_opensc(conf->opensc_engine_path) ||
 		    tls_engine_load_dynamic_pkcs11(conf->pkcs11_engine_path,
 						   conf->pkcs11_module_path)) {
-			tls_deinit(ssl);
+			tls_deinit(data);
 			return NULL;
 		}
 	}
@@ -853,17 +863,18 @@ void * tls_init(const struct tls_config *conf)
 		wpa_printf(MSG_ERROR,
 			   "OpenSSL: Failed to set cipher string '%s'",
 			   ciphers);
-		tls_deinit(ssl);
+		tls_deinit(data);
 		return NULL;
 	}
 
-	return ssl;
+	return data;
 }
 
 
 void tls_deinit(void *ssl_ctx)
 {
-	SSL_CTX *ssl = ssl_ctx;
+	struct tls_data *data = ssl_ctx;
+	SSL_CTX *ssl = data->ssl;
 	struct tls_context *context = SSL_CTX_get_app_data(ssl);
 	if (context != tls_global)
 		os_free(context);
@@ -883,6 +894,8 @@ void tls_deinit(void *ssl_ctx)
 		os_free(tls_global);
 		tls_global = NULL;
 	}
+
+	os_free(data);
 }
 
 
@@ -1058,7 +1071,8 @@ static void tls_msg_cb(int write_p, int version, int content_type,
 
 struct tls_connection * tls_connection_init(void *ssl_ctx)
 {
-	SSL_CTX *ssl = ssl_ctx;
+	struct tls_data *data = ssl_ctx;
+	SSL_CTX *ssl = data->ssl;
 	struct tls_connection *conn;
 	long options;
 	struct tls_context *context = SSL_CTX_get_app_data(ssl);
@@ -1066,7 +1080,7 @@ struct tls_connection * tls_connection_init(void *ssl_ctx)
 	conn = os_zalloc(sizeof(*conn));
 	if (conn == NULL)
 		return NULL;
-	conn->ssl_ctx = ssl_ctx;
+	conn->ssl_ctx = ssl;
 	conn->ssl = SSL_new(ssl);
 	if (conn->ssl == NULL) {
 		tls_show_errors(MSG_INFO, __func__,
@@ -1641,9 +1655,9 @@ static int tls_verify_cb(int preverify_ok, X509_STORE_CTX *x509_ctx)
 
 
 #ifndef OPENSSL_NO_STDIO
-static int tls_load_ca_der(void *_ssl_ctx, const char *ca_cert)
+static int tls_load_ca_der(struct tls_data *data, const char *ca_cert)
 {
-	SSL_CTX *ssl_ctx = _ssl_ctx;
+	SSL_CTX *ssl_ctx = data->ssl;
 	X509_LOOKUP *lookup;
 	int ret = 0;
 
@@ -1673,11 +1687,12 @@ static int tls_load_ca_der(void *_ssl_ctx, const char *ca_cert)
 #endif /* OPENSSL_NO_STDIO */
 
 
-static int tls_connection_ca_cert(void *_ssl_ctx, struct tls_connection *conn,
+static int tls_connection_ca_cert(struct tls_data *data,
+				  struct tls_connection *conn,
 				  const char *ca_cert, const u8 *ca_cert_blob,
 				  size_t ca_cert_blob_len, const char *ca_path)
 {
-	SSL_CTX *ssl_ctx = _ssl_ctx;
+	SSL_CTX *ssl_ctx = data->ssl;
 	X509_STORE *store;
 
 	/*
@@ -1812,7 +1827,7 @@ static int tls_connection_ca_cert(void *_ssl_ctx, struct tls_connection *conn,
 			tls_show_errors(MSG_WARNING, __func__,
 					"Failed to load root certificates");
 			if (ca_cert &&
-			    tls_load_ca_der(ssl_ctx, ca_cert) == 0) {
+			    tls_load_ca_der(data, ca_cert) == 0) {
 				wpa_printf(MSG_DEBUG, "OpenSSL: %s - loaded "
 					   "DER format CA certificate",
 					   __func__);
@@ -1821,7 +1836,7 @@ static int tls_connection_ca_cert(void *_ssl_ctx, struct tls_connection *conn,
 		} else {
 			wpa_printf(MSG_DEBUG, "TLS: Trusted root "
 				   "certificate(s) loaded");
-			tls_get_errors(ssl_ctx);
+			tls_get_errors(data);
 		}
 #else /* OPENSSL_NO_STDIO */
 		wpa_printf(MSG_DEBUG, "OpenSSL: %s - OPENSSL_NO_STDIO",
@@ -1838,8 +1853,10 @@ static int tls_connection_ca_cert(void *_ssl_ctx, struct tls_connection *conn,
 }
 
 
-static int tls_global_ca_cert(SSL_CTX *ssl_ctx, const char *ca_cert)
+static int tls_global_ca_cert(struct tls_data *data, const char *ca_cert)
 {
+	SSL_CTX *ssl_ctx = data->ssl;
+
 	if (ca_cert) {
 		if (SSL_CTX_load_verify_locations(ssl_ctx, ca_cert, NULL) != 1)
 		{
@@ -1867,7 +1884,8 @@ int tls_global_set_verify(void *ssl_ctx, int check_crl)
 	int flags;
 
 	if (check_crl) {
-		X509_STORE *cs = SSL_CTX_get_cert_store(ssl_ctx);
+		struct tls_data *data = ssl_ctx;
+		X509_STORE *cs = SSL_CTX_get_cert_store(data->ssl);
 		if (cs == NULL) {
 			tls_show_errors(MSG_INFO, __func__, "Failed to get "
 					"certificate store when enabling "
@@ -2028,9 +2046,12 @@ static int tls_connection_client_cert(struct tls_connection *conn,
 }
 
 
-static int tls_global_client_cert(SSL_CTX *ssl_ctx, const char *client_cert)
+static int tls_global_client_cert(struct tls_data *data,
+				  const char *client_cert)
 {
 #ifndef OPENSSL_NO_STDIO
+	SSL_CTX *ssl_ctx = data->ssl;
+
 	if (client_cert == NULL)
 		return 0;
 
@@ -2064,7 +2085,7 @@ static int tls_passwd_cb(char *buf, int size, int rwflag, void *password)
 
 
 #ifdef PKCS12_FUNCS
-static int tls_parse_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, PKCS12 *p12,
+static int tls_parse_pkcs12(struct tls_data *data, SSL *ssl, PKCS12 *p12,
 			    const char *passwd)
 {
 	EVP_PKEY *pkey;
@@ -2095,7 +2116,7 @@ static int tls_parse_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, PKCS12 *p12,
 			if (SSL_use_certificate(ssl, cert) != 1)
 				res = -1;
 		} else {
-			if (SSL_CTX_use_certificate(ssl_ctx, cert) != 1)
+			if (SSL_CTX_use_certificate(data->ssl, cert) != 1)
 				res = -1;
 		}
 		X509_free(cert);
@@ -2107,7 +2128,7 @@ static int tls_parse_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, PKCS12 *p12,
 			if (SSL_use_PrivateKey(ssl, pkey) != 1)
 				res = -1;
 		} else {
-			if (SSL_CTX_use_PrivateKey(ssl_ctx, pkey) != 1)
+			if (SSL_CTX_use_PrivateKey(data->ssl, pkey) != 1)
 				res = -1;
 		}
 		EVP_PKEY_free(pkey);
@@ -2146,7 +2167,7 @@ static int tls_parse_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, PKCS12 *p12,
 		res = 0;
 #else /* OPENSSL_VERSION_NUMBER >= 0x10002000L */
 #if OPENSSL_VERSION_NUMBER >= 0x10001000L
-		SSL_CTX_clear_extra_chain_certs(ssl_ctx);
+		SSL_CTX_clear_extra_chain_certs(data->ssl);
 #endif /* OPENSSL_VERSION_NUMBER >= 0x10001000L */
 		while ((cert = sk_X509_pop(certs)) != NULL) {
 			X509_NAME_oneline(X509_get_subject_name(cert), buf,
@@ -2157,7 +2178,8 @@ static int tls_parse_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, PKCS12 *p12,
 			 * There is no SSL equivalent for the chain cert - so
 			 * always add it to the context...
 			 */
-			if (SSL_CTX_add_extra_chain_cert(ssl_ctx, cert) != 1) {
+			if (SSL_CTX_add_extra_chain_cert(data->ssl, cert) != 1)
+			{
 				res = -1;
 				break;
 			}
@@ -2169,15 +2191,15 @@ static int tls_parse_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, PKCS12 *p12,
 	PKCS12_free(p12);
 
 	if (res < 0)
-		tls_get_errors(ssl_ctx);
+		tls_get_errors(data);
 
 	return res;
 }
 #endif  /* PKCS12_FUNCS */
 
 
-static int tls_read_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, const char *private_key,
-			   const char *passwd)
+static int tls_read_pkcs12(struct tls_data *data, SSL *ssl,
+			   const char *private_key, const char *passwd)
 {
 #ifdef PKCS12_FUNCS
 	FILE *f;
@@ -2196,7 +2218,7 @@ static int tls_read_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, const char *private_key,
 		return -1;
 	}
 
-	return tls_parse_pkcs12(ssl_ctx, ssl, p12, passwd);
+	return tls_parse_pkcs12(data, ssl, p12, passwd);
 
 #else /* PKCS12_FUNCS */
 	wpa_printf(MSG_INFO, "TLS: PKCS12 support disabled - cannot read "
@@ -2206,7 +2228,7 @@ static int tls_read_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, const char *private_key,
 }
 
 
-static int tls_read_pkcs12_blob(SSL_CTX *ssl_ctx, SSL *ssl,
+static int tls_read_pkcs12_blob(struct tls_data *data, SSL *ssl,
 				const u8 *blob, size_t len, const char *passwd)
 {
 #ifdef PKCS12_FUNCS
@@ -2219,7 +2241,7 @@ static int tls_read_pkcs12_blob(SSL_CTX *ssl_ctx, SSL *ssl,
 		return -1;
 	}
 
-	return tls_parse_pkcs12(ssl_ctx, ssl, p12, passwd);
+	return tls_parse_pkcs12(data, ssl, p12, passwd);
 
 #else /* PKCS12_FUNCS */
 	wpa_printf(MSG_INFO, "TLS: PKCS12 support disabled - cannot parse "
@@ -2290,13 +2312,13 @@ static int tls_connection_engine_client_cert(struct tls_connection *conn,
 }
 
 
-static int tls_connection_engine_ca_cert(void *_ssl_ctx,
+static int tls_connection_engine_ca_cert(struct tls_data *data,
 					 struct tls_connection *conn,
 					 const char *ca_cert_id)
 {
 #ifndef OPENSSL_NO_ENGINE
 	X509 *cert;
-	SSL_CTX *ssl_ctx = _ssl_ctx;
+	SSL_CTX *ssl_ctx = data->ssl;
 	X509_STORE *store;
 
 	if (tls_engine_get_cert(conn, ca_cert_id, &cert))
@@ -2362,14 +2384,14 @@ static int tls_connection_engine_private_key(struct tls_connection *conn)
 }
 
 
-static int tls_connection_private_key(void *_ssl_ctx,
+static int tls_connection_private_key(struct tls_data *data,
 				      struct tls_connection *conn,
 				      const char *private_key,
 				      const char *private_key_passwd,
 				      const u8 *private_key_blob,
 				      size_t private_key_blob_len)
 {
-	SSL_CTX *ssl_ctx = _ssl_ctx;
+	SSL_CTX *ssl_ctx = data->ssl;
 	char *passwd;
 	int ok;
 
@@ -2415,7 +2437,7 @@ static int tls_connection_private_key(void *_ssl_ctx,
 			break;
 		}
 
-		if (tls_read_pkcs12_blob(ssl_ctx, conn->ssl, private_key_blob,
+		if (tls_read_pkcs12_blob(data, conn->ssl, private_key_blob,
 					 private_key_blob_len, passwd) == 0) {
 			wpa_printf(MSG_DEBUG, "OpenSSL: PKCS#12 as blob --> "
 				   "OK");
@@ -2448,7 +2470,7 @@ static int tls_connection_private_key(void *_ssl_ctx,
 			   __func__);
 #endif /* OPENSSL_NO_STDIO */
 
-		if (tls_read_pkcs12(ssl_ctx, conn->ssl, private_key, passwd)
+		if (tls_read_pkcs12(data, conn->ssl, private_key, passwd)
 		    == 0) {
 			wpa_printf(MSG_DEBUG, "OpenSSL: Reading PKCS#12 file "
 				   "--> OK");
@@ -2487,9 +2509,11 @@ static int tls_connection_private_key(void *_ssl_ctx,
 }
 
 
-static int tls_global_private_key(SSL_CTX *ssl_ctx, const char *private_key,
+static int tls_global_private_key(struct tls_data *data,
+				  const char *private_key,
 				  const char *private_key_passwd)
 {
+	SSL_CTX *ssl_ctx = data->ssl;
 	char *passwd;
 
 	if (private_key == NULL)
@@ -2511,7 +2535,7 @@ static int tls_global_private_key(SSL_CTX *ssl_ctx, const char *private_key,
 	    SSL_CTX_use_PrivateKey_file(ssl_ctx, private_key,
 					SSL_FILETYPE_PEM) != 1 &&
 #endif /* OPENSSL_NO_STDIO */
-	    tls_read_pkcs12(ssl_ctx, NULL, private_key, passwd)) {
+	    tls_read_pkcs12(data, NULL, private_key, passwd)) {
 		tls_show_errors(MSG_INFO, __func__,
 				"Failed to load private key");
 		os_free(passwd);
@@ -2606,7 +2630,7 @@ static int tls_connection_dh(struct tls_connection *conn, const char *dh_file)
 }
 
 
-static int tls_global_dh(SSL_CTX *ssl_ctx, const char *dh_file)
+static int tls_global_dh(struct tls_data *data, const char *dh_file)
 {
 #ifdef OPENSSL_NO_DH
 	if (dh_file == NULL)
@@ -2615,6 +2639,7 @@ static int tls_global_dh(SSL_CTX *ssl_ctx, const char *dh_file)
 		   "dh_file specified");
 	return -1;
 #else /* OPENSSL_NO_DH */
+	SSL_CTX *ssl_ctx = data->ssl;
 	DH *dh;
 	BIO *bio;
 
@@ -2778,7 +2803,7 @@ static int openssl_get_keyblock_size(SSL *ssl)
 #endif /* CONFIG_FIPS */
 
 
-static int openssl_tls_prf(void *tls_ctx, struct tls_connection *conn,
+static int openssl_tls_prf(struct tls_connection *conn,
 			   const char *label, int server_random_first,
 			   int skip_keyblock, u8 *out, size_t out_len)
 {
@@ -2946,7 +2971,7 @@ int tls_connection_prf(void *tls_ctx, struct tls_connection *conn,
 	if (conn == NULL)
 		return -1;
 	if (server_random_first || skip_keyblock)
-		return openssl_tls_prf(tls_ctx, conn, label,
+		return openssl_tls_prf(conn, label,
 				       server_random_first, skip_keyblock,
 				       out, out_len);
 	ssl = conn->ssl;
@@ -2956,7 +2981,7 @@ int tls_connection_prf(void *tls_ctx, struct tls_connection *conn,
 		return 0;
 	}
 #endif
-	return openssl_tls_prf(tls_ctx, conn, label, server_random_first,
+	return openssl_tls_prf(conn, label, server_random_first,
 			       skip_keyblock, out, out_len);
 }
 
@@ -3633,6 +3658,7 @@ static int ocsp_status_cb(SSL *s, void *arg)
 int tls_connection_set_params(void *tls_ctx, struct tls_connection *conn,
 			      const struct tls_connection_params *params)
 {
+	struct tls_data *data = tls_ctx;
 	int ret;
 	unsigned long err;
 	int can_pkcs11 = 0;
@@ -3708,10 +3734,9 @@ int tls_connection_set_params(void *tls_ctx, struct tls_connection *conn,
 		return -1;
 
 	if (engine_id && ca_cert_id) {
-		if (tls_connection_engine_ca_cert(tls_ctx, conn,
-						  ca_cert_id))
+		if (tls_connection_engine_ca_cert(data, conn, ca_cert_id))
 			return TLS_SET_PARAMS_ENGINE_PRV_VERIFY_FAILED;
-	} else if (tls_connection_ca_cert(tls_ctx, conn, params->ca_cert,
+	} else if (tls_connection_ca_cert(data, conn, params->ca_cert,
 					  params->ca_cert_blob,
 					  params->ca_cert_blob_len,
 					  params->ca_path))
@@ -3729,7 +3754,7 @@ int tls_connection_set_params(void *tls_ctx, struct tls_connection *conn,
 		wpa_printf(MSG_DEBUG, "TLS: Using private key from engine");
 		if (tls_connection_engine_private_key(conn))
 			return TLS_SET_PARAMS_ENGINE_PRV_VERIFY_FAILED;
-	} else if (tls_connection_private_key(tls_ctx, conn,
+	} else if (tls_connection_private_key(data, conn,
 					      params->private_key,
 					      params->private_key_passwd,
 					      params->private_key_blob,
@@ -3783,7 +3808,7 @@ int tls_connection_set_params(void *tls_ctx, struct tls_connection *conn,
 
 #ifdef HAVE_OCSP
 	if (params->flags & TLS_CONN_REQUEST_OCSP) {
-		SSL_CTX *ssl_ctx = tls_ctx;
+		SSL_CTX *ssl_ctx = data->ssl;
 		SSL_set_tlsext_status_type(conn->ssl, TLSEXT_STATUSTYPE_ocsp);
 		SSL_CTX_set_tlsext_status_cb(ssl_ctx, ocsp_resp_cb);
 		SSL_CTX_set_tlsext_status_arg(ssl_ctx, conn);
@@ -3802,7 +3827,7 @@ int tls_connection_set_params(void *tls_ctx, struct tls_connection *conn,
 
 	conn->flags = params->flags;
 
-	tls_get_errors(tls_ctx);
+	tls_get_errors(data);
 
 	return 0;
 }
@@ -3811,7 +3836,8 @@ int tls_connection_set_params(void *tls_ctx, struct tls_connection *conn,
 int tls_global_set_params(void *tls_ctx,
 			  const struct tls_connection_params *params)
 {
-	SSL_CTX *ssl_ctx = tls_ctx;
+	struct tls_data *data = tls_ctx;
+	SSL_CTX *ssl_ctx = data->ssl;
 	unsigned long err;
 
 	while ((err = ERR_get_error())) {
@@ -3819,11 +3845,11 @@ int tls_global_set_params(void *tls_ctx,
 			   __func__, ERR_error_string(err, NULL));
 	}
 
-	if (tls_global_ca_cert(ssl_ctx, params->ca_cert) ||
-	    tls_global_client_cert(ssl_ctx, params->client_cert) ||
-	    tls_global_private_key(ssl_ctx, params->private_key,
+	if (tls_global_ca_cert(data, params->ca_cert) ||
+	    tls_global_client_cert(data, params->client_cert) ||
+	    tls_global_private_key(data, params->private_key,
 				   params->private_key_passwd) ||
-	    tls_global_dh(ssl_ctx, params->dh_file)) {
+	    tls_global_dh(data, params->dh_file)) {
 		wpa_printf(MSG_INFO, "TLS: Failed to set global parameters");
 		return -1;
 	}