Skip to content

Commit

Permalink
[sledge] Add support for returning a result from thread_join
Browse files Browse the repository at this point in the history
Summary:
Previously the function initially executed by created threads could
not return a result. This diff adds support for passing such results
to callers of thread_join.

Differential Revision: D32511370

fbshipit-source-id: 92ae9cd23
  • Loading branch information
jberdine authored and facebook-github-bot committed Nov 19, 2021
1 parent 799d85d commit 3e1f8e3
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 49 deletions.
5 changes: 5 additions & 0 deletions sledge/cli/domain_itv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ let retn tid _ freturn {areturn; caller_q} callee_q =
| Some aret, None -> exec_kill tid aret caller_q
| None, _ -> caller_q

type term_code = unit [@@deriving compare, sexp_of]

let term _ _ _ _ = ()
let move_term_code _ _ () q = q

(** map actuals to formals (via temporary registers), stash constraints on
caller-local variables. Note that this exploits the non-relational-ness
of Box to ignore all variables other than the formal/actual params/
Expand Down
8 changes: 4 additions & 4 deletions sledge/model/llair_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ __attribute__((noreturn)) void __llair_unreachable();

typedef int thread_t;

typedef void (*thread_create_routine)(void*);
typedef int (*thread_create_routine)(void*);

thread_t sledge_thread_create(thread_create_routine entry, void* arg);

void sledge_thread_join(thread_t thread);
int sledge_thread_join(thread_t thread);

typedef int error_t;
#define OK 0
Expand All @@ -54,9 +54,9 @@ thread_create(thread_t** t, thread_create_routine entry, void* arg)
}

error_t
thread_join(thread_t* t)
thread_join(thread_t* thread, int* ret_code)
{
sledge_thread_join(*t);
*ret_code = sledge_thread_join(*thread);
return OK;
}

Expand Down
67 changes: 41 additions & 26 deletions sledge/src/control.ml
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ module Make (Config : Config) (D : Domain) (Queue : Queue) = struct
(** Representation of a single thread, including identity and scheduling
state *)
module Thread : sig
type t = Runnable of IP.t | Terminated of ThreadID.t
type t = Runnable of IP.t | Terminated of ThreadID.t * D.term_code
[@@deriving equal, sexp_of]

val compare : t Ord.t
Expand All @@ -380,15 +380,15 @@ module Make (Config : Config) (D : Domain) (Queue : Queue) = struct
(** Because [ip] needs to include [tid], this is represented as a sum of
products, but it may be more natural to think in terms of the
isomorphic representation using a product of a sum such as
[(Runnable of ... | Terminated ...) * ThreadID.t]. *)
type t = Runnable of IP.t | Terminated of ThreadID.t
[(Runnable of ... | Terminated of ...) * ThreadID.t]. *)
type t = Runnable of IP.t | Terminated of ThreadID.t * D.term_code
[@@deriving sexp_of]

let pp ppf = function
| Runnable ip -> IP.pp ppf ip
| Terminated tid -> Format.fprintf ppf "T%i" tid
| Terminated (tid, _) -> Format.fprintf ppf "T%i" tid

let id = function Runnable {tid} -> tid | Terminated tid -> tid
let id = function Runnable {tid} -> tid | Terminated (tid, _) -> tid

(* Note: Threads.inactive relies on comparing tid last *)
let compare_aux compare_tid x y =
Expand All @@ -402,7 +402,8 @@ module Make (Config : Config) (D : Domain) (Queue : Queue) = struct
<?> (compare_tid, x.tid, y.tid)
| Runnable _, _ -> -1
| _, Runnable _ -> 1
| Terminated x_tid, Terminated y_tid -> compare_tid x_tid y_tid
| Terminated (x_tid, x_tc), Terminated (y_tid, y_tc) ->
D.compare_term_code x_tc y_tc <?> (compare_tid, x_tid, y_tid)

let compare = compare_aux ThreadID.compare
let equal = [%compare.equal: t]
Expand All @@ -422,7 +423,7 @@ module Make (Config : Config) (D : Domain) (Queue : Queue) = struct
val init : t
val create : Llair.block -> t -> ThreadID.t * t
val after_step : Thread.t -> t -> t * inactive
val join : ThreadID.t -> t -> t option
val join : ThreadID.t -> t -> (D.term_code * t) option
val fold : t -> 's -> f:(Thread.t -> 's -> 's) -> 's
end = struct
module M = Map.Make (ThreadID)
Expand Down Expand Up @@ -467,7 +468,7 @@ module Make (Config : Config) (D : Domain) (Queue : Queue) = struct

let join tid threads =
match M.find tid threads with
| Some (Thread.Terminated _) -> Some (M.remove tid threads)
| Some (Thread.Terminated (_, tc)) -> Some (tc, M.remove tid threads)
| _ ->
[%Trace.info " prune join of non-terminated thread: %i" tid] ;
None
Expand Down Expand Up @@ -515,8 +516,11 @@ module Make (Config : Config) (D : Domain) (Queue : Queue) = struct
<?> (compare_tid, x_t.tid, y_t.tid)
| {dst= Runnable _}, _ -> -1
| _, {dst= Runnable _} -> 1
| {dst= Terminated x_tid}, {dst= Terminated y_tid} ->
Llair.Block.compare x.src y.src <?> (compare_tid, x_tid, y_tid)
| {dst= Terminated (x_tid, x_tc)}, {dst= Terminated (y_tid, y_tc)}
->
Llair.Block.compare x.src y.src
<?> (D.compare_term_code, x_tc, y_tc)
<?> (compare_tid, x_tid, y_tid)

let compare = compare_aux ThreadID.compare
let equal = [%compare.equal: t]
Expand Down Expand Up @@ -922,9 +926,10 @@ module Make (Config : Config) (D : Domain) (Queue : Queue) = struct
{ams with ctrl= {ams.ctrl with stk}; state= retn_state}
wl
| None ->
if Config.function_summaries then summarize exit_state |> ignore ;
summarize exit_state |> ignore ;
let tc = D.term tid formals freturn exit_state in
Work.add ~retreating:false
{ams with ctrl= {dst= Terminated tid; src= block}}
{ams with ctrl= {dst= Terminated (tid, tc); src= block}}
wl )
|>
[%Trace.retn fun {pf} _ -> pf ""]
Expand Down Expand Up @@ -960,12 +965,18 @@ module Make (Config : Config) (D : Domain) (Queue : Queue) = struct
[%Trace.info " infeasible %a@\n@[%a@]" Llair.Exp.pp cond D.pp state] ;
wl

let exec_thread_create reg {Llair.name; formals; freturn; entry; locals}
actual return ({ctrl= {tid}; state; threads} as ams) wl =
let exec_thread_create areturn
{Llair.name; formals; freturn; entry; locals} actual return
({ctrl= {tid}; state; threads} as ams) wl =
let child, threads = Threads.create entry threads in
let state =
let child = Llair.Exp.integer (Llair.Reg.typ reg) (Z.of_int child) in
D.exec_move tid (IArray.of_ (reg, child)) state
match areturn with
| None -> state
| Some reg ->
let child =
Llair.Exp.integer (Llair.Reg.typ reg) (Z.of_int child)
in
D.exec_move tid (IArray.of_ (reg, child)) state
in
let state, _ =
let globals = Domain_used_globals.by_function Config.globals name in
Expand All @@ -975,11 +986,17 @@ module Make (Config : Config) (D : Domain) (Queue : Queue) = struct
in
exec_jump return {ams with state; threads} wl

let exec_thread_join thread return ({ctrl= {tid}; state; threads} as ams)
wl =
let exec_thread_join thread areturn return
({ctrl= {tid}; state; threads} as ams) wl =
List.fold (D.resolve_int tid state thread) wl ~f:(fun join_tid wl ->
match Threads.join join_tid threads with
| Some threads -> exec_jump return {ams with threads} wl
| Some (term_code, threads) ->
let state =
match areturn with
| None -> state
| Some reg -> D.move_term_code tid reg term_code state
in
exec_jump return {ams with state; threads} wl
| None -> wl )

let resolve_callee (pgm : Llair.program) tid callee state =
Expand Down Expand Up @@ -1009,17 +1026,15 @@ module Make (Config : Config) (D : Domain) (Queue : Queue) = struct
~name:jump.dst.lbl ) )
jump ams wl )
| Call ({callee; actuals; areturn; return} as call) -> (
match
(Llair.Function.name callee.name, IArray.to_array actuals, areturn)
with
| "sledge_thread_create", [|callee; arg|], Some reg -> (
match (Llair.Function.name callee.name, IArray.to_array actuals) with
| "sledge_thread_create", [|callee; arg|] -> (
match resolve_callee pgm tid callee state with
| [] -> exec_skip_func areturn return ams wl
| callees ->
List.fold callees wl ~f:(fun callee wl ->
exec_thread_create reg callee arg return ams wl ) )
| "sledge_thread_join", [|thread|], None ->
exec_thread_join thread return ams wl
exec_thread_create areturn callee arg return ams wl ) )
| "sledge_thread_join", [|thread|] ->
exec_thread_join thread areturn return ams wl
| _ -> exec_call call ams wl )
| ICall ({callee; areturn; return} as call) -> (
match resolve_callee pgm tid callee state with
Expand Down
7 changes: 7 additions & 0 deletions sledge/src/domain_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ module type Domain = sig
-> t
-> t

type term_code [@@deriving compare, sexp_of]

val term :
ThreadID.t -> Llair.Reg.t iarray -> Llair.Reg.t option -> t -> term_code

val move_term_code : ThreadID.t -> Llair.Reg.t -> term_code -> t -> t

val resolve_callee :
(string -> Llair.func option)
-> ThreadID.t
Expand Down
8 changes: 8 additions & 0 deletions sledge/src/domain_relation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ module Make (State_domain : State_domain_sig) = struct
|>
[%Trace.retn fun {pf} -> pf "%a" pp]

type term_code = State_domain.term_code [@@deriving compare, sexp_of]

let term tid formals freturn (_, current) =
State_domain.term tid formals freturn current

let move_term_code tid reg code (entry, current) =
(entry, State_domain.move_term_code tid reg code current)

let dnf (entry, current) =
State_domain.Set.fold (State_domain.dnf current) Set.empty
~f:(fun c rs -> Set.add (entry, c) rs)
Expand Down
20 changes: 20 additions & 0 deletions sledge/src/domain_sh.ml
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,26 @@ let retn tid formals freturn {areturn; unshadow; frame} q =
|>
[%Trace.retn fun {pf} -> pf "%a" pp]

type term_code = Term.t option [@@deriving compare, sexp_of]

let term tid formals freturn q =
let* freturn = freturn in
let formals =
Var.Set.of_iter (Iter.map ~f:(X.reg tid) (IArray.to_iter formals))
in
let freturn = X.reg tid freturn in
let xs, q = Sh.bind_exists q ~wrt:Var.Set.empty in
let outscoped = Var.Set.union formals (Var.Set.of_ freturn) in
let xs = Var.Set.union xs outscoped in
let retn_val_cls = Context.class_of q.ctx (Term.var freturn) in
List.find retn_val_cls ~f:(fun retn_val ->
Var.Set.disjoint xs (Term.fv retn_val) )

let move_term_code tid reg code q =
match code with
| Some retn_val -> Exec.move q (IArray.of_ (X.reg tid reg, retn_val))
| None -> q

let resolve_callee lookup tid ptr (q : Sh.t) =
Context.class_of q.ctx (X.term tid ptr)
|> List.find_map ~f:(X.lookup_func lookup)
Expand Down
5 changes: 5 additions & 0 deletions sledge/src/domain_unit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ let call ~summaries:_ _ ?child:_ ~globals:_ ~actuals:_ ~areturn:_ ~formals:_
let recursion_beyond_bound = `skip
let post _ _ _ () = ()
let retn _ _ _ _ _ = ()

type term_code = unit [@@deriving compare, sexp_of]

let term _ _ _ _ = ()
let move_term_code _ _ () () = ()
let dnf () = Set.of_ ()
let resolve_callee _ _ _ _ = []

Expand Down
5 changes: 5 additions & 0 deletions sledge/src/domain_used_globals.ml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ let enter_scope _ _ state = state
let recursion_beyond_bound = `skip
let post _ _ _ state = state
let retn _ _ _ from_call post = Llair.Global.Set.union from_call post

type term_code = unit [@@deriving compare, sexp_of]

let term _ _ _ _ = ()
let move_term_code _ _ () q = q
let dnf t = Set.of_ t

let used_globals exp s =
Expand Down
36 changes: 27 additions & 9 deletions sledge/test/analyze/cqueue.c
Original file line number Diff line number Diff line change
Expand Up @@ -125,31 +125,31 @@ mark_free(queue_t* const q, const uint32_t k)
atomic_store(&q->own[k], PROD);
}

static void
static int
produce_thread_run(void* const arg)
{
if (arg == NULL) {
return;
return 0;
}
queue_t* const q = (queue_t*)arg;
const uint32_t d = cct_random_between(1, 100);
const uint32_t idx = start_enqueue(q);
q->dat[idx] = d;
mark_ready(q, idx);
return;
return d;
}

static void
static int
consume_thread_run(void* const arg)
{
if (arg == NULL) {
return;
return 0;
}
queue_t* const q = (queue_t*)arg;
const uint32_t idx = start_dequeue(q);
const uint32_t d = q->dat[idx];
mark_free(q, idx);
return;
return d;
}

#define NUM_PRODUCE_THREADS 2
Expand All @@ -165,18 +165,27 @@ main(void)
{
void* test_mem_ptr = __llair_alloc(num_bytes_to_allocate());
queue_t* test_queue = queue_init(test_mem_ptr);
error_t status;
int thread_ret;
int32_t total_produce = 0;
int32_t num_produced = 0;
int32_t total_consume = 0;
int32_t num_consumed = 0;
thread_t* produce_threads[NUM_PRODUCE_THREADS];
thread_t* consume_threads[NUM_CONSUME_THREADS];
error_t status;

for (uint32_t i = 0; i < NUM_PRODUCE_THREADS; i++) {
status =
thread_create(&produce_threads[i], &produce_thread_run, test_queue);
assert(OK == status && "Failed to create thread");
}
for (uint32_t i = 0; i < NUM_PRODUCE_THREADS; i++) {
status = thread_join(produce_threads[i]);
status = thread_join(produce_threads[i], &thread_ret);
assert(OK == status && "Failed to join thread");
total_produce += thread_ret;
if (thread_ret != 0) {
++num_produced;
}
}

for (uint32_t i = 0; i < NUM_CONSUME_THREADS; i++) {
Expand All @@ -185,10 +194,19 @@ main(void)
assert(OK == status && "Failed to create thread");
}
for (uint32_t i = 0; i < NUM_CONSUME_THREADS; i++) {
status = thread_join(consume_threads[i]);
status = thread_join(consume_threads[i], &thread_ret);
assert(OK == status && "Failed to join thread");
total_consume += thread_ret;
if (thread_ret != 0) {
++num_consumed;
}
}

assert(num_produced - num_consumed ==
NUM_PRODUCE_THREADS - NUM_CONSUME_THREADS &&
"Number of remaining elements = #produce threads - #consume threads");
assert(total_produce >= total_consume &&
"sum of produced elements >= sum of consumed elements");
if (NUM_PRODUCE_THREADS == NUM_CONSUME_THREADS) {
assert(queue_is_empty(test_queue) && "Non-empty queue");
}
Expand Down
9 changes: 5 additions & 4 deletions sledge/test/analyze/thread.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@

int count = 0;

void
int
child_routine(void* arg)
{
count++;
return count++;
}

int
main()
{
thread_t* child;
error_t err = thread_create(&child, &child_routine, NULL);
count++;
err = thread_join(child);
int ret_code;
err = thread_join(child, &ret_code);
count += ret_code;
return count;
}
Loading

0 comments on commit 3e1f8e3

Please sign in to comment.