Skip to content

Commit

Permalink
Fix a couple bugs (ml-explore#1161)
Browse files Browse the repository at this point in the history
* fix jit reduce for RMS norm

* make strides a single buffer

* better eval error message

* fix compiling with inf and bf16

* fix cpu compile with bf16
  • Loading branch information
awni authored May 28, 2024
1 parent a87ef5b commit e7a2a3d
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 27 deletions.
1 change: 1 addition & 0 deletions mlx/backend/common/make_compiled_preamble.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const char* get_kernel_preamble() {
return R"preamble(
$INCLUDES
$CONTENT
using namespace mlx::core;
using namespace mlx::core::detail;
)preamble";
}
Expand Down
43 changes: 29 additions & 14 deletions mlx/backend/metal/compiled.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,15 @@ inline void build_kernel(
} else {
add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl
<< " constant const size_t* " << xname << "_strides [[buffer("
<< cnt++ << ")]]," << std::endl;
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
}
}

if (add_indices) {
os << " constant const size_t* in_strides [[buffer(" << cnt++
<< ")]],\n";
}

// Add the output arguments
for (auto& x : outputs) {
os << " device " << get_type_string(x.dtype()) << "* "
Expand Down Expand Up @@ -110,31 +113,38 @@ inline void build_kernel(
}

// Read the inputs in tmps
for (auto& x : inputs) {
int nc_in_count = 0;
for (int i = 0; i < inputs.size(); ++i) {
auto& x = inputs[i];
auto& xname = namer.get_name(x);

if (is_constant(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
auto type_str = get_type_string(x.dtype());
os << " auto tmp_" << xname << " = static_cast<"
<< get_type_string(x.dtype()) << ">(";
print_constant(os, x);
os << ";" << std::endl;
os << ");" << std::endl;
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
} else if (contiguous) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index];" << std::endl;
} else if (!dynamic_dims) {
int offset = nc_in_count * ndim;
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[";
os << "index_0 * " << xname << "_strides[0]";
os << "index_0 * " << "in_strides[" << offset << "]";
for (int i = 1; i < ndim; i++) {
os << " + index_" << i << " * " << xname << "_strides[" << i << "]";
os << " + index_" << i << " * " << "in_strides[" << offset + i << "]";
}
os << "];" << std::endl;
nc_in_count++;
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[elem_to_loc(index, output_shape, " << xname
<< "_strides, ndim)];" << std::endl;
<< xname << "[elem_to_loc(index, output_shape, in_strides + "
<< nc_in_count * ndim << ", ndim)];" << std::endl;
nc_in_count++;
}
}

Expand Down Expand Up @@ -296,20 +306,25 @@ void Compiled::eval_gpu(
// Put the inputs in
int cnt = 0;
int stride_idx = 1; // idx 0 is the output strides
std::vector<size_t> in_strides;
for (int i = 0; i < inputs.size(); i++) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
}
auto& x = inputs[i];
compute_encoder.set_input_array(x, cnt++);
if (!contiguous && !is_scalar(x)) {
compute_encoder->setBytes(
strides[stride_idx].data(),
strides[stride_idx].size() * sizeof(size_t),
cnt++);
in_strides.insert(
in_strides.end(),
strides[stride_idx].begin(),
strides[stride_idx].end());
stride_idx++;
}
}
if (!in_strides.empty()) {
compute_encoder->setBytes(
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
}

compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, true);
Expand Down
5 changes: 4 additions & 1 deletion mlx/backend/metal/jit_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,14 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& op_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce()
Expand All @@ -273,7 +276,7 @@ MTL::ComputePipelineState* get_reduce_kernel(
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out));
op_type);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/metal/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& op_name,
const array& in,
const array& out);

Expand Down
4 changes: 2 additions & 2 deletions mlx/backend/metal/kernels/reduce.metal
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
#define instantiate_reduce_ops(inst_f, type_f) \
type_f(inst_f, sum, Sum) \
type_f(inst_f, prod, Prod) \
type_f(inst_f, min_, Min) \
type_f(inst_f, max_, Max)
type_f(inst_f, min, Min) \
type_f(inst_f, max, Max)

// Special case for bool reductions
#define instantiate_reduce_from_types_helper( \
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/metal/nojit_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string&,
const array&,
const array&) {
return d.get_kernel(kernel_name);
Expand Down
14 changes: 7 additions & 7 deletions mlx/backend/metal/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void all_reduce_dispatch(
kernel_name += "NoAtomics";
}
kernel_name += "_reduce_" + op_name + type_to_name(in);
auto kernel = get_reduce_kernel(d, kernel_name, in, out);
auto kernel = get_reduce_kernel(d, kernel_name, op_name, in, out);

compute_encoder->setComputePipelineState(kernel);

Expand Down Expand Up @@ -175,7 +175,7 @@ void row_reduce_general_dispatch(
kname << "rowGeneral" << small_desc << "_reduce_" << op_name
<< type_to_name(in);

auto kernel = get_reduce_kernel(d, kname.str(), in, out);
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
compute_encoder->setComputePipelineState(kernel);

// Get dispatch grid dims
Expand Down Expand Up @@ -342,7 +342,7 @@ void strided_reduce_general_dispatch(
if (reduction_size * non_col_reductions < 16) {
// Select kernel
auto kernel = get_reduce_kernel(
d, "colSmall_reduce_" + op_name + type_to_name(in), in, out);
d, "colSmall_reduce_" + op_name + type_to_name(in), op_name, in, out);
compute_encoder->setComputePipelineState(kernel);

// Select block dims
Expand Down Expand Up @@ -384,7 +384,7 @@ void strided_reduce_general_dispatch(
kernel_name += "NoAtomics";
}
kernel_name += "_reduce_" + op_name + type_to_name(in);
auto kernel = get_reduce_kernel(d, kernel_name, in, out);
auto kernel = get_reduce_kernel(d, kernel_name, op_name, in, out);

compute_encoder->setComputePipelineState(kernel);

Expand Down Expand Up @@ -501,7 +501,7 @@ void strided_reduce_general_dispatch(
std::string kernel_name =
"rowGeneralNoAtomics_reduce_" + op_name + type_to_name(intermediate);
auto row_reduce_kernel =
get_reduce_kernel(d, kernel_name, intermediate, out);
get_reduce_kernel(d, kernel_name, op_name, intermediate, out);

compute_encoder->setComputePipelineState(row_reduce_kernel);
compute_encoder.set_input_array(intermediate, 0);
Expand Down Expand Up @@ -573,10 +573,10 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
op_name = out.dtype() == bool_ ? "and" : "prod";
break;
case Reduce::Min:
op_name = out.dtype() == bool_ ? "and" : "min_";
op_name = out.dtype() == bool_ ? "and" : "min";
break;
case Reduce::Max:
op_name = out.dtype() == bool_ ? "or" : "max_";
op_name = out.dtype() == bool_ ? "or" : "max";
break;
}

Expand Down
10 changes: 7 additions & 3 deletions mlx/transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,13 @@ array eval_impl(std::vector<array> outputs, bool async) {
" transformations like compile or vmap is not allowed.");
}
throw std::runtime_error(
"[eval] Attempting to eval an array without a primitive. "
"This may be a bug, please file an issue here: "
" https://github.com/ml-explore/mlx/issues.");
"[eval] Attempting to eval an array without a primitive.\n"
"If you are compiling a function, make sure all the inputs "
"and outputs are captured:\n"
"https://ml-explore.github.io/mlx/build/html/usage/compile.html#pure-functions.\n"
"If you are not using compile, this may be a bug. "
"Please file an issue here:\n"
"https://github.com/ml-explore/mlx/issues.");
}
if (a.primitive().stream() != in.primitive().stream()) {
needs_signal.insert(in.id());
Expand Down
7 changes: 7 additions & 0 deletions python/tests/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,13 @@ def fn(x):
self.assertEqual(y1.item(), y2.item())
self.assertEqual(y1.item(), 6)

def test_inf_constant(self):
def fn(x):
return mx.where(mx.isinf(x), 0, 1)

x = mx.array([0, float("inf"), 1], dtype=mx.bfloat16)
self.assertTrue(mx.array_equal(mx.compile(fn)(x), fn(x)))


if __name__ == "__main__":
unittest.main()

0 comments on commit e7a2a3d

Please sign in to comment.