-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【新算子】- cholesky 算子开发 #1005
Comments
PR 一起贴出来吧 |
pr链接:#1018 |
设计文档已完成 |
已完成complex类型batch=1时的cholesky分解,并在pr中更新了commit |
已安排寒武纪同学review 文档 @dglr 更新下进展把:预计开发&自测完成时间 |
目前已完成complex类型batch=1情况的开发,正在优化性能中,预计6.17前优化完毕性能,然后开始编写多batch情况,预计6.24前完成多batch情况的编写及优化。之后7.1之前完成自测 |
本周完成了多batch下矩阵分解代码编写完成,现在基本调通,有些规模还有bug。 |
当前进度:上述精度问题已解决。目前进展:已完成float类型的测试,这周末将完成complex float类型的测试 |
测试使用的json如下 |
测试使用的compute.py如下: import torch
from nonmlu_ops.base import *
import logging
import copy
import os
@registerTensorList("cholesky")
class CholeskyTensorList(TensorList):
pass
def print_matrix(A):
if A.ndim == 3:
batch = A.shape[0]
size = A.shape[1]
for i in range(batch):
for j in range(size):
for k in range(size):
print("{:.3}".format(A[i][j][k]),end=" ")
print("\n")
print("\n")
elif A.ndim == 2:
size = A.shape[0]
for i in range(size):
for j in range(size):
print("{:.3}".format(A[i][j]),end=" ")
print("\n")
def set_complex_data(data_node, complex_tensor):
cpu_array = complex_tensor.cpu().numpy()
cpu_real = np.real(cpu_array)
cpu_imag = np.imag(cpu_array)
data_node.setComplexData(cpu_real, cpu_imag)
def set_diag_imag_one(input_tensor):
if input_tensor.dim() == 2:
diag_indices = torch.arange(input_tensor.size(0), device=input_tensor.device)
input_tensor[diag_indices, diag_indices] += 1j - input_tensor[diag_indices, diag_indices].imag * 1j
elif input_tensor.dim() == 3:
batch_size, n, _ = input_tensor.size()
for i in range(batch_size):
diag_indices = torch.arange(n, device=input_tensor.device)
input_tensor[i, diag_indices, diag_indices] += 1j - input_tensor[i, diag_indices, diag_indices].imag * 1j
@registerOp("cholesky")
class CholeskyOp(OpTest):
def __init__(self,tensorlist,params):
super().__init__(tensorlist,params)
self.upper_ = self.params_.get("upper", False)
compute_cout = 0
def compute(self):
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'
gpu_count = torch.cuda.device_count()
print("gpu_count:")
print(torch.cuda.device_count())
print("os visible:",os.environ.get('CUDA_VISIBLE_DEVICES'))
cuda_count = CholeskyOp.compute_cout % gpu_count
print('now gpu:',cuda_count)
CholeskyOp.compute_cout += 1
result_mul = False
input_tensor = self.tensor_list_.getInputTensor(0)
output_tensor = self.tensor_list_.getOutputTensor(0)
input_is_complex = input_tensor.getDataType().isComplex()
upper = self.upper_
if not input_is_complex:
input_data = torch.tensor(input_tensor.getData()).cuda(cuda_count)
input_data_fp64 = input_data.type(torch.float64).cuda(cuda_count)
upper_triangle = torch.triu(input_data, diagonal=1)
L_matrix = input_data - upper_triangle
del input_data
del upper_triangle
torch.cuda.empty_cache()
batch = 1
size = L_matrix.size(1)
if L_matrix.dim() == 2:
U_matrix = L_matrix.transpose(0, 1)
A = torch.mm(L_matrix, U_matrix)
del L_matrix
torch.cuda.empty_cache()
eye = size * torch.eye(size, dtype=torch.float32).cuda(cuda_count)
A = A + eye
elif L_matrix.dim() == 3:
U_matrix = L_matrix.transpose(1, 2)
A = torch.bmm(L_matrix, U_matrix)
batch = L_matrix.size(0)
del L_matrix
torch.cuda.empty_cache()
eye = size * torch.eye(size, dtype=torch.float32).expand(batch, -1, -1).cuda(cuda_count)
A = A + eye
else:
exit()
del eye
del U_matrix
torch.cuda.empty_cache()
input_tensor.setData(A.cpu().numpy())
result_L = torch.linalg.cholesky(A,upper=upper)
torch.cuda.empty_cache()
if result_mul:
if not upper:
if result_L.dim() == 2:
result = torch.matmul(result_L, result_L.transpose(0, 1))
else:
result = torch.bmm(result_L, result_L.transpose(1, 2))
else:
if result_L.dim() == 2:
result = torch.matmul(result_L.transpose(0, 1),result_L)
else:
result = torch.bmm(result_L.transpose(1, 2),result_L)
else:
if not upper:
lower_triangle = torch.tril(result_L, diagonal=-1)
if result_L.dim() == 2:
result = result_L + lower_triangle.transpose(0, 1)
else:
result = result_L + lower_triangle.transpose(1, 2)
else:
upper_triangle = torch.triu(result_L, diagonal=1)
if result_L.dim() == 2:
result = result_L + upper_triangle.transpose(0, 1)
else:
result = result_L + upper_triangle.transpose(1, 2)
output_result = result.cpu().numpy()
output_tensor.setData(output_result)
del result_L
del result
torch.cuda.empty_cache()
upper_triangle_fp64 = torch.triu(input_data_fp64, diagonal=1)
L_matrix_fp64 = input_data_fp64 - upper_triangle_fp64
del upper_triangle_fp64
del input_data_fp64
torch.cuda.empty_cache()
if L_matrix_fp64.dim() == 2:
U_matrix_fp64 = L_matrix_fp64.transpose(0, 1)
A_fp64 = torch.mm(L_matrix_fp64, U_matrix_fp64)
del L_matrix_fp64
del U_matrix_fp64
A_fp64 = A_fp64 + torch.eye(size, dtype=torch.float64).cuda(cuda_count) * size
elif L_matrix_fp64.dim() == 3:
U_matrix_fp64 = L_matrix_fp64.transpose(1, 2)
A_fp64 = torch.bmm(L_matrix_fp64, U_matrix_fp64)
del L_matrix_fp64
del U_matrix_fp64
A_fp64 = A_fp64 + torch.eye(size, dtype=torch.float64).expand(batch, -1, -1).cuda(cuda_count) * size
torch.cuda.empty_cache()
A_fp64 = A.double()
result_L_fp64 = torch.linalg.cholesky(A_fp64,upper=upper)
torch.cuda.empty_cache()
base_node = DataNode("double")
if result_mul:
if not upper:
if result_L_fp64.dim() == 2:
result = torch.matmul(result_L_fp64, result_L_fp64.transpose(0, 1))
else:
result = torch.bmm(result_L_fp64, result_L_fp64.transpose(1, 2))
else:
if result_L_fp64.dim() == 2:
result = torch.matmul(result_L_fp64.transpose(0, 1),result_L_fp64)
else:
result = torch.bmm(result_L_fp64.transpose(1, 2),result_L_fp64)
else:
if not upper:
lower_triangle = torch.tril(result_L_fp64, diagonal=-1)
if result_L_fp64.dim() == 2:
result = result_L_fp64 + lower_triangle.transpose(0, 1)
else:
result = result_L_fp64 + lower_triangle.transpose(1, 2)
del lower_triangle
else:
upper_triangle = torch.triu(result_L_fp64, diagonal=1)
if result_L_fp64.dim() == 2:
result = result_L_fp64 + upper_triangle.transpose(0, 1)
else:
result = result_L_fp64 + upper_triangle.transpose(1, 2)
del upper_triangle
output_result = result.cpu().numpy()
base_node.setData(output_result)
del result
del result_L_fp64
torch.cuda.empty_cache()
half_dynamic_threshold = 1e-3
float_dynamic_threshold = 1e-5
eva = diff_utils.Evaluator(base_node, output_tensor.getDataNode(), half_dynamic_threshold, float_dynamic_threshold)
diff1 = eva.computeDiff1()
diff2 = eva.computeDiff2()
diff_3_2 = eva.computeDiff3_2(10.0)
print("diff1: ", diff1)
print("diff2: ", diff2)
print("diff_3_2: ", diff_3_2)
output_tensor.setDiff(diff1, diff2, -1, diff_3_2, -1)
else:
input_real_data, input_imag_data = input_tensor.getComplexData()
upper_triangle_real = np.tril(input_real_data)
upper_triangle_imag = np.tril(input_imag_data, k=-1)
complex_numpy_array = upper_triangle_real + 1j * upper_triangle_imag
del upper_triangle_real
del upper_triangle_imag
input_data = torch.tensor(complex_numpy_array, dtype=torch.complex64).cuda(cuda_count)
del complex_numpy_array
input_data_complex128 = input_data.type(torch.complex128).cuda(cuda_count)
upper_triangle_complex64 = torch.triu(input_data, diagonal=1)
L_matrix_complex64 = input_data - upper_triangle_complex64
del input_data
del upper_triangle_complex64
torch.cuda.empty_cache()
batch = 1
size = L_matrix_complex64.size(1)
print("Tensor shape:", L_matrix_complex64.shape)
if L_matrix_complex64.dim() == 2:
U_matrix_complex64 = L_matrix_complex64.transpose(0, 1).conj()
A_complex64 = torch.mm(L_matrix_complex64, U_matrix_complex64)
del L_matrix_complex64
torch.cuda.empty_cache()
eye_complex64 = torch.eye(size, dtype=torch.complex64).cuda(cuda_count) * size
A_complex64 = A_complex64 + eye_complex64
elif L_matrix_complex64.dim() == 3:
U_matrix_complex64 = L_matrix_complex64.transpose(1, 2).conj()
A_complex64 = torch.bmm(L_matrix_complex64, U_matrix_complex64)
batch = L_matrix_complex64.size(0)
del L_matrix_complex64
torch.cuda.empty_cache()
eye_complex64 = torch.eye(size, dtype=torch.complex64).expand(batch, -1, -1).cuda(cuda_count) * size
A_complex64 = A_complex64 + eye_complex64
else:
exit()
del eye_complex64
del U_matrix_complex64
torch.cuda.empty_cache()
set_complex_data(input_tensor, A_complex64)
result_L_complex64 = torch.linalg.cholesky(A_complex64,upper=upper)
del A_complex64
torch.cuda.empty_cache()
if result_mul:
if not upper:
if result_L_complex64.dim() == 2:
result = torch.mm(result_L_complex64, result_L_complex64.transpose(0, 1).conj())
else:
result = torch.bmm(result_L_complex64, result_L_complex64.transpose(1, 2).conj())
else:
if result_L_complex64.dim() == 2:
result = torch.mm(result_L_complex64.transpose(0, 1).conj(),result_L_complex64)
else:
result = torch.bmm(result_L_complex64.transpose(1, 2).conj(),result_L_complex64)
else:
if not upper:
lower_triangle = torch.tril(result_L_complex64, diagonal=-1)
if result_L_complex64.dim() == 2:
result = result_L_complex64 + lower_triangle.transpose(0, 1)
else:
result = result_L_complex64 + lower_triangle.transpose(1, 2)
del lower_triangle
else:
upper_triangle = torch.triu(result_L_complex64, diagonal=1)
if result_L_complex64.dim() == 2:
result = result_L_complex64 + upper_triangle.transpose(0, 1)
else:
result = result_L_complex64 + upper_triangle.transpose(1, 2)
del upper_triangle
set_diag_imag_one(result)
set_complex_data(output_tensor, result)
del result_L_complex64
del result
torch.cuda.empty_cache()
L_matrix_complex128 = input_data_complex128
del input_data_complex128
torch.cuda.empty_cache()
if L_matrix_complex128.dim() == 2:
A_complex128 = torch.mm(L_matrix_complex128, L_matrix_complex128.transpose(0, 1).conj()) + torch.eye(size, dtype=torch.complex128).cuda(cuda_count) * size
del L_matrix_complex128
torch.cuda.empty_cache()
elif L_matrix_complex128.dim() == 3:
A_complex128 = torch.bmm(L_matrix_complex128, L_matrix_complex128.transpose(1, 2).conj())
del L_matrix_complex128
torch.cuda.empty_cache()
A_complex128 = A_complex128 + torch.eye(size, dtype=torch.complex128).expand(batch, -1, -1).cuda(cuda_count) * size
result_L_complex128 = torch.linalg.cholesky(A_complex128,upper=upper)
del A_complex128
torch.cuda.empty_cache()
base_node = DataNode("complex128")
if result_mul:
if not upper:
if result_L_complex128.dim() == 2:
result = torch.mm(result_L_complex128, result_L_complex128.transpose(0, 1).conj())
else:
result = torch.bmm(result_L_complex128, result_L_complex128.transpose(1, 2).conj())
else:
if result_L_complex128.dim() == 2:
result = torch.mm(result_L_complex128.transpose(0, 1).conj(),result_L_complex128)
else:
result = torch.bmm(result_L_complex128.transpose(1, 2).conj(),result_L_complex128)
else:
if not upper:
lower_triangle = torch.tril(result_L_complex128, diagonal=-1)
if result_L_complex128.dim() == 2:
result = result_L_complex128 + lower_triangle.transpose(0, 1)
else:
result = result_L_complex128 + lower_triangle.transpose(1, 2)
del lower_triangle
else:
upper_triangle = torch.triu(result_L_complex128, diagonal=1)
if result_L_complex128.dim() == 2:
result = result_L_complex128 + upper_triangle.transpose(0, 1)
else:
result = result_L_complex128 + upper_triangle.transpose(1, 2)
del upper_triangle
set_diag_imag_one(result)
set_complex_data(base_node, result)
del result_L_complex128
del result
torch.cuda.empty_cache()
half_dynamic_threshold = 1e-3
float_dynamic_threshold = 1e-5
eva = diff_utils.Evaluator(base_node, output_tensor.getDataNode(), half_dynamic_threshold, float_dynamic_threshold)
diff_3_2 = eva.computeDiff3_2(1.0)
diff1 = eva.computeDiff1()
diff2 = eva.computeDiff2()
print("diff1: ", diff1)
print("diff2: ", diff2)
print("diff_3_2: ", diff_3_2)
output_tensor.setDiff(diff1, diff2, -1, diff_3_2, -1)
local_vars = list(locals().keys())
for var in local_vars:
del locals()[var]
torch.cuda.empty_cache()
@registerProtoWriter("cholesky")
class CholeskyProtoWriter(MluOpProtoWriter):
def dumpOpParam2Node(self):
cholesky_param_node = self.proto_node_.cholesky_param
cholesky_param_node.upper = self.op_params_.get("upper")
|
测试json中也要包含测试nan/inf的场景 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
开发计划可参考以下节点:
The text was updated successfully, but these errors were encountered: