Skip to content

Commit

Permalink
initial release of xlnet
Browse files Browse the repository at this point in the history
  • Loading branch information
zihangdai committed Jun 19, 2019
1 parent 280d714 commit 93dfce7
Show file tree
Hide file tree
Showing 22 changed files with 10,923 additions and 0 deletions.
372 changes: 372 additions & 0 deletions README.md

Large diffs are not rendered by default.

148 changes: 148 additions & 0 deletions classifier_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from absl import flags

import re
import numpy as np

import tensorflow as tf
from data_utils import SEP_ID, CLS_ID

FLAGS = flags.FLAGS

SEG_ID_A = 0
SEG_ID_B = 1
SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4

class PaddingInputExample(object):
"""Fake example so the num input examples is a multiple of the batch size.
When running eval/predict on the TPU, we need to pad the number of examples
to be a multiple of the batch size, because the TPU requires a fixed batch
size. The alternative is to drop the last batch, which is bad because it means
the entire output data won't be generated.
We use this class instead of `None` because treating `None` as padding
battches could cause silent errors.
"""


class InputFeatures(object):
"""A single set of features of data."""

def __init__(self,
input_ids,
input_mask,
segment_ids,
label_id,
is_real_example=True):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.is_real_example = is_real_example


def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""

# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()


def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenize_fn):
"""Converts a single `InputExample` into a single `InputFeatures`."""

if isinstance(example, PaddingInputExample):
return InputFeatures(
input_ids=[0] * max_seq_length,
input_mask=[1] * max_seq_length,
segment_ids=[0] * max_seq_length,
label_id=0,
is_real_example=False)

if label_list is not None:
label_map = {}
for (i, label) in enumerate(label_list):
label_map[label] = i

tokens_a = tokenize_fn(example.text_a)
tokens_b = None
if example.text_b:
tokens_b = tokenize_fn(example.text_b)

if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for two [SEP] & one [CLS] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for one [SEP] & one [CLS] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[:max_seq_length - 2]

tokens = []
segment_ids = []
for token in tokens_a:
tokens.append(token)
segment_ids.append(SEG_ID_A)
tokens.append(SEP_ID)
segment_ids.append(SEG_ID_A)

if tokens_b:
for token in tokens_b:
tokens.append(token)
segment_ids.append(SEG_ID_B)
tokens.append(SEP_ID)
segment_ids.append(SEG_ID_B)

tokens.append(CLS_ID)
segment_ids.append(SEG_ID_CLS)

input_ids = tokens

# The mask has 0 for real tokens and 1 for padding tokens. Only real
# tokens are attended to.
input_mask = [0] * len(input_ids)

# Zero-pad up to the sequence length.
if len(input_ids) < max_seq_length:
delta_len = max_seq_length - len(input_ids)
input_ids = [0] * delta_len + input_ids
input_mask = [1] * delta_len + input_mask
segment_ids = [SEG_ID_PAD] * delta_len + segment_ids

assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length

if label_list is not None:
label_id = label_map[example.label]
else:
label_id = example.label
if ex_index < 5:
tf.logging.info("*** Example ***")
tf.logging.info("guid: %s" % (example.guid))
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
tf.logging.info("label: {} (id = {})".format(example.label, label_id))

feature = InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id)
return feature



Loading

0 comments on commit 93dfce7

Please sign in to comment.