Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save mumumu/26b5fa85dc5ca31ba1f306dad3fa5573 to your computer and use it in GitHub Desktop.
Save mumumu/26b5fa85dc5ca31ba1f306dad3fa5573 to your computer and use it in GitHub Desktop.
diff --git a/tcpd/libcouriertls.c b/tcpd/libcouriertls.c
index 199015e..b24ccc3 100644
--- a/tcpd/libcouriertls.c
+++ b/tcpd/libcouriertls.c
@@ -333,6 +333,7 @@ static void load_dh_params(SSL_CTX *ctx, const char *filename,
}
static int read_certfile(SSL_CTX *ctx, const char *filename,
+ const char *private_key_filename,
int *cert_file_flags)
{
const struct tls_info *info=SSL_CTX_get_app_data(ctx);
@@ -353,36 +354,69 @@ static int read_certfile(SSL_CTX *ctx, const char *filename,
return (1);
}
-static int process_certfile(SSL_CTX *ctx, const char *certfile, const char *ip,
+static int check_readable_file(const char *filename)
+{
+ return (access(filename, R_OK) == 0) ? 1 : 0;
+}
+
+static char *get_ip_concated_readable_file(SSL_CTX *ctx, const char *filename, const char *ip)
+{
+ if (!filename || !ip) return NULL;
+
+ char *test_file;
+ const struct tls_info *info=SSL_CTX_get_app_data(ctx);
+
+ test_file= malloc(strlen(filename)+strlen(ip)+2);
+ if (!test_file)
+ {
+ nonsslerror(info, "malloc");
+ exit(1);
+ }
+
+ strcpy(test_file, filename);
+ strcat(test_file, ".");
+ strcat(test_file, ip);
+
+ if (check_readable_file(test_file)) {
+ return test_file;
+ }
+ return NULL;
+}
+
+static int process_certfile(SSL_CTX *ctx, const char *certfile,
+ const char *private_key_file,
+ const char *ip,
int (*func)(SSL_CTX *, const char *,
+ const char *,
int *),
int *cert_file_flags)
{
if (ip && *ip)
{
char *test_file;
+ char *test_private_key_file;
if (strncmp(ip, "::ffff:", 7) == 0 && strchr(ip, '.'))
- return (process_certfile(ctx, certfile, ip+7, func, cert_file_flags));
-
- test_file= malloc(strlen(certfile)+strlen(ip)+2);
-
- strcpy(test_file, certfile);
- strcat(test_file, ".");
- strcat(test_file, ip);
+ return (process_certfile(ctx, certfile, private_key_file, ip+7, func, cert_file_flags));
- if (access(test_file, R_OK) == 0)
+ test_file= get_ip_concated_readable_file(ctx, certfile, ip);
+ test_private_key_file = get_ip_concated_readable_file(ctx, private_key_file, ip);
+ if (test_file != NULL)
{
- int rc= (*func)(ctx, test_file,
+ int rc= (*func)(ctx, test_file, test_private_key_file,
cert_file_flags);
free(test_file);
+ if (test_private_key_file) free(test_private_key_file);
+
return rc;
}
free(test_file);
+ if (test_private_key_file) free(test_private_key_file);
}
- return (*func)(ctx, certfile, cert_file_flags);
+ private_key_file = check_readable_file(private_key_file) ? private_key_file : NULL;
+ return (*func)(ctx, certfile, private_key_file, cert_file_flags);
}
static int client_cert_cb(ssl_handle ssl, X509 **x509, EVP_PKEY **pkey)
@@ -495,30 +529,26 @@ static int client_cert_cb(ssl_handle ssl, X509 **x509, EVP_PKEY **pkey)
static SSL_CTX *tls_create_int(int isserver, const struct tls_info *info,
int internal);
-static int server_cert_cb(ssl_handle ssl, int *ad, void *arg)
-{
#ifdef HAVE_OPENSSL_SNI
- struct tls_info *info=(struct tls_info *)SSL_get_app_data(ssl);
- const char *servername=SSL_get_servername(ssl,
- TLSEXT_NAMETYPE_host_name);
- const char *certfile=safe_getenv(info, "TLS_CERTFILE");
- int cert_file_flags=0;
- char *buffer;
+static char *get_servername_concated_readable_file(const char *filename,
+ const char *servername,
+ struct tls_info *info)
+{
+ char *filename_buffer;
char *p;
- if (!servername || !certfile)
- return SSL_TLSEXT_ERR_OK;
+ if (!filename || !servername) return NULL;
- buffer=malloc(strlen(certfile)+strlen(servername)+2);
- if (!buffer)
+ filename_buffer=malloc(strlen(filename)+strlen(servername)+2);
+ if (!filename_buffer)
{
nonsslerror(info, "malloc");
exit(1);
}
- strcat(strcpy(buffer, certfile), ".");
+ strcat(strcpy(filename_buffer, filename), ".");
- p=buffer + strlen(buffer);
+ p=filename_buffer + strlen(filename_buffer);
while ((*p=*servername) != 0)
{
@@ -527,8 +557,31 @@ static int server_cert_cb(ssl_handle ssl, int *ad, void *arg)
++p;
++servername;
}
+ if (check_readable_file(filename_buffer)) {
+ return filename_buffer;
+ }
+ return NULL;
+}
+#endif
+
+static int server_cert_cb(ssl_handle ssl, int *ad, void *arg)
+{
+#ifdef HAVE_OPENSSL_SNI
+ struct tls_info *info=(struct tls_info *)SSL_get_app_data(ssl);
+ const char *servername=SSL_get_servername(ssl,
+ TLSEXT_NAMETYPE_host_name);
+ const char *certfile=safe_getenv(info, "TLS_CERTFILE");
+ const char *private_keyfile=safe_getenv(info, "TLS_PRIVATE_KEYFILE");
+ int cert_file_flags=0;
+ char *cert_file_buffer;
+ char *private_keyfile_buffer;
- if (access(buffer, R_OK) == 0)
+ if (!servername || !certfile)
+ return SSL_TLSEXT_ERR_OK;
+
+ cert_file_buffer = get_servername_concated_readable_file(certfile, servername, info);
+ private_keyfile_buffer = get_servername_concated_readable_file(private_keyfile, servername, info);
+ if (cert_file_buffer != NULL)
{
SSL_CTX *orig_ctx=SSL_get_SSL_CTX(ssl);
SSL_CTX *temp_ctx=tls_create_int(1, info, 1);
@@ -541,7 +594,7 @@ static int server_cert_cb(ssl_handle ssl, int *ad, void *arg)
exit(1);
}
SSL_set_SSL_CTX(ssl, temp_ctx);
- rc=read_certfile(orig_ctx, buffer, &cert_file_flags);
+ rc=read_certfile(orig_ctx, cert_file_buffer, private_keyfile_buffer, &cert_file_flags);
SSL_set_SSL_CTX(ssl, orig_ctx);
tls_destroy(temp_ctx);
if (!rc)
@@ -551,7 +604,8 @@ static int server_cert_cb(ssl_handle ssl, int *ad, void *arg)
exit(1);
}
}
- free(buffer);
+ free(cert_file_buffer);
+ free(private_keyfile_buffer);
#endif
return SSL_TLSEXT_ERR_OK;
@@ -571,6 +625,7 @@ SSL_CTX *tls_create_int(int isserver, const struct tls_info *info,
int session_timeout=atoi(safe_getenv(info, "TLS_TIMEOUT"));
const char *dhparamsfile=safe_getenv(info, "TLS_DHPARAMS");
const char *certfile=safe_getenv(info, "TLS_CERTFILE");
+ const char *private_keyfile=safe_getenv(info, "TLS_PRIVATE_KEYFILE");
const char *s;
struct stat stat_buf;
const char *peer_cert_dir=NULL;
@@ -588,6 +643,9 @@ SSL_CTX *tls_create_int(int isserver, const struct tls_info *info,
if (!*certfile)
certfile=NULL;
+ if (!*private_keyfile)
+ private_keyfile=NULL;
+
if (!*dhparamsfile)
dhparamsfile=NULL;
@@ -697,7 +755,7 @@ SSL_CTX *tls_create_int(int isserver, const struct tls_info *info,
if (dhparamsfile)
load_dh_params(ctx, dhparamsfile, &cert_file_flags);
- if (certfile && !process_certfile(ctx, certfile, s,
+ if (certfile && !process_certfile(ctx, certfile, private_keyfile, s,
read_certfile,
&cert_file_flags))
{
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment