From 72f0e01f5d6d0f75686b049ee11b9bd5aa7c10f9 Mon Sep 17 00:00:00 2001 From: Mark Andrews Date: Fri, 13 Dec 2019 13:58:47 +1100 Subject: [PATCH] Address dns_zt_asyncload races by properly using isc_reference_*. --- lib/dns/zt.c | 100 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 65 insertions(+), 35 deletions(-) diff --git a/lib/dns/zt.c b/lib/dns/zt.c index 69e1f10e6f..5b410b1831 100644 --- a/lib/dns/zt.c +++ b/lib/dns/zt.c @@ -285,36 +285,70 @@ load(dns_zone_t *zone, void *paramsv) { return (result); } +static void +call_loaddone(dns_zt_t *zt) { + dns_zt_allloaded_t loaddone = zt->loaddone; + void *loaddone_arg = zt->loaddone_arg; + + /* + * Set zt->loaddone, zt->loaddone_arg and zt->loadparams to NULL + * before calling loaddone. + */ + zt->loaddone = NULL; + zt->loaddone_arg = NULL; + + isc_mem_put(zt->mctx, zt->loadparams, sizeof(struct zt_load_params)); + zt->loadparams = NULL; + + /* + * Call the callback last. + */ + if (loaddone != NULL) { + loaddone(loaddone_arg); + } +} + isc_result_t dns_zt_asyncload(dns_zt_t *zt, bool newonly, - dns_zt_allloaded_t alldone, void *arg) { + dns_zt_allloaded_t alldone, void *arg) +{ isc_result_t result; - int pending; + uint_fast32_t loads_pending; REQUIRE(VALID_ZT(zt)); + + /* + * Obtain a reference to zt->loads_pending so that asyncload can + * safely decrement both zt->references and zt->loads_pending + * without going to zero. + */ + loads_pending = isc_refcount_increment0(&zt->loads_pending); + INSIST(loads_pending == 0); + + /* + * Only one dns_zt_asyncload call at a time should be active so + * these pointers should be NULL. They are set back to NULL + * before the zt->loaddone (alldone) is called in call_loaddone. + */ + INSIST(zt->loadparams == NULL); + INSIST(zt->loaddone == NULL); + INSIST(zt->loaddone_arg == NULL); + zt->loadparams = isc_mem_get(zt->mctx, sizeof(struct zt_load_params)); zt->loadparams->dl = doneloading; zt->loadparams->newonly = newonly; + zt->loaddone = alldone; + zt->loaddone_arg = arg; - RWLOCK(&zt->rwlock, isc_rwlocktype_write); - - INSIST(isc_refcount_current(&zt->loads_pending) == 0); - + RWLOCK(&zt->rwlock, isc_rwlocktype_read); result = dns_zt_apply(zt, false, NULL, asyncload, zt); + RWUNLOCK(&zt->rwlock, isc_rwlocktype_read); - pending = isc_refcount_current(&zt->loads_pending); - - if (pending != 0) { - zt->loaddone = alldone; - zt->loaddone_arg = arg; - } - - RWUNLOCK(&zt->rwlock, isc_rwlocktype_write); - - if (pending == 0) { - isc_mem_put(zt->mctx, zt->loadparams, sizeof(struct zt_load_params)); - zt->loadparams = NULL; - alldone(arg); + /* + * Have all the loads completed? + */ + if (isc_refcount_decrement(&zt->loads_pending) == 1) { + call_loaddone(zt); } return (result); @@ -332,14 +366,20 @@ asyncload(dns_zone_t *zone, void *zt_) { REQUIRE(zone != NULL); isc_refcount_increment(&zt->references); - isc_refcount_increment(&zt->loads_pending); - result = dns_zone_asyncload(zone, zt->loadparams->newonly, *zt->loadparams->dl, zt); + result = dns_zone_asyncload(zone, zt->loadparams->newonly, + *zt->loadparams->dl, zt); if (result != ISC_R_SUCCESS) { - - isc_refcount_decrement(&zt->references); - isc_refcount_decrement(&zt->loads_pending); + uint_fast32_t oldref; + /* + * Caller is holding a reference to zt->loads_pending + * and zt->references so these can't decrement to zero. + */ + oldref = isc_refcount_decrement(&zt->loads_pending); + INSIST(oldref > 1); + oldref = isc_refcount_decrement(&zt->references); + INSIST(oldref > 1); } return (ISC_R_SUCCESS); } @@ -528,8 +568,6 @@ dns_zt_apply(dns_zt_t *zt, bool stop, isc_result_t *sub, */ static isc_result_t doneloading(dns_zt_t *zt, dns_zone_t *zone, isc_task_t *task) { - dns_zt_allloaded_t alldone = NULL; - void *arg = NULL; UNUSED(zone); UNUSED(task); @@ -537,15 +575,7 @@ doneloading(dns_zt_t *zt, dns_zone_t *zone, isc_task_t *task) { REQUIRE(VALID_ZT(zt)); if (isc_refcount_decrement(&zt->loads_pending) == 1) { - alldone = zt->loaddone; - arg = zt->loaddone_arg; - zt->loaddone = NULL; - zt->loaddone_arg = NULL; - isc_mem_put(zt->mctx, zt->loadparams, sizeof(struct zt_load_params)); - zt->loadparams = NULL; - if (alldone != NULL) { - alldone(arg); - } + call_loaddone(zt); } if (isc_refcount_decrement(&zt->references) == 1) {