Skip to content

Commit

Permalink
Lint failure adding annotations to `torchtnt/examples/mingpt/char_dat…
Browse files Browse the repository at this point in the history
…aset.py` (pytorch#516)

Summary: Pull Request resolved: pytorch#516

Reviewed By: galrotem

Differential Revision: D48643967

fbshipit-source-id: 2a3a377107d72db103d549fe777e39051d5b0fc7
  • Loading branch information
ananthsub authored and facebook-github-bot committed Aug 24, 2023
1 parent fe48759 commit 56cd6cb
Showing 1 changed file with 11 additions and 20 deletions.
31 changes: 11 additions & 20 deletions examples/mingpt/char_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Dict, Tuple

import fsspec
import torch
Expand All @@ -17,18 +18,14 @@

@dataclass
class DataConfig:
# pyre-fixme[8]: Attribute has type `str`; used as `None`.
path: str = None
# pyre-fixme[8]: Attribute has type `int`; used as `None`.
block_size: int = None
# pyre-fixme[8]: Attribute has type `float`; used as `None`.
train_split: float = None
path: str
block_size: int
train_split: float
truncate: float = 1.0


class CharDataset(Dataset):
# pyre-fixme[3]: Return type must be annotated.
def __init__(self, data_cfg: DataConfig):
def __init__(self, data_cfg: DataConfig) -> None:
print(data_cfg.path)
data = fsspec.open(data_cfg.path).open().read().decode("utf-8")
data = data[: int(len(data) * data_cfg.truncate)]
Expand All @@ -37,24 +34,18 @@ def __init__(self, data_cfg: DataConfig):
data_size, vocab_size = len(data), len(chars)
print("Data has %d characters, %d unique." % (data_size, vocab_size))

# pyre-fixme[4]: Attribute must be annotated.
self.stoi = {ch: i for i, ch in enumerate(chars)}
# pyre-fixme[4]: Attribute must be annotated.
self.itos = {i: ch for i, ch in enumerate(chars)}
# pyre-fixme[4]: Attribute must be annotated.
self.block_size = data_cfg.block_size
# pyre-fixme[4]: Attribute must be annotated.
self.vocab_size = vocab_size
self.stoi: Dict[str, int] = {ch: i for i, ch in enumerate(chars)}
self.itos: Dict[int, str] = {i: ch for i, ch in enumerate(chars)}
self.block_size: int = data_cfg.block_size
self.vocab_size: int = vocab_size
# pyre-fixme[4]: Attribute must be annotated.
self.data = data

# pyre-fixme[3]: Return type must be annotated.
def __len__(self):
def __len__(self) -> int:
return len(self.data) - self.block_size

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __getitem__(self, idx):
def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
# grab a chunk of (block_size + 1) characters from the data
chunk = self.data[idx : idx + self.block_size + 1]
# encode every character to an integer
Expand Down

0 comments on commit 56cd6cb

Please sign in to comment.