Skip to content

Commit

Permalink
compiler: implement @branchHint, replacing @setCold
Browse files Browse the repository at this point in the history
Implements the accepted proposal to introduce `@branchHint`. This
builtin is permitted as the first statement of a block if that block is
the direct body of any of the following:

* a function (*not* a `test`)
* either branch of an `if`
* the RHS of a `catch` or `orelse`
* a `switch` prong
* an `or` or `and` expression

It lowers to the ZIR instruction `extended(branch_hint(...))`. When Sema
encounters this instruction, it sets `sema.branch_hint` appropriately,
and `zirCondBr` etc are expected to reset this value as necessary. The
state is on `Sema` rather than `Block` to make it automatically
propagate up non-conditional blocks without special handling. If
`@panic` is reached, the branch hint is set to `.cold` if none was
already set; similarly, error branches get a hint of `.unlikely` if no
hint is explicitly provided. If a condition is comptime-known, `cold`
hints from the taken branch are allowed to propagate up, but other hints
are discarded. This is because a `likely`/`unlikely` hint just indicates
the direction this branch is likely to go, which is redundant
information when the branch is known at comptime; but `cold` hints
indicate that control flow is unlikely to ever reach this branch,
meaning if the branch is always taken from its parent, then the parent
is also unlikely to ever be reached.

This branch information is stored in AIR `cond_br` and `switch_br`. In
addition, `try` and `try_ptr` instructions have variants `try_cold` and
`try_ptr_cold` which indicate that the error case is cold (rather than
just unlikely); this is reachable through e.g. `errdefer unreachable` or
`errdefer @Panic("")`.

A new API `unwrapSwitch` is introduced to `Air` to make it more
convenient to access `switch_br` instructions. In time, I plan to update
all AIR instructions to be accessed via an `unwrap` method which returns
a convenient tagged union a la `InternPool.indexToKey`.

The LLVM backend lowers branch hints for conditional branches and
switches as follows:

* If any branch is marked `unpredictable`, the instruction is marked
  `!unpredictable`.
* Any branch which is marked as `cold` gets a
  `llvm.assume(i1 true) [ "cold"() ]` call to mark the code path cold.
* If any branch is marked `likely` or `unlikely`, branch weight metadata
  is attached with `!prof`. Likely branches get a weight of 2000, and
  unlikely branches a weight of 1. In `switch` statements, un-annotated
  branches get a weight of 1000 as a "middle ground" hint, since there
  could be likely *and* unlikely *and* un-annotated branches.

For functions, a `cold` hint corresponds to the `cold` function
attribute, and other hints are currently ignored -- as far as I can tell
LLVM doesn't really have a way to lower them. (Ideally, we would want
the branch hint given in the function to propagate to call sites.)

The compiler and standard library do not yet use this new builtin.

Resolves: ziglang#21148
  • Loading branch information
mlugg committed Aug 26, 2024
1 parent 72e0080 commit 457c94d
Show file tree
Hide file tree
Showing 25 changed files with 1,127 additions and 563 deletions.
19 changes: 19 additions & 0 deletions lib/std/builtin.zig
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,25 @@ pub const ExternOptions = struct {
is_thread_local: bool = false,
};

/// This data structure is used by the Zig language code generation and
/// therefore must be kept in sync with the compiler implementation.
pub const BranchHint = enum(u3) {
/// Equivalent to no hint given.
none,
/// This branch of control flow is more likely to be reached than its peers.
/// The optimizer should optimize for reaching it.
likely,
/// This branch of control flow is less likely to be reached than its peers.
/// The optimizer should optimize for not reaching it.
unlikely,
/// This branch of control flow is unlikely to *ever* be reached.
/// The optimizer may place it in a different page of memory to optimize other branches.
cold,
/// It is difficult to predict whether this branch of control flow will be reached.
/// The optimizer should avoid branching behavior with expensive mispredictions.
unpredictable,
};

/// This enum is set by the compiler and communicates which compiler backend is
/// used to produce machine code.
/// Think carefully before deciding to observe this value. Nearly all code should
Expand Down
144 changes: 91 additions & 53 deletions lib/std/zig/AstGen.zig

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion lib/std/zig/AstRlAnnotate.zig
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,10 @@ fn builtinCall(astrl: *AstRlAnnotate, block: ?*Block, ri: ResultInfo, node: Ast.
}
switch (info.tag) {
.import => return false,
.branch_hint => {
_ = try astrl.expr(args[0], block, ResultInfo.type_only);
return false;
},
.compile_log, .TypeOf => {
for (args) |arg_node| {
_ = try astrl.expr(arg_node, block, ResultInfo.none);
Expand Down Expand Up @@ -907,7 +911,6 @@ fn builtinCall(astrl: *AstRlAnnotate, block: ?*Block, ri: ResultInfo, node: Ast.
.fence,
.set_float_mode,
.set_align_stack,
.set_cold,
.type_info,
.work_item_id,
.work_group_size,
Expand Down
18 changes: 9 additions & 9 deletions lib/std/zig/BuiltinFn.zig
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub const Tag = enum {
bit_offset_of,
int_from_bool,
bit_size_of,
branch_hint,
breakpoint,
disable_instrumentation,
mul_add,
Expand Down Expand Up @@ -82,7 +83,6 @@ pub const Tag = enum {
return_address,
select,
set_align_stack,
set_cold,
set_eval_branch_quota,
set_float_mode,
set_runtime_safety,
Expand Down Expand Up @@ -256,6 +256,14 @@ pub const list = list: {
.param_count = 1,
},
},
.{
"@branchHint",
.{
.tag = .branch_hint,
.param_count = 1,
.illegal_outside_function = true,
},
},
.{
"@breakpoint",
.{
Expand Down Expand Up @@ -744,14 +752,6 @@ pub const list = list: {
.illegal_outside_function = true,
},
},
.{
"@setCold",
.{
.tag = .set_cold,
.param_count = 1,
.illegal_outside_function = true,
},
},
.{
"@setEvalBranchQuota",
.{
Expand Down
12 changes: 7 additions & 5 deletions lib/std/zig/Zir.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1546,7 +1546,7 @@ pub const Inst = struct {
=> false,

.extended => switch (data.extended.opcode) {
.fence, .set_cold, .breakpoint, .disable_instrumentation => true,
.fence, .branch_hint, .breakpoint, .disable_instrumentation => true,
else => false,
},
};
Expand Down Expand Up @@ -1954,9 +1954,6 @@ pub const Inst = struct {
/// Implement builtin `@setAlignStack`.
/// `operand` is payload index to `UnNode`.
set_align_stack,
/// Implements `@setCold`.
/// `operand` is payload index to `UnNode`.
set_cold,
/// Implements the `@errorCast` builtin.
/// `operand` is payload index to `BinNode`. `lhs` is dest type, `rhs` is operand.
error_cast,
Expand Down Expand Up @@ -2051,6 +2048,10 @@ pub const Inst = struct {
/// `operand` is `src_node: i32`.
/// `small` is an `Inst.BuiltinValue`.
builtin_value,
/// Provide a `@branchHint` for the current block.
/// `operand` is payload index to `UnNode`.
/// `small` is unused.
branch_hint,

pub const InstData = struct {
opcode: Extended,
Expand Down Expand Up @@ -3142,6 +3143,7 @@ pub const Inst = struct {
export_options,
extern_options,
type_info,
branch_hint,
// Values
calling_convention_c,
calling_convention_inline,
Expand Down Expand Up @@ -3962,7 +3964,6 @@ fn findDeclsInner(
.fence,
.set_float_mode,
.set_align_stack,
.set_cold,
.error_cast,
.await_nosuspend,
.breakpoint,
Expand All @@ -3986,6 +3987,7 @@ fn findDeclsInner(
.closure_get,
.field_parent_ptr,
.builtin_value,
.branch_hint,
=> return,

// `@TypeOf` has a body.
Expand Down
115 changes: 109 additions & 6 deletions src/Air.zig
Original file line number Diff line number Diff line change
Expand Up @@ -433,13 +433,18 @@ pub const Inst = struct {
/// In the case of non-error, control flow proceeds to the next instruction
/// after the `try`, with the result of this instruction being the unwrapped
/// payload value, as if `unwrap_errunion_payload` was executed on the operand.
/// The error branch is considered to have a branch hint of `.unlikely`.
/// Uses the `pl_op` field. Payload is `Try`.
@"try",
/// Same as `try` except the error branch hint is `.cold`.
try_cold,
/// Same as `try` except the operand is a pointer to an error union, and the
/// result is a pointer to the payload. Result is as if `unwrap_errunion_payload_ptr`
/// was executed on the operand.
/// Uses the `ty_pl` field. Payload is `TryPtr`.
try_ptr,
/// Same as `try_ptr` except the error branch hint is `.cold`.
try_ptr_cold,
/// Notes the beginning of a source code statement and marks the line and column.
/// Result type is always void.
/// Uses the `dbg_stmt` field.
Expand Down Expand Up @@ -1116,11 +1121,20 @@ pub const Call = struct {
pub const CondBr = struct {
then_body_len: u32,
else_body_len: u32,
branch_hints: BranchHints,
pub const BranchHints = packed struct(u32) {
true: std.builtin.BranchHint,
false: std.builtin.BranchHint,
_: u26 = 0,
};
};

/// Trailing:
/// * 0. `Case` for each `cases_len`
/// * 1. the else body, according to `else_body_len`.
/// * 0. `BranchHint` for each `cases_len + 1`. bit-packed into `u32`
/// elems such that each `u32` contains up to 10x `BranchHint`.
/// LSBs are first case. Final hint is `else`.
/// * 1. `Case` for each `cases_len`
/// * 2. the else body, according to `else_body_len`.
pub const SwitchBr = struct {
cases_len: u32,
else_body_len: u32,
Expand Down Expand Up @@ -1380,6 +1394,7 @@ pub fn typeOfIndex(air: *const Air, inst: Air.Inst.Index, ip: *const InternPool)
.ptr_add,
.ptr_sub,
.try_ptr,
.try_ptr_cold,
=> return datas[@intFromEnum(inst)].ty_pl.ty.toType(),

.not,
Expand Down Expand Up @@ -1500,7 +1515,7 @@ pub fn typeOfIndex(air: *const Air, inst: Air.Inst.Index, ip: *const InternPool)
return air.typeOf(extra.lhs, ip);
},

.@"try" => {
.@"try", .try_cold => {
const err_union_ty = air.typeOf(datas[@intFromEnum(inst)].pl_op.operand, ip);
return Type.fromInterned(ip.indexToKey(err_union_ty.ip_index).error_union_type.payload_type);
},
Expand All @@ -1524,9 +1539,8 @@ pub fn extraData(air: Air, comptime T: type, index: usize) struct { data: T, end
inline for (fields) |field| {
@field(result, field.name) = switch (field.type) {
u32 => air.extra[i],
Inst.Ref => @as(Inst.Ref, @enumFromInt(air.extra[i])),
i32 => @as(i32, @bitCast(air.extra[i])),
InternPool.Index => @as(InternPool.Index, @enumFromInt(air.extra[i])),
InternPool.Index, Inst.Ref => @enumFromInt(air.extra[i]),
i32, CondBr.BranchHints => @bitCast(air.extra[i]),
else => @compileError("bad field type: " ++ @typeName(field.type)),
};
i += 1;
Expand Down Expand Up @@ -1593,7 +1607,9 @@ pub fn mustLower(air: Air, inst: Air.Inst.Index, ip: *const InternPool) bool {
.cond_br,
.switch_br,
.@"try",
.try_cold,
.try_ptr,
.try_ptr_cold,
.dbg_stmt,
.dbg_inline_block,
.dbg_var_ptr,
Expand Down Expand Up @@ -1796,4 +1812,91 @@ pub fn mustLower(air: Air, inst: Air.Inst.Index, ip: *const InternPool) bool {
};
}

pub const UnwrappedSwitch = struct {
air: *const Air,
operand: Inst.Ref,
cases_len: u32,
else_body_len: u32,
branch_hints_start: u32,
cases_start: u32,

/// Asserts that `case_idx < us.cases_len`.
pub fn getHint(us: UnwrappedSwitch, case_idx: u32) std.builtin.BranchHint {
assert(case_idx < us.cases_len);
return us.getHintInner(case_idx);
}
pub fn getElseHint(us: UnwrappedSwitch) std.builtin.BranchHint {
return us.getHintInner(us.cases_len);
}
fn getHintInner(us: UnwrappedSwitch, idx: u32) std.builtin.BranchHint {
const bag = us.air.extra[us.branch_hints_start..][idx / 10];
const bits: u3 = @truncate(bag >> @intCast(3 * (idx % 10)));
return @enumFromInt(bits);
}

pub fn iterateCases(us: UnwrappedSwitch) CaseIterator {
return .{
.air = us.air,
.cases_len = us.cases_len,
.else_body_len = us.else_body_len,
.next_case = 0,
.extra_index = us.cases_start,
};
}
pub const CaseIterator = struct {
air: *const Air,
cases_len: u32,
else_body_len: u32,
next_case: u32,
extra_index: u32,

pub fn next(it: *CaseIterator) ?Case {
if (it.next_case == it.cases_len) return null;
const idx = it.next_case;
it.next_case += 1;

const extra = it.air.extraData(SwitchBr.Case, it.extra_index);
var extra_index = extra.end;
const items: []const Inst.Ref = @ptrCast(it.air.extra[extra_index..][0..extra.data.items_len]);
extra_index += items.len;
const body: []const Inst.Index = @ptrCast(it.air.extra[extra_index..][0..extra.data.body_len]);
extra_index += body.len;
it.extra_index = @intCast(extra_index);

return .{
.idx = idx,
.items = items,
.body = body,
};
}
/// Only valid to call once all cases have been iterated, i.e. `next` returns `null`.
/// Returns the body of the "default" (`else`) case.
pub fn elseBody(it: *CaseIterator) []const Inst.Index {
assert(it.next_case == it.cases_len);
return @ptrCast(it.air.extra[it.extra_index..][0..it.else_body_len]);
}
pub const Case = struct {
idx: u32,
items: []const Inst.Ref,
body: []const Inst.Index,
};
};
};

pub fn unwrapSwitch(air: *const Air, switch_inst: Inst.Index) UnwrappedSwitch {
const inst = air.instructions.get(@intFromEnum(switch_inst));
assert(inst.tag == .switch_br);
const pl_op = inst.data.pl_op;
const extra = air.extraData(SwitchBr, pl_op.payload);
const hint_bag_count = std.math.divCeil(usize, extra.data.cases_len + 1, 10) catch unreachable;
return .{
.air = air,
.operand = pl_op.operand,
.cases_len = extra.data.cases_len,
.else_body_len = extra.data.else_body_len,
.branch_hints_start = @intCast(extra.end),
.cases_start = @intCast(extra.end + hint_bag_count),
};
}

pub const typesFullyResolved = @import("Air/types_resolved.zig").typesFullyResolved;
31 changes: 9 additions & 22 deletions src/Air/types_resolved.zig
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ fn checkBody(air: Air, body: []const Air.Inst.Index, zcu: *Zcu) bool {
if (!checkRef(data.pl_op.operand, zcu)) return false;
},

.@"try" => {
.@"try", .try_cold => {
const extra = air.extraData(Air.Try, data.pl_op.payload);
if (!checkRef(data.pl_op.operand, zcu)) return false;
if (!checkBody(
Expand All @@ -354,7 +354,7 @@ fn checkBody(air: Air, body: []const Air.Inst.Index, zcu: *Zcu) bool {
)) return false;
},

.try_ptr => {
.try_ptr, .try_ptr_cold => {
const extra = air.extraData(Air.TryPtr, data.ty_pl.payload);
if (!checkType(data.ty_pl.ty.toType(), zcu)) return false;
if (!checkRef(extra.data.ptr, zcu)) return false;
Expand All @@ -381,27 +381,14 @@ fn checkBody(air: Air, body: []const Air.Inst.Index, zcu: *Zcu) bool {
},

.switch_br => {
const extra = air.extraData(Air.SwitchBr, data.pl_op.payload);
if (!checkRef(data.pl_op.operand, zcu)) return false;
var extra_index = extra.end;
for (0..extra.data.cases_len) |_| {
const case = air.extraData(Air.SwitchBr.Case, extra_index);
extra_index = case.end;
const items: []const Air.Inst.Ref = @ptrCast(air.extra[extra_index..][0..case.data.items_len]);
extra_index += case.data.items_len;
for (items) |item| if (!checkRef(item, zcu)) return false;
if (!checkBody(
air,
@ptrCast(air.extra[extra_index..][0..case.data.body_len]),
zcu,
)) return false;
extra_index += case.data.body_len;
const switch_br = air.unwrapSwitch(inst);
if (!checkRef(switch_br.operand, zcu)) return false;
var it = switch_br.iterateCases();
while (it.next()) |case| {
for (case.items) |item| if (!checkRef(item, zcu)) return false;
if (!checkBody(air, case.body, zcu)) return false;
}
if (!checkBody(
air,
@ptrCast(air.extra[extra_index..][0..extra.data.else_body_len]),
zcu,
)) return false;
if (!checkBody(air, it.elseBody(), zcu)) return false;
},

.assembly => {
Expand Down
Loading

0 comments on commit 457c94d

Please sign in to comment.