Skip to content

Commit

Permalink
[WIP] Upgrading cometpy to new version of COMET with decoupled indexl…
Browse files Browse the repository at this point in the history
…abels
  • Loading branch information
pthomadakis committed Mar 22, 2024
1 parent 74bec7b commit bed0194
Show file tree
Hide file tree
Showing 4 changed files with 543 additions and 480 deletions.
143 changes: 64 additions & 79 deletions frontends/numpy-scipy/cometpy/MLIRGen/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,20 +191,14 @@ class TensorSumBuilder:
undefined=jinja2.StrictUndefined,
)

def __init__(self, lhs, operators, tensors_shapes, label_map):
self.lhs = lhs
self.tensors_shapes = tensors_shapes
output_type = "tensor<{}xf64>".format("x".join(str(label_map[v][0]) for v in self.tensors_shapes[-1]))
input_type = []
self.input_type = "tensor<{}xf64>".format("x".join(str(label_map[v][0]) if label_map[v][1] == DENSE else '?' for v in self.tensors_shapes[0]))
def __init__(self, data): # lhs, operators, tensors_shapes, label_map):
self.lhs = data["out_id"]
self.input_type = "tensor<{}xf64>".format("x".join(str(v) for v in data["shapes"][0]))

self.operators = self.operators = "({})".format(",".join("%t"+str(v) for v in operators))
self.operators = "({})".format(",".join("%t"+str(v) for v in data["operands"]))

def build_op(self):
output_type = "f64"
# for t in self.tensors_shapes[:-1]:
# input_type.append("tensor<{}xf64>".format("x".join(str(v) for v in t)))
# input_type = ",".join(input_type)

return self.tensor_sum_wrapper_text.render(
lhs = self.lhs,
Expand All @@ -223,13 +217,11 @@ class SetOp_Builder:
undefined=jinja2.StrictUndefined,
)

def __init__(self, in_tensor, target, tensors_shapes, label_map, beta) :
self.target = target
self.in_tensor = in_tensor
self.tensors_shapes =[]
for l in tensors_shapes:
self.tensors_shapes.append([ label_map[lbl][0] if label_map[lbl][1] == DENSE else '?' for lbl in l ] )
self.beta = "{:e}".format(beta)
def __init__(self, data):# in_tensor, target, tensors_shapes, label_map, beta) :
self.target = data["lhs"]
self.in_tensor = data["rhs"]
self.tensors_shapes = data["shapes"]
self.beta = "{:e}".format(data["beta"])


def build_op(self):
Expand Down Expand Up @@ -332,88 +324,85 @@ class ArithOp_Builder:
undefined=jinja2.StrictUndefined,
)

def __init__(self, dest, input_tensors:list, tc_indices, formats: list, tensors_shapes, opr_type, label_map, mask=None, mask_type="none", mask_lbls = None, semiring=None, beta=0):
# dimslbls_to_map:list, input_array_dims_lbls:list,
# target_dims_lbls:list,tensor_types:list,tc_indices:list,opr_type:str,op:str, formats:list) -> None:
def __init__(self, data):

self.dest = dest
self.operators = "{}".format(",".join("%t"+str(v) for v in input_tensors)+","+",".join("%i"+str(vv) for v in tensors_shapes for vv in v))
self.tc_indices = tc_indices
# self.dimslbls_to_map = dimslbls_to_map
# self.input_array_dims_lbls = input_array_dims_lbls
# self.target_dims_lbls = target_dims_lbls
self.mask = None
self.mask_type = "None"
self.mask_shape = None
self.semiring = None


self.dest = data["out_id"]
self.operators = "{}".format(",".join("%t"+str(v) for v in data["operands"])+","+",".join("%i"+str(vv) for v in data["op_ilabels"] for vv in v))
self.tensors_shapes =[]
for l in tensors_shapes:
self.tensors_shapes.append([ label_map[lbl][0] if label_map[lbl][1] == DENSE else '?' for lbl in l ] )
self.op_ilabels = data["op_ilabels"]
for l in data["shapes"]:
self.tensors_shapes.append([ str(lbl) for lbl in l ] )
# self.tensors_shapes = [label_map[lbl][0] for lbl in tensors_shapes]
self.opr_type = opr_type
self.opr_type = data["op_type"]
# self.op = op
self.formats = formats
self.mask = mask
self.mask_type = mask_type
if mask_lbls != None:
self.mask_shape = [ label_map[lbl][0] if label_map[lbl][1] == DENSE else '?' for lbl in mask_lbls ]
self.operators+=",%t"+str(self.mask)
else:
self.mask_shape = None
self.semiring = semiring
self.beta = "{:e}".format(beta)

self.formats = data["formats"]
if "mask" in data:
self.mask = data["mask"][0]
self.mask_type = data["mask"][1]
if data["mask"][2] != None:
self.mask_shape = [ str(lbl) for lbl in data["mask"][2] ]
self.operators+=",%t"+str(self.mask)
if "semiring" in data:
self.semiring = data["semiring"]
self.beta = "{:e}".format(data["beta"])

def build_op(self):
output_type = "tensor<{}xf64>".format("x".join(str(v) for v in self.tensors_shapes[-1]))
input_type = []
for t in self.tensors_shapes[:-1]:
input_type.append("tensor<{}xf64>".format("x".join(str(v) for v in t)))
for t in self.tensors_shapes:
for v in t:
input_type.append("!ta.range")
input_type.append("!ta.indexlabel")
input_type = ",".join(input_type)
if self.mask_shape != None:
input_type += ",tensor<{}xf64>".format("x".join(str(v) for v in self.mask_shape))
# beta_val = ArithOp_Builder.get_beta_val(self.op)

ops = self.tc_indices.split(',')
iMap = {}
i = 0
if len(ops) > 1:
op1 = ops[0]
op2, res = ops[1].split('->')
else:
op2 = []
op1,res = ops[0].split('->')
vMap = {}
indexing_map = []
i = 0
temp = []
for l in op1:
# if l not in iMap:
for k, l in enumerate(self.op_ilabels[0]):
iMap[l] = i
vMap[l] = self.tensors_shapes[0][k]
temp.append(i)
i+=1

indexing_map.append(temp)

if len(ops) > 1:
if len(self.op_ilabels) > 2:
temp = []
for l in op2:
for k, l in enumerate(self.op_ilabels[1]):
if l not in iMap:
iMap[l] = i
print( self.tensors_shapes[1][k])
vMap[l] = self.tensors_shapes[1][k]
temp.append(i)
i+=1
else:
temp.append(iMap[l])

indexing_map.append(temp)
temp = []
for l in res:

for l in self.op_ilabels[-1]:
temp.append(iMap[l])
indexing_map.append(temp)
indexing_maps = []

output_type = "tensor<{}xf64>".format("x".join(str(vMap[v]) for v in self.op_ilabels[-1]))
print(output_type)

for imap in indexing_map:
indexing_maps.append("affine_map<({})->({})>".format(",".join(["d"+str(l) for l in range(i)]) , ",".join(["d"+str(l) for l in imap])))

indexing_maps = str(indexing_maps).replace("'","")


# Tensor contraction
if self.opr_type == 'c':
semiring = "plusxy_times"
Expand Down Expand Up @@ -542,24 +531,19 @@ class Tensor_Decl_Builder:
undefined=jinja2.StrictUndefined,
)

def __init__(self, lhs, decl_vars:list, input_shape: str, format, dtype, label_map, is_input)->None:
self.lhs = lhs
self.inputtype = "tensor<{}x{}>".format("x".join(str(label_map[v][0]) if label_map[v][1] == DENSE else '?' for v in input_shape), dtype)
self.decl_vars = decl_vars
self.format = format
self.is_input = is_input
def __init__(self, data)->None:
self.lhs = data["id"]
self.inputtype = "tensor<{}x{}>".format("x".join(str(v) for v in data["shape"]), data["value_type"])
# self.decl_vars = data["dimsSSA"]
self.decl_vars = []
self.format = data["format"]
self.is_input = data["is_input"]


def build_tensor(self):
dims_tuple = "({})".format(",".join("%i"+str(v) for v in self.decl_vars))
ranges_tuple = "({})".format(",".join(["!ta.range"]* len(self.decl_vars)))
# ranges_tuple = "("
# for i in range(len(self.decl_vars)-1):
# dims_tuple += self.decl_vars[i] + ","
# ranges_tuple += "!ta.range,"

# dims_tuple += self.decl_vars[-1] + ")"
# ranges_tuple += "!ta.range)"
dims_tuple = "({})".format(",".join("%d"+str(v) for v in self.decl_vars))
ranges_tuple = "({})".format(",".join(["index"]* len(self.decl_vars)))

if not self.format == DENSE:
where = "_from_file"
format = '"{}" , temporal_tensor = false'.format(self.formats[self.format])
Expand Down Expand Up @@ -599,13 +583,14 @@ class PrintBuilder:
undefined=jinja2.StrictUndefined,
)

def __init__(self, operand, input_labels, dtype, label_map):
self.operand = operand[0]
self.outtype = "x".join(str(label_map[v][0]) if label_map[v][1] == DENSE else '?' for v in input_labels[0])
if len(self.outtype) > 0:
self.outtype = "tensor<{}x{}>".format(self.outtype, dtype)
def __init__(self, data): #operand, input_labels, dtype, label_map):
self.operand = data["operands"][0]
self.outtype = "x".join(str(v) for v in data["shapes"][0])
if len(data["shapes"][0])==1 and data["shapes"][0][0] == 1:
self.outtype = data["value_type"]
else:
self.outtype = dtype
self.outtype = "tensor<{}x{}>".format(self.outtype, data["value_type"])

def build_op(self):
return self.tensor_print_text.render(
tensor = self.operand,
Expand Down
19 changes: 14 additions & 5 deletions frontends/numpy-scipy/cometpy/MLIRGen/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
def cleanup():
for f in files_to_cleanup:
if os.path.exists(f):
os.remove(f)
# pass
# os.remove(f)
pass
atexit.register(cleanup)

class memref_i64(Structure):
Expand Down Expand Up @@ -139,15 +139,19 @@ def comment_unneeded_dense(input_, arg_vals):
for j in range(len(input[:i])):
if cast + " = memref.cast" in input[j]:
out = input[j][input[j].find("%alloc") : input[j].find(":")].lstrip().strip()
replace['%arg'+str(len(arg_vals))] = out

start = input[j].find(":")
end = input[j][start:].find("to")
replace['%arg'+str(len(arg_vals))] = out +" " + input[j][start:start+end].lstrip().strip()
for k in range(len(input[:j])):
if out+" = memref.alloc(" in input[k]:
# input[k] = "//from dense" + input[k]
input[k] = ""
elif allocs_needed > 0 and "memref.alloc" in input[i]:
allocs_needed = allocs_needed - 1
a = input[i][input[i].find("%") : input[i].find("=")].lstrip().strip()
replace['%arg'+str(indexes[0])] = a
start = input[i].rfind(":")
replace['%arg'+str(indexes[0])] = a +" " + input[i][start:].lstrip().strip()
indexes = indexes[1:]
# input[i] = "//from dense" + input[i]
input[i] = ""
Expand All @@ -161,7 +165,12 @@ def comment_unneeded_dense(input_, arg_vals):
break

for v in replace:
input[1] = input[1].replace(v, replace[v])
start = input[1].find(v)
end = input[1][start:].find(",")
if end == -1 :
end = input[1][start:].find(")")
repl = input[1][start:start+end]
input[1] = input[1].replace(repl, replace[v])

output = ""

Expand Down
2 changes: 2 additions & 0 deletions frontends/numpy-scipy/cometpy/cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
comet_path = '/Users/thom895/local/comet/COMET/build'
llvm_path = '/Users/thom895/local/comet/COMET/llvm/build'
Loading

0 comments on commit bed0194

Please sign in to comment.