Skip to content

Commit

Permalink
updated
Browse files Browse the repository at this point in the history
Signed-off-by: Anhforth <[email protected]>
  • Loading branch information
Anhforth committed Feb 15, 2023
1 parent 2ba795e commit c714eb6
Showing 1 changed file with 1 addition and 27 deletions.
28 changes: 1 addition & 27 deletions flagai/data/dataset/superglue/pvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self,
self.max_dec_seq_length = 16
self._is_multi_token = is_multi_token
self.max_segment_length = max_segment_length
# self.task_mask = args.task_mask
self.task_mask = args.task_mask
self.continuous_prompt = args.continuous_prompt
self.prefix_prompt = args.prefix_prompt
# if self.continuous_prompt:
Expand Down Expand Up @@ -176,29 +176,6 @@ def insert_tokens(parts, num_prompt_tokens, avg_prompt_tokens):

new_parts_a = insert_tokens(parts_a, num_prompt_tokens, avg_prompt_tokens)
new_parts_b = insert_tokens(parts_b, num_prompt_tokens, avg_prompt_tokens)
# new_parts_a, new_parts_b = [], []
# for part in parts_a:
# if part is None:
# if num_prompt_tokens > 0:
# if num_prompt_tokens >= avg_prompt_tokens:
# new_parts_a.append(avg_prompt_tokens)
# num_prompt_tokens -= avg_prompt_tokens
# else:
# new_parts_a.append(num_prompt_tokens)
# num_prompt_tokens = 0
# else:
# new_parts_a.append(part)
# for part in parts_b:
# if part is None:
# if num_prompt_tokens > 0:
# if num_prompt_tokens >= avg_prompt_tokens:
# new_parts_b.append(avg_prompt_tokens)
# num_prompt_tokens -= avg_prompt_tokens
# else:
# new_parts_b.append(num_prompt_tokens)
# num_prompt_tokens = 0
# else:
# new_parts_b.append(part)

return new_parts_a, new_parts_b

Expand Down Expand Up @@ -237,8 +214,6 @@ def encode_input(raw_parts):
parts.append((x, s))
return parts



parts_a = encode_input(
raw_parts_a) # Encode part a from text to token ids
if self.prefix_prompt > 0:
Expand Down Expand Up @@ -370,7 +345,6 @@ def encode_input(raw_parts):
else:
this_parts_a, this_parts_b = copy.deepcopy(
parts_a), copy.deepcopy(parts_b)

self.num_truncated += self.truncate(
this_parts_a,
this_parts_b,
Expand Down

0 comments on commit c714eb6

Please sign in to comment.