diff --git a/examples/common/text_data.py b/examples/common/text_data.py index 7e878542a..f5166491d 100644 --- a/examples/common/text_data.py +++ b/examples/common/text_data.py @@ -62,6 +62,8 @@ class StreamingTextDataset(StreamingDataset): ``False``. shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. + text_column (str): name of text column. Defaults to ``text``. + tokens_column (str): name of tokens column. Defaults to ``tokens``. """ def __init__(self, @@ -84,6 +86,8 @@ def __init__(self, shuffle: bool = False, shuffle_algo: str = 'py1s', shuffle_seed: int = 9176, + text_column: str = 'text', + tokens_column: str = 'tokens', **kwargs: Dict[str, Any]): group_method = kwargs.pop('group_method', None) @@ -128,6 +132,8 @@ def __init__(self, ) self.tokenizer = tokenizer self.max_seq_len = max_seq_len + self.text_column = text_column + self.tokens_column = tokens_column # How to tokenize a text sample to a token sample def _tokenize(self, text_sample): @@ -136,26 +142,26 @@ def _tokenize(self, text_sample): raise RuntimeError( 'If tokenizing on-the-fly, tokenizer must have a pad_token_id') - return self.tokenizer(text_sample['text'], + return self.tokenizer(text_sample[self.text_column], truncation=True, padding='max_length', max_length=self.max_seq_len) def _read_binary_tokenized_sample(self, sample): return torch.from_numpy( - np.frombuffer(sample['tokens'], + np.frombuffer(sample[self.tokens_column], dtype=np.int64)[:self.max_seq_len].copy()) # How to process a sample def __getitem__(self, idx: int) -> Dict[str, Any]: sample = super().__getitem__(idx) - if 'text' in sample: + if self.text_column in sample: token_sample = self._tokenize(sample) - elif 'tokens' in sample: + elif self.tokens_column in sample: token_sample = self._read_binary_tokenized_sample(sample) else: raise RuntimeError( - 'StreamingTextDataset needs samples to have a `text` or `tokens` column' + f'StreamingTextDataset needs samples to have a `{self.text_column}` or `{self.tokens_column}` column' ) return token_sample @@ -266,6 +272,8 @@ def build_text_dataloader( shuffle=cfg.dataset.get('shuffle', False), shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1s'), shuffle_seed=cfg.dataset.get('shuffle_seed', 9176), + text_column=cfg.dataset.get('text_column', 'text'), + tokens_column=cfg.dataset.get('tokens_column', 'tokens'), ) mlm_probability = cfg.dataset.get('mlm_probability', None) @@ -325,6 +333,14 @@ def build_text_dataloader( type=int, default=32, help='max sequence length to test') + parser.add_argument('--text_column', + type=str, + default='text', + help='name of column that contains text') + parser.add_argument('--tokens_column', + type=str, + default='tokens', + help='name of column that contains tokens') args = parser.parse_args() @@ -343,6 +359,8 @@ def build_text_dataloader( 'split': args.split, 'shuffle': False, 'max_seq_len': args.max_seq_len, + 'text_column': args.text_column, + 'tokens_column': args.tokens_column, 'keep_zip': True, # in case we need compressed files after testing }, 'drop_last': False,