Skip to content

Commit 3139722

Browse files
crcrparpytorchmergebot
authored andcommitted
[foreach][mta] Inplace maximum and minimum (pytorch#82523)
### Description <!-- What did you change and why was it needed? --> Implement `torch._foreach_maximum_` and `torch._foreach_minimum_` mainly for `_multi_tensor_adam` and `_multi_tensor_adamw` with `amsgrad=True` to correctly update their `max_exp_avg_sqs`. ### Issue <!-- Link to Issue ticket or RFP --> - pytorch#78807 - pytorch#81894 - pytorch#81348 - pytorch#81705 - pytorch#58833 - pytorch#68041 ### Testing <!-- How did you test your change? --> Updated `test_foreach.py::TestForeach::_minmax_test` to compare the outputs of `_foreach_maximum_` (and `_foreach_minimum_`) against those of `[torch.maximum(a, b) for a, b in zip(tensors1, tensors2)]` cc @ngimel @albanD @mikaylagawarecki Pull Request resolved: pytorch#82523 Approved by: https://github.com/albanD
1 parent 9647bec commit 3139722

File tree

7 files changed

+69
-12
lines changed

7 files changed

+69
-12
lines changed

aten/src/ATen/native/ForeachOpsKernels.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ FOREACH_POINTWISE_OP_SCALAR(addcmul);
199199
FOREACH_POINTWISE_OP_SCALARLIST(addcdiv);
200200
FOREACH_POINTWISE_OP_SCALARLIST(addcmul);
201201

202+
// NOTE(crcrpar): It didn't seem feasible to use `self[i]` as both the first and the last
203+
// arguments of `maximum_out` and `minimum_out` so I tentatively embarrassingly get and copy
204+
// the result to `self[i]`.
202205
#define FOREACH_MAXIMUM_MINIMUM_OP(NAME) \
203206
std::vector<Tensor> foreach_tensor_##NAME##_slow(TensorList tensors1, TensorList tensors2) { \
204207
check_foreach_api_restrictions(tensors1, tensors2); \
@@ -211,6 +214,13 @@ std::vector<Tensor> foreach_tensor_##NAME##_slow(TensorList tensors1, TensorList
211214
\
212215
return result; \
213216
} \
217+
void foreach_tensor_##NAME##_slow_(TensorList self, TensorList other) { \
218+
check_foreach_api_restrictions(self, other); \
219+
for (const auto i : c10::irange(self.size())) { \
220+
const auto tmp = at::NAME(self[i], other[i]); \
221+
self[i].copy_(tmp, /* non_blocking */ true); \
222+
} \
223+
}
214224

215225
FOREACH_MAXIMUM_MINIMUM_OP(maximum)
216226
FOREACH_MAXIMUM_MINIMUM_OP(minimum)

aten/src/ATen/native/cuda/ForeachPointwiseOp.cu

+28-3
Original file line numberDiff line numberDiff line change
@@ -188,20 +188,45 @@ std::vector<Tensor> foreach_tensor_##NAME##_cuda(TensorList tensors1, TensorList
188188
tensor_lists.emplace_back(std::move(vec_res)); \
189189
\
190190
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, tensors1[0].scalar_type(), "foreach_maximum_minimum_op_cuda", [&]() { \
191-
using opmath_t = at::opmath_type<scalar_t>; \
191+
using opmath_t = at::opmath_type<scalar_t>; \
192192
auto op = [] GPU_LAMBDA (opmath_t a, opmath_t b) -> opmath_t { \
193193
opmath_t c = a OP b ? a : b; \
194194
if (_isnan(a)) { \
195195
c = a; \
196196
} \
197197
return c;}; \
198198
multi_tensor_apply<3>(tensor_lists, \
199-
PointwiseOpListFunctor<scalar_t, 3>(), \
200-
op); \
199+
BinaryOpListAlphaFunctor<scalar_t, 3, 2, 2>(), \
200+
op, \
201+
opmath_t(1)); \
201202
}); \
202203
\
203204
return tensor_lists[2]; \
204205
} \
206+
\
207+
void foreach_tensor_##NAME##_cuda_(TensorList self, TensorList other) { \
208+
check_foreach_api_restrictions(self, other); \
209+
if (!can_use_fast_route({self, other}) || has_bool_tensor(self)) { \
210+
return at::native::foreach_tensor_##NAME##_slow_(self, other); \
211+
} \
212+
\
213+
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, self[0].scalar_type(), "foreach_maximum_minimum_op_cuda_", \
214+
[&]() { \
215+
using opmath_t = at::opmath_type<scalar_t>; \
216+
std::vector<std::vector<at::Tensor>> tensor_lists{self.vec(), other.vec()}; \
217+
auto op = [] GPU_LAMBDA (opmath_t a, opmath_t b) -> opmath_t { \
218+
opmath_t c = a OP b ? a : b; \
219+
if (_isnan(a)) { \
220+
c = a; \
221+
} \
222+
return c; \
223+
}; \
224+
multi_tensor_apply<2>(tensor_lists, \
225+
BinaryOpListAlphaFunctor<scalar_t, 2, 2, 0>(), \
226+
op, \
227+
opmath_t(1)); \
228+
}); \
229+
} \
205230

206231
FOREACH_MAXIMUM_MINIMUM_OP(maximum, >)
207232
FOREACH_MAXIMUM_MINIMUM_OP(minimum, <)

aten/src/ATen/native/native_functions.yaml

+16
Original file line numberDiff line numberDiff line change
@@ -9412,13 +9412,29 @@
94129412
CPU: foreach_tensor_maximum_slow
94139413
CUDA: foreach_tensor_maximum_cuda
94149414

9415+
- func: _foreach_maximum_.List(Tensor(a!)[] self, Tensor[] other) -> ()
9416+
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
9417+
variants: function
9418+
dispatch:
9419+
CPU: foreach_tensor_maximum_slow_
9420+
CUDA: foreach_tensor_maximum_cuda_
9421+
autogen: _foreach_maximum.List_out
9422+
94159423
- func: _foreach_minimum.List(Tensor[] self, Tensor[] other) -> Tensor[]
94169424
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
94179425
variants: function
94189426
dispatch:
94199427
CPU: foreach_tensor_minimum_slow
94209428
CUDA: foreach_tensor_minimum_cuda
94219429

9430+
- func: _foreach_minimum_.List(Tensor(a!)[] self, Tensor[] other) -> ()
9431+
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
9432+
variants: function
9433+
dispatch:
9434+
CPU: foreach_tensor_minimum_slow_
9435+
CUDA: foreach_tensor_minimum_cuda_
9436+
autogen: _foreach_minimum.List_out
9437+
94229438
- func: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2) -> Tensor[]
94239439
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
94249440
variants: function

test/test_foreach.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def is_cuda(self):
9797
# note(mkozuki): It might be the case that the expected number of `cudaLaunchKernel`s
9898
# is greater than 1 once foreach functions internally separate their input `TensorList`s by
9999
# devices & dtypes into vectors of tensors.
100-
def _get_funcs(self, op, n_expected_cudaLaunchKernels):
100+
def _get_funcs(self, op, n_expected_cudaLaunchKernels: int):
101101
return (
102102
ForeachFuncWrapper(op.method_variant, n_expected_cudaLaunchKernels),
103103
RegularFuncWrapper(op.ref),
@@ -370,11 +370,17 @@ def test_unary_slowpath(self, device, dtype, op):
370370
for N in N_values:
371371
self._test_unary(device, dtype, op, N, is_fastpath=False)
372372

373+
# note(crcrpar): `torch.maximum` and `torch.minimum` support `out` arg but there seem to be no inplace versions.
374+
# So, compare `inplace_op` results with `ref`'s outputs.
373375
def _minmax_test(self, opinfo, inputs, is_fastpath, n_expected_cudaLaunchKernels):
374-
op, ref, _, _ = self._get_funcs(opinfo, n_expected_cudaLaunchKernels)
375-
self.assertEqual(ref(inputs), op(inputs, self.is_cuda, is_fastpath))
376+
op, ref, inplace_op, _ = self._get_funcs(opinfo, n_expected_cudaLaunchKernels)
377+
expected = ref(inputs)
378+
self.assertEqual(expected, op(inputs, self.is_cuda, is_fastpath))
379+
380+
inplace_inputs = [[t.clone() for t in inputs[0]], inputs[1]]
381+
inplace_op(inplace_inputs, self.is_cuda, is_fastpath)
382+
self.assertEqual(expected, inplace_inputs[0])
376383

377-
# note(mkozuki): in-place of foreach_minimum and foreach_maximum aren't implemented.
378384
@ops(foreach_minmax_op_db)
379385
def test_minmax_fastpath(self, device, dtype, op):
380386
for N in N_values:

torch/optim/adam.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def _multi_tensor_adam(params: List[Tensor],
377377

378378
if amsgrad:
379379
# Maintains the maximum of all 2nd moment running avg. till now
380-
max_exp_avg_sqs = torch._foreach_maximum(max_exp_avg_sqs, exp_avg_sqs) # type: ignore[assignment]
380+
torch._foreach_maximum_(max_exp_avg_sqs, exp_avg_sqs) # type: ignore[assignment]
381381

382382
# Use the max. for normalizing running avg. of gradient
383383
max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sqs)
@@ -405,7 +405,7 @@ def _multi_tensor_adam(params: List[Tensor],
405405

406406
if amsgrad:
407407
# Maintains the maximum of all 2nd moment running avg. till now
408-
max_exp_avg_sqs = torch._foreach_maximum(max_exp_avg_sqs, exp_avg_sqs) # type: ignore[assignment]
408+
torch._foreach_maximum_(max_exp_avg_sqs, exp_avg_sqs)
409409

410410
# Use the max. for normalizing running avg. of gradient
411411
max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sqs)

torch/optim/adamw.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def _multi_tensor_adamw(params: List[Tensor],
369369

370370
if amsgrad:
371371
# Maintains the maximum of all 2nd moment running avg. till now
372-
max_exp_avg_sqs = torch._foreach_maximum(max_exp_avg_sqs, exp_avg_sqs) # type: ignore[assignment]
372+
torch._foreach_maximum_(max_exp_avg_sqs, exp_avg_sqs)
373373

374374
# Use the max. for normalizing running avg. of gradient
375375
max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sqs)
@@ -397,7 +397,7 @@ def _multi_tensor_adamw(params: List[Tensor],
397397

398398
if amsgrad:
399399
# Maintains the maximum of all 2nd moment running avg. till now
400-
max_exp_avg_sqs = torch._foreach_maximum(max_exp_avg_sqs, exp_avg_sqs) # type: ignore[assignment]
400+
torch._foreach_maximum_(max_exp_avg_sqs, exp_avg_sqs)
401401

402402
# Use the max. for normalizing running avg. of gradient
403403
max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sqs)

torch/testing/_internal/common_methods_invocations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6726,7 +6726,7 @@ def sample_inputs_foreach(self, device, dtype, N, *, noncontiguous=False, same_s
67266726
def get_foreach_method_names(name):
67276727
# get torch inplace reference function
67286728
op_name = "_foreach_" + name
6729-
inplace_op_name = "_foreach_" + name + "_"
6729+
inplace_op_name = op_name + "_"
67306730

67316731
op = getattr(torch, op_name, None)
67326732
inplace_op = getattr(torch, inplace_op_name, None)

0 commit comments

Comments
 (0)