Skip to content

Commit

Permalink
Merge fortanix#299
Browse files Browse the repository at this point in the history
299: Limit memory usage for broadcast events r=jethrogb a=mzohreva

The elusive memory leak that I was chasing was not an actual leak, but was rather caused by pile up of events broadcast to all TCS's. This PR fixes the issue by removing the unbounded channel used for events.

Co-authored-by: Mohsen Zohrevandi <[email protected]>
  • Loading branch information
bors[bot] and mzohreva authored Nov 10, 2020
2 parents 5759959 + d3a5bcc commit 798f26b
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 102 deletions.
2 changes: 1 addition & 1 deletion enclave-runner/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ openssl = { version = "0.10", optional = true } # Apache-2.0
crossbeam = "0.7.1" # MIT/Apache-2.0
num_cpus = "1.10.0" # MIT/Apache-2.0
tokio = { version = "0.2", features = ["full"] } # MIT
futures = { version = "0.3", features = ["compat", "io-compat"] }
futures = { version = "0.3", features = ["compat", "io-compat"] } # MIT/Apache-2.0

[features]
default = ["crypto-openssl"]
Expand Down
205 changes: 104 additions & 101 deletions enclave-runner/src/usercalls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use std::thread::{self, JoinHandle};
use std::time::{self, Duration, Instant};
use std::time::{self, Duration};
use std::{cmp, fmt, str};

use failure::{self, bail};
Expand All @@ -31,6 +31,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio::stream::Stream as TokioStream;
use tokio::sync::broadcast;
use tokio::sync::mpsc as async_mpsc;
use tokio::sync::Semaphore;

use fortanix_sgx_abi::*;
use ipc_queue::{self, DescriptorGuard, Identified, QueueEvent};
Expand Down Expand Up @@ -496,7 +497,6 @@ impl fmt::Pointer for TcsAddress {

struct StoppedTcs {
tcs: ErasedTcs,
event_queue: futures::channel::mpsc::UnboundedReceiver<u8>,
}

struct IOHandlerInput<'tcs> {
Expand All @@ -506,7 +506,11 @@ struct IOHandlerInput<'tcs> {
}

struct PendingEvents {
counts: [u32; Self::EV_MAX],
// The Semaphores are basically counting how many times each event set has
// been sent. The count is decreased when an event set is consumed through
// `take` or `wait_for` methods by calling `SemaphorePermit::forget`.
counts: [Semaphore; Self::EV_MAX],
abort: Semaphore,
}

impl PendingEvents {
Expand All @@ -520,34 +524,83 @@ impl PendingEvents {
const _ERROR_IF_TOO_BIG: u64 = u64::MAX + (Self::EV_MAX_U64 - (Self::_EV_MAX_U16 as u64));

fn new() -> Self {
Self { counts: Default::default() }
PendingEvents {
counts: [
Semaphore::new(0), Semaphore::new(0), Semaphore::new(0), Semaphore::new(0),
Semaphore::new(0), Semaphore::new(0), Semaphore::new(0), Semaphore::new(0),
],
abort: Semaphore::new(0),
}
}

fn take(&mut self, event_mask: u8) -> Option<u8> {
assert!((event_mask as usize) < Self::EV_MAX);
fn take(&self, event_mask: u64) -> Option<u64> {
assert!(event_mask < Self::EV_MAX_U64);

if let Ok(_) = self.abort.try_acquire() {
return Some(EV_ABORT);
}

for i in (1..Self::EV_MAX).rev() {
let ev = i as u8;
if (ev & event_mask) != 0 && self.counts[i] > 0 {
self.counts[i] -= 1;
return Some(ev);
let ev = i as u64;
if (ev & event_mask) != 0 {
if let Ok(permit) = self.counts[i].try_acquire() {
permit.forget();
return Some(ev);
}
}
}
None
}

fn push(&mut self, event: u8) {
debug_assert!(event != 0 && (event as usize) < Self::EV_MAX);
let i = event as usize;
if let Some(val) = self.counts[i].checked_add(1) {
self.counts[i] = val;
async fn wait_for(&self, event_mask: u64) -> u64 {
assert!(event_mask < Self::EV_MAX_U64);

if let Ok(_) = self.abort.try_acquire() {
return EV_ABORT;
}

let it = std::iter::once((EV_ABORT, &self.abort))
.chain(self.counts.iter().enumerate().map(|(ev, sem)| (ev as u64, sem)).filter(|&(ev, _)| ev & event_mask != 0))
.map(|(ev, sem)| sem.acquire().map(move |permit| (ev, permit)).boxed());

let ((ev, permit), _, _) = futures::future::select_all(it).await;

// Abort should take precedence if it happens concurrently with another
// event. The abort semaphore may not have been selected.
if let Ok(_) = self.abort.try_acquire() {
return EV_ABORT;
}
if ev != EV_ABORT {
permit.forget();
}
ev
}

fn push(&self, event: u64) {
debug_assert!(event != 0 && event < Self::EV_MAX_U64);
let index = event as usize;
// add_permits() panics if the permit limit is exceeded.
// NOTE: [the documentation] incorrectly specifies the maximum to be
// `usize::MAX >> 3`, while the actual maximum is `usize::MAX >> 4`.
// It's possible to have multiple threads pushing the same event
// concurrently, hence the smaller bound.
//
// [the documentation]: https://docs.rs/tokio/0.2.22/tokio/sync/struct.Semaphore.html#method.add_permits
const MAX_PERMITS: usize = usize::MAX >> 5;
if self.counts[index].available_permits() < MAX_PERMITS {
self.counts[index].add_permits(1);
}
}

fn abort(&self) {
if self.abort.available_permits() == 0 {
self.abort.add_permits(usize::MAX >> 5);
}
}
}

struct RunningTcs {
pending_events: PendingEvents,
event_queue: futures::channel::mpsc::UnboundedReceiver<u8>,
tcs_address: TcsAddress,
mode: EnclaveEntry,
}

Expand Down Expand Up @@ -591,7 +644,7 @@ struct FifoGuards {

pub(crate) struct EnclaveState {
kind: EnclaveKind,
event_queues: FnvHashMap<TcsAddress, futures::channel::mpsc::UnboundedSender<u8>>,
event_queues: FnvHashMap<TcsAddress, PendingEvents>,
fds: Mutex<FnvHashMap<Fd, Arc<AsyncFileDesc>>>,
last_fd: AtomicUsize,
exiting: AtomicBool,
Expand Down Expand Up @@ -633,22 +686,20 @@ impl Work {

impl EnclaveState {
fn event_queue_add_tcs(
event_queues: &mut FnvHashMap<TcsAddress, futures::channel::mpsc::UnboundedSender<u8>>,
event_queues: &mut FnvHashMap<TcsAddress, PendingEvents>,
tcs: ErasedTcs,
) -> StoppedTcs {
let (send, recv) = futures::channel::mpsc::unbounded();
if event_queues.insert(tcs.address(), send).is_some() {
if event_queues.insert(tcs.address(), PendingEvents::new()).is_some() {
panic!("duplicate TCS address: {:p}", tcs.address())
}
StoppedTcs {
tcs,
event_queue: recv,
}
}

fn new(
kind: EnclaveKind,
mut event_queues: FnvHashMap<TcsAddress, futures::channel::mpsc::UnboundedSender<u8>>,
mut event_queues: FnvHashMap<TcsAddress, PendingEvents>,
usercall_ext: Option<Box<dyn UsercallExtension>>,
threads_vector: Vec<ErasedTcs>,
forward_panics: bool,
Expand Down Expand Up @@ -860,10 +911,7 @@ impl EnclaveState {
let fut = async move {
let ret = match state.mode {
EnclaveEntry::Library => {
enclave_clone.threads_queue.push(StoppedTcs {
tcs,
event_queue: state.event_queue,
});
enclave_clone.threads_queue.push(StoppedTcs { tcs });
Ok((v1, v2))
}
EnclaveEntry::ExecutableMain => Err(EnclaveAbort::MainReturned),
Expand All @@ -876,10 +924,7 @@ impl EnclaveState {
// If the enclave is in the exit-state, threads are no
// longer able to be launched
if !enclave_clone.exiting.load(Ordering::SeqCst) {
enclave_clone.threads_queue.push(StoppedTcs {
tcs,
event_queue: state.event_queue,
});
enclave_clone.threads_queue.push(StoppedTcs { tcs });
}
Ok((0, 0))
}
Expand Down Expand Up @@ -974,8 +1019,7 @@ impl EnclaveState {

let main_work = Work {
tcs: RunningTcs {
event_queue: main.event_queue,
pending_events: PendingEvents::new(),
tcs_address: main.tcs.address(),
mode: EnclaveEntry::ExecutableMain,
},
entry: CoEntry::Initial(main.tcs, argv as _, argc as _, 0, 0, 0),
Expand Down Expand Up @@ -1061,9 +1105,8 @@ impl EnclaveState {
let thread = enclave.threads_queue.pop().expect("threads queue empty");
let work = Work {
tcs: RunningTcs {
event_queue: thread.event_queue,
tcs_address: thread.tcs.address(),
mode: EnclaveEntry::Library,
pending_events: PendingEvents::new(),
},
entry: CoEntry::Initial(thread.tcs, p1, p2, p3, p4, p5),
};
Expand Down Expand Up @@ -1093,8 +1136,8 @@ impl EnclaveState {
fn abort_all_threads(&self) {
self.exiting.store(true, Ordering::SeqCst);
// wake other threads
for queue in self.event_queues.values() {
let _ = queue.unbounded_send(EV_ABORT as _);
for pending_events in self.event_queues.values() {
pending_events.abort();
}
}
}
Expand Down Expand Up @@ -1404,23 +1447,18 @@ impl<'tcs> IOHandlerInput<'tcs> {

let ret = self.work_sender.send(Work {
tcs: RunningTcs {
pending_events: PendingEvents::new(),
event_queue: new_tcs.event_queue,
tcs_address: new_tcs.tcs.address(),
mode: EnclaveEntry::ExecutableNonMain,
},
entry: CoEntry::Initial(new_tcs.tcs, 0, 0, 0, 0, 0),
});
match ret {
Ok(()) => Ok(()),
Err(e) => {
let event_queue = e.0.tcs.event_queue;
let entry = e.0.entry;
match entry {
CoEntry::Initial(tcs, _, _ ,_, _, _) => {
self.enclave.threads_queue.push(StoppedTcs {
tcs,
event_queue,
});
self.enclave.threads_queue.push(StoppedTcs { tcs });
},
_ => unreachable!(),
};
Expand All @@ -1438,20 +1476,20 @@ impl<'tcs> IOHandlerInput<'tcs> {
EnclaveAbort::Exit { panic }
}

fn check_event_set(set: u64) -> IoResult<u8> {
fn check_event_set(set: u64) -> IoResult<()> {
const EV_ALL: u64 = EV_USERCALLQ_NOT_FULL | EV_RETURNQ_NOT_EMPTY | EV_UNPARK;
if (set & !EV_ALL) != 0 {
return Err(IoErrorKind::InvalidInput.into());
}

assert!((EV_ALL | EV_ABORT) <= u8::max_value().into());
assert!((EV_ALL & EV_ABORT) == 0);
Ok(set as u8)
Ok(())
}

#[inline(always)]
async fn wait(&mut self, event_mask: u64, timeout: u64) -> IoResult<u64> {
let event_mask = Self::check_event_set(event_mask)?;
Self::check_event_set(event_mask)?;

let timeout = match timeout {
WAIT_NO | WAIT_INDEFINITE => timeout,
Expand All @@ -1468,74 +1506,39 @@ impl<'tcs> IOHandlerInput<'tcs> {
// TODO: https://github.com/fortanix/rust-sgx/issues/290
let tcs = self.tcs.as_mut().ok_or(io::Error::from(io::ErrorKind::Other))?;

let mut ret = tcs.pending_events.take(event_mask);

if ret.is_none() {
let start = Instant::now();
loop {
let ev = match timeout {
WAIT_INDEFINITE => tcs.event_queue.next().await.ok_or(()),
WAIT_NO => match tcs.event_queue.try_next() {
Ok(Some(ev)) => Ok(ev),
Ok(None) => Err(()),
Err(_) => break,
},
timeout => {
let remaining = match Duration::from_nanos(timeout).checked_sub(start.elapsed()) {
None => break,
Some(ref duration) if duration.as_nanos() == 0 => break,
Some(duration) => duration,
};
match tokio::time::timeout(remaining, tcs.event_queue.next()).await {
Ok(Some(ev)) => Ok(ev),
Ok(None) => Err(()),
Err(_) => break, // timed out
}
}
}.expect("TCS event queue disconnected unexpectedly");
let pending_events = self.enclave.event_queues.get(&tcs.tcs_address).expect("invalid tcs address");

if (ev & (EV_ABORT as u8)) != 0 {
// dispatch will make sure this is not returned to enclave
return Err(IoErrorKind::Other.into());
}
let ret = match timeout {
WAIT_NO => pending_events.take(event_mask),
WAIT_INDEFINITE => Some(pending_events.wait_for(event_mask).await),
n => tokio::time::timeout(Duration::from_nanos(n), pending_events.wait_for(event_mask)).await.ok(),
};

if (ev & event_mask) != 0 {
ret = Some(ev);
break;
} else {
tcs.pending_events.push(ev);
}
if let Some(ev) = ret {
if (ev & EV_ABORT) != 0 {
// dispatch will make sure this is not returned to enclave
return Err(IoErrorKind::Other.into());
}
return Ok(ev.into());
}

if let Some(ret) = ret {
Ok(ret.into())
} else {
Err(if timeout == WAIT_NO { IoErrorKind::WouldBlock } else { IoErrorKind::TimedOut }.into())
}
Err(if timeout == WAIT_NO { IoErrorKind::WouldBlock } else { IoErrorKind::TimedOut }.into())
}

#[inline(always)]
fn send(&self, event_set: u64, target: Option<Tcs>) -> IoResult<()> {
let event_set = Self::check_event_set(event_set)?;
Self::check_event_set(event_set)?;

if event_set == 0 {
return Err(IoErrorKind::InvalidInput.into());
}

if let Some(tcs) = target {
let tcs = TcsAddress(tcs.as_ptr() as _);
let queue = self
.enclave
.event_queues
.get(&tcs)
.ok_or(IoErrorKind::InvalidInput)?;
queue
.unbounded_send(event_set)
.expect("TCS event queue disconnected");
let pending_events = self.enclave.event_queues.get(&tcs).ok_or(IoErrorKind::InvalidInput)?;
pending_events.push(event_set);
} else {
for queue in self.enclave.event_queues.values() {
let _ = queue.unbounded_send(event_set);
for pending_events in self.enclave.event_queues.values() {
pending_events.push(event_set);
}
}

Expand Down Expand Up @@ -1684,8 +1687,8 @@ impl ipc_queue::AsyncSynchronizer for QueueSynchronizer {
// When the enclave needs to wait on a queue, it executes the wait() usercall synchronously,
// specifying EV_USERCALLQ_NOT_FULL, EV_RETURNQ_NOT_EMPTY, or both in the event_mask.
// Userspace will wake any or all threads waiting on the appropriate event when it is triggered.
for queue in self.enclave.event_queues.values() {
let _ = queue.unbounded_send(ev as _);
for pending_events in self.enclave.event_queues.values() {
pending_events.push(ev as _);
}
}
}

0 comments on commit 798f26b

Please sign in to comment.