forked from ml-explore/mlx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdevice.cpp
44 lines (36 loc) · 1.16 KB
/
device.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// Copyright © 2023 Apple Inc.
#include <sstream>
#include <pybind11/pybind11.h>
#include "mlx/device.h"
#include "mlx/utils.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
void init_device(py::module_& m) {
auto device_class = py::class_<Device>(m, "Device");
py::enum_<Device::DeviceType>(m, "DeviceType")
.value("cpu", Device::DeviceType::cpu)
.value("gpu", Device::DeviceType::gpu)
.export_values()
.def(
"__eq__",
[](const Device::DeviceType& d1, const Device& d2) {
return d1 == d2;
},
py::prepend());
device_class.def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
.def_readonly("type", &Device::type)
.def(
"__repr__",
[](const Device& d) {
std::ostringstream os;
os << d;
return os.str();
})
.def("__eq__", [](const Device& d1, const Device& d2) {
return d1 == d2;
});
py::implicitly_convertible<Device::DeviceType, Device>();
m.def("default_device", &default_device);
m.def("set_default_device", &set_default_device, "device"_a);
}