Skip to content

Commit

Permalink
[FEATURE] OpInfo Metric in Pytorch namespace
Browse files Browse the repository at this point in the history
Currently, OpInfo metric will print:
OpIndex, OpName, TotalExecutionTime, MaxKernelIndex, MaxKernelName,
MaxKernelExecutionTime, kernelNumTracer, kernelNumCounter,
TotalDramRead, TotalDramWrite, TotalSFOp, TotalElapsedCycles

In addition, we find the reason why mapping mechanism is invalid
and fixed it.

Signed-off-by: YushuoEdge <[email protected]>
  • Loading branch information
YushuoEdge committed Apr 26, 2022
1 parent a1e9913 commit 50d1e60
Show file tree
Hide file tree
Showing 12 changed files with 346 additions and 100 deletions.
141 changes: 137 additions & 4 deletions src/amanda/profiler/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from utils import findTopKKernelTracer, findTopKKernelCounter
from utils import findTopK

def drawRoofline(hardwareTFlops, hardwareIntensity, X, Y):
maxright=0
Expand Down Expand Up @@ -88,7 +88,7 @@ def kernelInfoTracer(opList, timeList, apiList, rtList):

# Find Top-K kernel according to kernel execution time
k = min(20, len(kernelList))
ansListTracer = findTopKKernelTracer(infoList, k)
ansListTracer = findTopK(infoList, k, key = 6)
resTracer = pd.DataFrame(ansListTracer)
resTracer.columns = ['OpIndex', 'OpName', 'KernelIndex', 'KernelType', 'KernelName', 'LaunchKernelTime(ns)', 'KernelExecutionTime(ns)']
print("Number of tracer records: " + str(len(infoList)))
Expand Down Expand Up @@ -137,7 +137,7 @@ def kernelInfoCounter(dataList, flopCount=True):

# Find Top-K kernel according to kernel execution time
k = min(20, len(infoList))
ansListCounter = findTopKKernelCounter(infoList, k)
ansListCounter = findTopK(infoList, k, key = 3)
resCounter = pd.DataFrame(ansListCounter)
if flopCount == True:
resCounter.columns = ['OpIndex', 'OpName', 'KernelIndex', 'cyclesElapsed', 'dramRead', 'dramWrite', 'spAdd', 'spFma', 'spMul']
Expand Down Expand Up @@ -181,4 +181,137 @@ def kernelRoofline(supplyInfo, countData):
intensity = spOp / (dramRead + dramWrite)
kernelX.append(intensity)

drawRoofline(hardwareTFlops, hardwareIntensity, kernelX, kernelY)
drawRoofline(hardwareTFlops, hardwareIntensity, kernelX, kernelY)



# TOP-20 Tracer OP Information [key: executionTime], return ansListTracer for mapping
# Information collected by tracer: OpIndex, OpName, TotalExecutionTime, MaxKernelIndex, MaxKernelName, MaxKernelExecutionTime, kernelNumTracer
def opInfoTracer(opList, startTimeList, endTimeList, apiList, rtList):
# Filter cudaLaunchKernel and kernel
launchKernelApiList = []
kernelList = []
for x in apiList:
if x.name == "cudaLaunchKernel_v7000":
launchKernelApiList.append(x)
for x in rtList:
if x.kind == "KERNEL" or x.kind == "CONC KERNEL":
kernelList.append(x)

# Calculate number of kernels for each op
kernelNumList = []
kernelCount = 0
opIndex = 0
kernelIndex = 0
timeListLen = len(startTimeList)

while opIndex < timeListLen - 1 and startTimeList[opIndex] < launchKernelApiList[0].startTime and startTimeList[opIndex+1] < launchKernelApiList[0].startTime:
kernelNumList.append(0)
opIndex += 1

while kernelIndex < len(launchKernelApiList):
if (opIndex < timeListLen - 1 and launchKernelApiList[kernelIndex].startTime > startTimeList[opIndex + 1]):
kernelNumList.append(kernelCount)
opIndex += 1
kernelCount = 0
continue
kernelCount += 1
kernelIndex += 1
kernelNumList.append(kernelCount)

while opIndex < timeListLen - 1:
kernelNumList.append(0)
opIndex += 1

for i in range(len(opList)):
print(opList[i])
print(kernelNumList[i])

# Get information for each op
kernelCount = 0
infoList = []
for i in range(len(opList)):
executionTime = endTimeList[i] - startTimeList[i]
maxExeTime = 0
maxIndex = 0
for j in range(kernelCount, kernelCount+kernelNumList[i]):
if kernelList[j].durationTime > maxExeTime:
maxExeTime = kernelList[j].durationTime
maxIndex = j

if kernelNumList[i] != 0:
infoList.append([i, opList[i], executionTime, maxIndex-kernelCount, kernelList[maxIndex].name, maxExeTime, kernelNumList[i]])
else:
infoList.append([i, opList[i], executionTime, -1, "NONE", 0, 0])
kernelCount += kernelNumList[i]

# TOP-20, use executionTime as key
k = min(20, len(infoList))
ansListTracer = findTopK(infoList, k, 2)
resTracer = pd.DataFrame(ansListTracer)
resTracer.columns = ['OpIndex', 'OpName', 'opExecutionTime', 'MaxKernelIndex', 'MaxKernelName', 'MaxKernelExecutionTime(ns)', 'KernelNumTracer']
print("Number of tracer records: " + str(len(infoList)))
print(resTracer)
resTracer.to_csv("./Experiments/opInfoTracer_result.csv", index=False, sep=',')

return ansListTracer




def opInfoCounter(dataList, flopCount=True):
# Filter Count Data, aggregate the information
opPosition = []
opKernelNum = []
for i in range(len(dataList)):
if dataList[i].rangeName == "NEW OP":
opPosition.append(i)

numValue = 6
if flopCount == False:
numValue = 3
for i in range(len(opPosition) - 1):
kernelNum = (opPosition[i+1] - opPosition[i] -1) / numValue
opKernelNum.append(kernelNum)
kernelNum = (len(dataList) - opPosition[-1] - 1) / numValue
opKernelNum.append(kernelNum)

infoList = []
for i in range(len(opPosition)):
totalDramRead = 0
totalDramWrite = 0
totalCyclesElapsed = 0
totalFlopCount = 0
for j in range(int(opKernelNum[i])):
kernelPosition = j * numValue + opPosition[i] + 1
totalDramRead += dataList[kernelPosition].gpuValue / 1e6
totalDramWrite += dataList[kernelPosition+1].gpuValue / 1e6
if flopCount == True:
totalCyclesElapsed += dataList[kernelPosition+5].gpuValue / 1e6
spAdd = dataList[kernelPosition+2].gpuValue / 1e6
spFma = dataList[kernelPosition+3].gpuValue / 1e6
spMul = dataList[kernelPosition+4].gpuValue / 1e6
totalFlopCount += spAdd + spMul + spFma * 2

else:
totalCyclesElapsed += dataList[kernelPosition+2].gpuValue / 1e6

if flopCount:
infoList.append([i, dataList[opPosition[i]].metricName, int(opKernelNum[i]), round(totalCyclesElapsed,2), round(totalDramRead,2), round(totalDramWrite,2), round(totalFlopCount,2)])
else:
infoList.append([i, dataList[opPosition[i]].metricName, int(opKernelNum[i]), round(totalCyclesElapsed,2), round(totalDramRead,2), round(totalDramWrite,2)])

# Find Top-K kernel according to kernel execution time
k = min(20, len(infoList))
ansListCounter = findTopK(infoList, k, key = 3)
resCounter = pd.DataFrame(ansListCounter)
if flopCount == True:
resCounter.columns = ['OpIndex', 'OpName', 'KernelNumCounter', 'TotalCyclesElapsed', 'TotalDramRead', 'TotalDramWrite', 'TotalFlopCount']
else:
resCounter.columns = ['OpIndex', 'OpName', 'KernelNumCounter', 'TotalCyclesElapsed', 'TotalDramRead', 'TotalDramWrite']

print("Number of counter records: " + str(len(infoList)))
print(resCounter)
resCounter.to_csv("./Experiments/opInfoCounter_result.csv", index=False, sep=',')

return infoList
3 changes: 1 addition & 2 deletions src/amanda/profiler/pytorch/amanda_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@ def __init__(self, filePath="kernel_metrics.txt", kindFlag=0) -> None:

def forward_instrumentation(self, context: amanda.OpContext):
op = context.get_op()

self.opCount += 1
self.opList.append(op.__name__)

# if self.opCount > 10:
# return

self.opList.append(op.__name__)
context.insert_before_op(
self.start_profiling,
opName = op.__name__
Expand Down
6 changes: 3 additions & 3 deletions src/amanda/profiler/pytorch/amanda_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@ def __init__(self, filePath="activity_records.txt", kindFlag=0) -> None:

def forward_instrumentation(self, context: amanda.OpContext):
op = context.get_op()

self.opCount += 1
self.opList.append(op.__name__)

# if self.opCount > 10:
# return


self.opList.append(op.__name__)
context.insert_before_op(
self.init_trace,
)
Expand All @@ -34,6 +33,7 @@ def init_trace(self, *input):
self.tracer.initTrace()

def finish_trace(self, *output):
self.tracer.activityFlushAll()
self.tracer.finishTrace()

def setKindFlag(self, kindFlag):
Expand Down
25 changes: 25 additions & 0 deletions src/amanda/profiler/pytorch/opInfo_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import amanda
import torch
import torchvision
from profiler import Profiler

def main():

device = "cuda"

model = torchvision.models.resnet50().to(device)
x = torch.rand((32, 3, 227, 227)).to(device)

metric = "OpInfo"
profiler = Profiler(metric)
profiler.setConfigs(metric=metric, supplyInfo=[])

with amanda.tool.apply(profiler.counter):
y = model(x)
with amanda.tool.apply(profiler.tracer):
y = model(x)

profiler.showResults()

if __name__ == "__main__":
main()
15 changes: 14 additions & 1 deletion src/amanda/profiler/pytorch/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from utils import setConfigsMetric
from metrics import kernelRoofline
from torchMetrics import kernelInfo
from torchMetrics import kernelInfo, opInfo

class Profiler():
def __init__(self, metric) -> None:
Expand All @@ -26,6 +26,7 @@ def __init__(self, metric) -> None:

def setConfigs(self, metric, supplyInfo, onlineOnly=False, offlineOnly=False):
self.__metric = metric
self.tracer.activityFlushAll()
self.tracer.clearData()
self.counter.clearData()

Expand All @@ -42,6 +43,7 @@ def setConfigs(self, metric, supplyInfo, onlineOnly=False, offlineOnly=False):

def showResults(self):
if self.__metric == "KernelInfo":
self.tracer.activityFlushAll()
self.opList = self.tracer.opList
self.startTimeList = self.tracer.getStartTimeLists()
self.traceDataRt = self.tracer.getTraceDataRt()
Expand All @@ -55,6 +57,17 @@ def showResults(self):
assert len(self.supplyInfo) == 3, "Please provide correct hardware parameters"
kernelRoofline(self.supplyInfo, self.countData)
return

if self.__metric == "OpInfo":
self.tracer.activityFlushAll()
self.opList = self.tracer.opList
self.startTimeList = self.tracer.getStartTimeLists()
self.endTimeList = self.tracer.getEndTimeLists()
self.traceDataApi = self.tracer.getTraceDataApi()
self.traceDataRt = self.tracer.getTraceDataRt()
self.countData = self.counter.getCountData()
opInfo(self.opList, self.startTimeList, self.endTimeList, self.traceDataApi, self.traceDataRt, self.countData)
return

sys.exit("Profiler.Metric: " + self.__metric + " not supported")

Expand Down
85 changes: 54 additions & 31 deletions src/amanda/profiler/pytorch/torchMetrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pandas as pd
from metrics import kernelInfoTracer, kernelInfoCounter
from metrics import kernelInfoTracer, kernelInfoCounter, opInfoTracer, opInfoCounter


# Now Information: OpIndex, OpName, KernelIndex, KernelName, cudaLaunchKernelDuration, kernelExecutionDuration, dramRead, dramWrite, SFOp, elapsedCycles
Expand All @@ -8,35 +8,58 @@ def kernelInfo(opList, timeList, apiList, rtList, dataList):
ansListTracer = kernelInfoTracer(opList, timeList, apiList, rtList)
opPosition, opKernelNum = kernelInfoCounter(dataList)

# We can not do mapping now.
# k = min(20, len(ansListTracer))
k = min(20, len(ansListTracer))

# Mapping Tracer data and Counter data, add counter data to TOP-K tracer ansList
# for i in range(k):
# if opKernelNum[ansListTracer[i][0]] == 0:
# ansListTracer[i].append(-1)
# ansListTracer[i].append(-1)
# ansListTracer[i].append(-1)
# ansListTracer[i].append(-1)
# continue

# # kernelPosition = opPosition[opIndex] + matricsNum * kernelIndex + 1
# kernelPosition = opPosition[ansListTracer[i][0]] + 6 * ansListTracer[i][2] + 1
# dramRead = dataList[kernelPosition].gpuValue / 1e6
# dramWrite = dataList[kernelPosition+1].gpuValue / 1e6
# spAdd = dataList[kernelPosition+2].gpuValue
# spFma = dataList[kernelPosition+3].gpuValue
# spMul = dataList[kernelPosition+4].gpuValue
# spOp = (spAdd + spMul + spFma * 2) / 1e6
# cyclesElapsed = dataList[kernelPosition+5].gpuValue / 1e6

# ansListTracer[i].append(round(dramRead,2))
# ansListTracer[i].append(round(dramWrite,2))
# ansListTracer[i].append(round(spOp,2))
# ansListTracer[i].append(round(cyclesElapsed,2))

# res = pd.DataFrame(ansListTracer)
# res.columns = ['OpIndex', 'OpName', 'KernelIndex', 'KernelType', 'KernelName', 'LaunchKernelTime(ns)', 'KernelExecutionTime(ns)', 'DramRead(MB)',
# 'DramWrite(MB)', 'SpOps(M)', 'CyclesElapsed(M)']
# print(res)
# res.to_csv("./Experiments/kernelInfo_result.csv", index=False, sep=',')
for i in range(k):
if opKernelNum[ansListTracer[i][0]] == 0:
ansListTracer[i].append(-1)
ansListTracer[i].append(-1)
ansListTracer[i].append(-1)
ansListTracer[i].append(-1)
continue

# kernelPosition = opPosition[opIndex] + matricsNum * kernelIndex + 1
kernelPosition = opPosition[ansListTracer[i][0]] + 6 * ansListTracer[i][2] + 1
dramRead = dataList[kernelPosition].gpuValue / 1e6
dramWrite = dataList[kernelPosition+1].gpuValue / 1e6
spAdd = dataList[kernelPosition+2].gpuValue
spFma = dataList[kernelPosition+3].gpuValue
spMul = dataList[kernelPosition+4].gpuValue
spOp = (spAdd + spMul + spFma * 2) / 1e6
cyclesElapsed = dataList[kernelPosition+5].gpuValue / 1e6

ansListTracer[i].append(round(dramRead,2))
ansListTracer[i].append(round(dramWrite,2))
ansListTracer[i].append(round(spOp,2))
ansListTracer[i].append(round(cyclesElapsed,2))

res = pd.DataFrame(ansListTracer)
res.columns = ['OpIndex', 'OpName', 'KernelIndex', 'KernelType', 'KernelName', 'LaunchKernelTime(ns)', 'KernelExecutionTime(ns)', 'DramRead(MB)',
'DramWrite(MB)', 'SpOps(M)', 'CyclesElapsed(M)']
print(res)
res.to_csv("./Experiments/kernelInfo_result.csv", index=False, sep=',')


# Now Information: OpIndex, OpName, TotalExecutionTime, MaxKernelIndex, MaxKernelName, MaxKernelExecutionTime, kernelNumTracer, kernelNumCounter, TotalDramRead, TotalDramWrite, TotalSFOp, TotalElapsedCycles
def opInfo(opList, startTimeList, endTimeList, apiList, rtList, dataList):

# Information collected by tracer: OpIndex, OpName, TotalExecutionTime, MaxKernelIndex, MaxKernelName, MaxKernelExecutionTime, kernelNumTracer
ansListTracer = opInfoTracer(opList, startTimeList, endTimeList, apiList, rtList)
# Information collected by counter: OpIndex, OpName, kernelNumCounter, TotalDramRead, TotalDramWrite, TotalSFOp, TotalElapsedCycles
opInfoCounterList = opInfoCounter(dataList, flopCount=True)

k = min(20, len(ansListTracer))
for i in range(k):
opIndex = ansListTracer[i][0]
counterInfo = opInfoCounterList[opIndex]
ansListTracer[i] += counterInfo[2:]
print(ansListTracer[i])

res = pd.DataFrame(ansListTracer)
res.columns = ['OpIndex', 'OpName', 'opExecutionTime', 'MaxKernelIndex', 'MaxKernelName', 'MaxKernelExecutionTime(ns)', 'KernelNumTracer',
'KernelNumCounter', 'TotalCyclesElapsed', 'TotalDramRead', 'TotalDramWrite', 'TotalFlopCount']
print(res)
res.to_csv("./Experiments/opInfo_result.csv", index=False, sep=',')

return
4 changes: 2 additions & 2 deletions src/amanda/profiler/tensorflow/amanda_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def forward_instrumentation(self, context: amanda.OpContext):
if not tensor.dtype._is_ref_dtype
]

if len(op_outputs) != 0 and len(op_inputs) != 0:
# if len(op_outputs) != 0 and len(op_inputs) != 0 and op.name.find("Relu") != -1:
# if len(op_outputs) != 0 and len(op_inputs) != 0:
if len(op_outputs) != 0 and len(op_inputs) != 0 and op.name.find("conv2d/Conv2D") != -1:
self.opList.append(op.name)
context.insert_before_op(
self.start_profile,
Expand Down
4 changes: 2 additions & 2 deletions src/amanda/profiler/tensorflow/amanda_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def forward_instrumentation(self, context: amanda.OpContext):
if not tensor.dtype._is_ref_dtype
]

if len(op_outputs) != 0 and len(op_inputs) != 0:
# if len(op_outputs) != 0 and len(op_inputs) != 0 and op.name.find("Relu") != -1:
# if len(op_outputs) != 0 and len(op_inputs) != 0:
if len(op_outputs) != 0 and len(op_inputs) != 0 and op.name.find("conv2d/Conv2D") != -1:
self.opList.append(op.name)
context.insert_before_op(
self.init_trace,
Expand Down
Loading

0 comments on commit 50d1e60

Please sign in to comment.