Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【新算子】- cholesky 算子开发 #1005

Open
PetrelYy opened this issue Apr 19, 2024 · 16 comments
Open

【新算子】- cholesky 算子开发 #1005

PetrelYy opened this issue Apr 19, 2024 · 16 comments
Assignees
Labels
ICT New Op Contribute a new operator

Comments

@PetrelYy
Copy link
Collaborator

开发计划可参考以下节点:

  1. 方案撰写,xx.xx~xx.xx
  2. 开发自测,xx.xx~xx.xx
  3. 提出 PR/MR,xx.xx~xx.xx
  4. review( 3个赞),xx.xx~xx.xx
  5. maintainer 合入
@dglr dglr self-assigned this Apr 19, 2024
@PetrelYy PetrelYy added New Op Contribute a new operator ICT labels Apr 23, 2024
@dglr
Copy link
Collaborator

dglr commented May 14, 2024

当前进度:完成float类型多batch计算,运行时间达到v100的10倍以内,性能如下
image

@PetrelYy
Copy link
Collaborator Author

PR 一起贴出来吧

@dglr
Copy link
Collaborator

dglr commented May 20, 2024

pr链接:#1018

@dglr
Copy link
Collaborator

dglr commented May 24, 2024

设计文档已完成

@dglr
Copy link
Collaborator

dglr commented Jun 7, 2024

已完成complex类型batch=1时的cholesky分解,并在pr中更新了commit

@dglr
Copy link
Collaborator

dglr commented Jun 7, 2024

当前complex类型的cholesky分解正确性还没有完全稳定,矩阵规模较大时有概率正确性无法通过。
当前测试性能如下:
image

@PetrelYy
Copy link
Collaborator Author

已安排寒武纪同学review 文档

@dglr 更新下进展把:预计开发&自测完成时间

@dglr
Copy link
Collaborator

dglr commented Jun 12, 2024

目前已完成complex类型batch=1情况的开发,正在优化性能中,预计6.17前优化完毕性能,然后开始编写多batch情况,预计6.24前完成多batch情况的编写及优化。之后7.1之前完成自测

@dglr
Copy link
Collaborator

dglr commented Jun 14, 2024

上次开会后增加了pytorch中cholesky分解作为时间对比,当前运行时间对比如下,单位为微秒:
image
上述规模中128 * 128和3000 * 3000是任务书中要求的规模。
当前性能瓶颈在矩阵乘法中。以3072 * 3072规模为例,对其进行性能分析,结果如下:
ddbca1abdffce50fc28016d7178940d
图中红框中为调用底层的矩阵乘法,且由于没有cgemm的底层实现,当前cgemm是由4个sgemm拼接而成。
可以看到矩阵乘法的时间占比总和已经达到了70%,矩阵乘法所占用时间超过了5300微秒,已经超过了cusolver和pytorch运行时间的10倍。

@dglr
Copy link
Collaborator

dglr commented Jun 21, 2024

本周完成了多batch下矩阵分解代码编写完成,现在基本调通,有些规模还有bug。

@dglr
Copy link
Collaborator

dglr commented Jun 28, 2024

本周完成了多batch下complex float类型的矩阵分解进行性能调优,下图为运行时间对比。
image
上图中batch数为32,规模为128时瓶颈在矩阵乘法中。对其进行性能分析,结果如下图所示:
image
图中红框中为调用底层的矩阵乘法,且由于没有cgemm的底层实现,当前cgemm是由4个sgemm拼接而成。可以看到矩阵乘法的时间占比总和已经达到了60%,矩阵乘法所占用时间超过了2000微秒,已经超过了cusolver和pytorch运行时间的10倍。

@dglr
Copy link
Collaborator

dglr commented Jul 5, 2024

cholesky算子在测试时发现精度问题。cholesky分解算子在运算过程中会涉及到开方、除法以及大量的累积乘加运算,就结果来看最终的输出结果精度不甚理想。
这里测试的计算流程为:生成随机数->将随机数的上三角部分置为0,使得矩阵变成下三角矩阵->矩阵和自身的转置相乘,保证结果可以被cholesky分解->将结果传递给设备计算。
在输入矩阵规模为batch=1,边长为32的情况下,mlu上计算出的误差结果为(相比于理论上正确的结果):
image
相同的输入在pytorch 2.3.1+cu121版本中计算,误差结果为:
image
这个误差会随着规模的扩大而迅速增加。
当输入规模扩大到64时mlu上的计算出现nan,而pytorch中报错:
image
所以目前的结论是有些理论上可以被cholesky分解的矩阵在实际计算中因为精度的问题无法被分解

@dglr
Copy link
Collaborator

dglr commented Jul 12, 2024

当前进度:上述精度问题已解决。目前进展:已完成float类型的测试,这周末将完成complex float类型的测试

@dglr
Copy link
Collaborator

dglr commented Jul 23, 2024

测试使用的json如下
mannul_shape_3000.json

@dglr
Copy link
Collaborator

dglr commented Jul 23, 2024

测试使用的compute.py如下:

import torch
from nonmlu_ops.base import *
import logging
import copy
import os


@registerTensorList("cholesky")
class CholeskyTensorList(TensorList):
    pass

def print_matrix(A):
    if A.ndim == 3:
        batch = A.shape[0]
        size = A.shape[1]
        for i in range(batch):
                for j in range(size):
                    for k in range(size):
                        print("{:.3}".format(A[i][j][k]),end=" ")
                    print("\n")
                print("\n")
    elif A.ndim == 2:
        size = A.shape[0]
        for i in range(size):
            for j in range(size):
                print("{:.3}".format(A[i][j]),end=" ")
            print("\n")

def set_complex_data(data_node, complex_tensor):
    cpu_array = complex_tensor.cpu().numpy()
    cpu_real = np.real(cpu_array)
    cpu_imag = np.imag(cpu_array)
    data_node.setComplexData(cpu_real, cpu_imag)

def set_diag_imag_one(input_tensor):
    if input_tensor.dim() == 2:
        diag_indices = torch.arange(input_tensor.size(0), device=input_tensor.device)
        input_tensor[diag_indices, diag_indices] += 1j - input_tensor[diag_indices, diag_indices].imag * 1j

    elif input_tensor.dim() == 3:
        batch_size, n, _ = input_tensor.size()
        for i in range(batch_size):
            diag_indices = torch.arange(n, device=input_tensor.device)
            input_tensor[i, diag_indices, diag_indices] += 1j - input_tensor[i, diag_indices, diag_indices].imag * 1j


@registerOp("cholesky")
class CholeskyOp(OpTest):
    def __init__(self,tensorlist,params):
        super().__init__(tensorlist,params)
        self.upper_ = self.params_.get("upper", False)

    compute_cout = 0
    
    def compute(self):
        os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'
        gpu_count = torch.cuda.device_count()
        print("gpu_count:")
        print(torch.cuda.device_count())
        print("os visible:",os.environ.get('CUDA_VISIBLE_DEVICES'))
        cuda_count = CholeskyOp.compute_cout % gpu_count
        print('now gpu:',cuda_count)
        CholeskyOp.compute_cout += 1

        result_mul = False

        input_tensor = self.tensor_list_.getInputTensor(0)
        output_tensor = self.tensor_list_.getOutputTensor(0)
        input_is_complex = input_tensor.getDataType().isComplex()
        upper = self.upper_
        

        if not input_is_complex:
            input_data = torch.tensor(input_tensor.getData()).cuda(cuda_count)
            input_data_fp64 = input_data.type(torch.float64).cuda(cuda_count)
            upper_triangle = torch.triu(input_data, diagonal=1)
            L_matrix = input_data - upper_triangle
            
            del input_data
            del upper_triangle
            torch.cuda.empty_cache()
            
            batch = 1
            size = L_matrix.size(1)
            if L_matrix.dim() == 2:
                U_matrix = L_matrix.transpose(0, 1)
                A = torch.mm(L_matrix, U_matrix)
                del L_matrix
                torch.cuda.empty_cache()
                eye = size * torch.eye(size, dtype=torch.float32).cuda(cuda_count)
                A = A + eye
            elif L_matrix.dim() == 3:
                U_matrix = L_matrix.transpose(1, 2)
                A = torch.bmm(L_matrix, U_matrix)
                batch = L_matrix.size(0)
                del L_matrix
                torch.cuda.empty_cache()
                eye = size * torch.eye(size, dtype=torch.float32).expand(batch, -1, -1).cuda(cuda_count)
                A = A + eye
            else:
                exit()
            
            del eye
            del U_matrix
            torch.cuda.empty_cache()

            input_tensor.setData(A.cpu().numpy())

            result_L = torch.linalg.cholesky(A,upper=upper)
            torch.cuda.empty_cache()


            if result_mul:
                if not upper:
                    if result_L.dim() == 2:
                        result = torch.matmul(result_L, result_L.transpose(0, 1))
                    else:
                        result = torch.bmm(result_L, result_L.transpose(1, 2))
                else:
                    if result_L.dim() == 2:
                        result = torch.matmul(result_L.transpose(0, 1),result_L)
                    else:
                        result = torch.bmm(result_L.transpose(1, 2),result_L)

            else:

                if not upper:
                    lower_triangle = torch.tril(result_L, diagonal=-1)
                    if result_L.dim() == 2:
                        result = result_L + lower_triangle.transpose(0, 1)
                    else:
                        result = result_L + lower_triangle.transpose(1, 2)
                else:
                    upper_triangle = torch.triu(result_L, diagonal=1)
                    if result_L.dim() == 2:
                        result = result_L + upper_triangle.transpose(0, 1)
                    else:
                        result = result_L + upper_triangle.transpose(1, 2)
            
            output_result = result.cpu().numpy()
            output_tensor.setData(output_result)
            
            del result_L
            del result
            torch.cuda.empty_cache()
            

            upper_triangle_fp64 = torch.triu(input_data_fp64, diagonal=1)
            L_matrix_fp64 = input_data_fp64 - upper_triangle_fp64
            del upper_triangle_fp64
            del input_data_fp64
            torch.cuda.empty_cache()
            
            if L_matrix_fp64.dim() == 2:
                U_matrix_fp64 = L_matrix_fp64.transpose(0, 1)
                A_fp64 = torch.mm(L_matrix_fp64, U_matrix_fp64)
                del L_matrix_fp64
                del U_matrix_fp64
                A_fp64 = A_fp64 + torch.eye(size, dtype=torch.float64).cuda(cuda_count) * size
            elif L_matrix_fp64.dim() == 3:
                U_matrix_fp64 = L_matrix_fp64.transpose(1, 2)
                A_fp64 = torch.bmm(L_matrix_fp64, U_matrix_fp64)
                del L_matrix_fp64
                del U_matrix_fp64
                A_fp64 = A_fp64 + torch.eye(size, dtype=torch.float64).expand(batch, -1, -1).cuda(cuda_count) * size
            
            torch.cuda.empty_cache()
            A_fp64 = A.double()
            
            

            result_L_fp64 = torch.linalg.cholesky(A_fp64,upper=upper)
            torch.cuda.empty_cache()


            base_node = DataNode("double")

            if result_mul:
                if not upper:
                    if result_L_fp64.dim() == 2:
                        result = torch.matmul(result_L_fp64, result_L_fp64.transpose(0, 1))
                    else:
                        result = torch.bmm(result_L_fp64, result_L_fp64.transpose(1, 2))
                else:
                    if result_L_fp64.dim() == 2:
                        result = torch.matmul(result_L_fp64.transpose(0, 1),result_L_fp64)
                    else:
                        result = torch.bmm(result_L_fp64.transpose(1, 2),result_L_fp64)
                
            else:
                if not upper:
                    lower_triangle = torch.tril(result_L_fp64, diagonal=-1)
                    if result_L_fp64.dim() == 2:
                        result = result_L_fp64 + lower_triangle.transpose(0, 1)
                    else:
                        result = result_L_fp64 + lower_triangle.transpose(1, 2)
                    del lower_triangle
                else:
                    upper_triangle = torch.triu(result_L_fp64, diagonal=1)
                    if result_L_fp64.dim() == 2:
                        result = result_L_fp64 + upper_triangle.transpose(0, 1)
                    else:
                        result = result_L_fp64 + upper_triangle.transpose(1, 2)
                    del upper_triangle

            output_result = result.cpu().numpy()
            base_node.setData(output_result)
            del result
            del result_L_fp64
            torch.cuda.empty_cache()



            
            half_dynamic_threshold = 1e-3
            float_dynamic_threshold = 1e-5
            eva = diff_utils.Evaluator(base_node, output_tensor.getDataNode(), half_dynamic_threshold, float_dynamic_threshold)
            diff1 = eva.computeDiff1()
            diff2 = eva.computeDiff2()
            diff_3_2 = eva.computeDiff3_2(10.0)
            print("diff1: ", diff1)
            print("diff2: ", diff2)
            print("diff_3_2: ", diff_3_2)
            output_tensor.setDiff(diff1, diff2, -1, diff_3_2, -1)

        else:
            input_real_data, input_imag_data = input_tensor.getComplexData()
            upper_triangle_real = np.tril(input_real_data)
            upper_triangle_imag = np.tril(input_imag_data, k=-1)
            complex_numpy_array = upper_triangle_real + 1j * upper_triangle_imag
            del upper_triangle_real
            del upper_triangle_imag
        
            input_data = torch.tensor(complex_numpy_array, dtype=torch.complex64).cuda(cuda_count)
            del complex_numpy_array

            input_data_complex128 = input_data.type(torch.complex128).cuda(cuda_count)
            upper_triangle_complex64 = torch.triu(input_data, diagonal=1)
            L_matrix_complex64 = input_data - upper_triangle_complex64

            del input_data
            del upper_triangle_complex64
            torch.cuda.empty_cache()

            batch = 1
            size = L_matrix_complex64.size(1)
            print("Tensor shape:", L_matrix_complex64.shape)
            if L_matrix_complex64.dim() == 2:
                U_matrix_complex64 = L_matrix_complex64.transpose(0, 1).conj()
                A_complex64 = torch.mm(L_matrix_complex64, U_matrix_complex64)
                del L_matrix_complex64
                torch.cuda.empty_cache()
                eye_complex64 = torch.eye(size, dtype=torch.complex64).cuda(cuda_count) * size
                A_complex64 = A_complex64 + eye_complex64
            elif L_matrix_complex64.dim() == 3:
                U_matrix_complex64 = L_matrix_complex64.transpose(1, 2).conj()
                A_complex64 = torch.bmm(L_matrix_complex64, U_matrix_complex64)
                batch = L_matrix_complex64.size(0)
                del L_matrix_complex64
                torch.cuda.empty_cache()
                eye_complex64 = torch.eye(size, dtype=torch.complex64).expand(batch, -1, -1).cuda(cuda_count) * size
                A_complex64 = A_complex64 + eye_complex64
            else:
                exit()

            del eye_complex64
            del U_matrix_complex64
            torch.cuda.empty_cache()

            set_complex_data(input_tensor, A_complex64)


            result_L_complex64 = torch.linalg.cholesky(A_complex64,upper=upper)

            del A_complex64
            torch.cuda.empty_cache()
            
            if result_mul:
                if not upper:
                    if result_L_complex64.dim() == 2:
                        result = torch.mm(result_L_complex64, result_L_complex64.transpose(0, 1).conj())
                    else:
                        result = torch.bmm(result_L_complex64, result_L_complex64.transpose(1, 2).conj())
                else:
                    if result_L_complex64.dim() == 2:
                        result = torch.mm(result_L_complex64.transpose(0, 1).conj(),result_L_complex64)
                    else:
                        result = torch.bmm(result_L_complex64.transpose(1, 2).conj(),result_L_complex64)
                
            else:
                if not upper:
                    lower_triangle = torch.tril(result_L_complex64, diagonal=-1)
                    if result_L_complex64.dim() == 2:
                        result = result_L_complex64 + lower_triangle.transpose(0, 1)
                    else:
                        result = result_L_complex64 + lower_triangle.transpose(1, 2)
                    del lower_triangle
                else:
                    upper_triangle = torch.triu(result_L_complex64, diagonal=1)
                    if result_L_complex64.dim() == 2:
                        result = result_L_complex64 + upper_triangle.transpose(0, 1)
                    else:
                        result = result_L_complex64 + upper_triangle.transpose(1, 2)
                    del upper_triangle
            set_diag_imag_one(result)
            set_complex_data(output_tensor, result)

            del result_L_complex64
            del result
            torch.cuda.empty_cache()


            L_matrix_complex128 = input_data_complex128

            del input_data_complex128
            torch.cuda.empty_cache()

            
            if L_matrix_complex128.dim() == 2:
     
                A_complex128 = torch.mm(L_matrix_complex128, L_matrix_complex128.transpose(0, 1).conj()) + torch.eye(size, dtype=torch.complex128).cuda(cuda_count) * size
                del L_matrix_complex128
                torch.cuda.empty_cache()  
            elif L_matrix_complex128.dim() == 3:

                
                A_complex128 = torch.bmm(L_matrix_complex128, L_matrix_complex128.transpose(1, 2).conj()) 
                del L_matrix_complex128
                torch.cuda.empty_cache()  
                A_complex128 = A_complex128 + torch.eye(size, dtype=torch.complex128).expand(batch, -1, -1).cuda(cuda_count) * size
            
                

            result_L_complex128 = torch.linalg.cholesky(A_complex128,upper=upper)

            del A_complex128
            torch.cuda.empty_cache()

            base_node = DataNode("complex128")

            if result_mul:
                if not upper:
                    if result_L_complex128.dim() == 2:
                        result = torch.mm(result_L_complex128, result_L_complex128.transpose(0, 1).conj())
                    else:
                        result = torch.bmm(result_L_complex128, result_L_complex128.transpose(1, 2).conj())
                else:
                    if result_L_complex128.dim() == 2:
                        result = torch.mm(result_L_complex128.transpose(0, 1).conj(),result_L_complex128)
                    else:
                        result = torch.bmm(result_L_complex128.transpose(1, 2).conj(),result_L_complex128)   
            else:
                if not upper:
                    lower_triangle = torch.tril(result_L_complex128, diagonal=-1)
                    if result_L_complex128.dim() == 2:
                        result = result_L_complex128 + lower_triangle.transpose(0, 1)
                    else:
                        result = result_L_complex128 + lower_triangle.transpose(1, 2)
                    del lower_triangle
                else:
                    upper_triangle = torch.triu(result_L_complex128, diagonal=1)
                    if result_L_complex128.dim() == 2:
                        result = result_L_complex128 + upper_triangle.transpose(0, 1)
                    else:
                        result = result_L_complex128 + upper_triangle.transpose(1, 2)
                    del upper_triangle
            set_diag_imag_one(result)

            set_complex_data(base_node, result)
            del result_L_complex128
            del result
            torch.cuda.empty_cache()



            half_dynamic_threshold = 1e-3
            float_dynamic_threshold = 1e-5
            eva = diff_utils.Evaluator(base_node, output_tensor.getDataNode(), half_dynamic_threshold, float_dynamic_threshold)
            diff_3_2 = eva.computeDiff3_2(1.0)
            diff1 = eva.computeDiff1()
            diff2 = eva.computeDiff2()
            print("diff1: ", diff1)
            print("diff2: ", diff2)
            print("diff_3_2: ", diff_3_2)
            output_tensor.setDiff(diff1, diff2, -1, diff_3_2, -1)

        local_vars = list(locals().keys())
        for var in local_vars:
            del locals()[var]
        torch.cuda.empty_cache()



@registerProtoWriter("cholesky")
class CholeskyProtoWriter(MluOpProtoWriter):
    def dumpOpParam2Node(self):
        cholesky_param_node = self.proto_node_.cholesky_param
        cholesky_param_node.upper = self.op_params_.get("upper")



@ArtIntAI
Copy link
Collaborator

ArtIntAI commented Sep 4, 2024

测试json中也要包含测试nan/inf的场景

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ICT New Op Contribute a new operator
Projects
None yet
Development

No branches or pull requests

3 participants