| | import warnings |
| | from typing import List, Optional, Union |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | from torch import nn |
| | from transformers import BatchEncoding |
| | from transformers.generation.logits_process import ( |
| | LogitsProcessorList, |
| | ) |
| | from transformers.generation.stopping_criteria import ( |
| | StoppingCriteriaList, |
| | validate_stopping_criteria, |
| | ) |
| |
|
| | from transformers.generation.utils import SampleOutput, SampleEncoderDecoderOutput, SampleDecoderOnlyOutput |
| |
|
| | def sample( |
| | self, |
| | input_ids: torch.LongTensor, |
| | logits_processor: Optional[LogitsProcessorList] = None, |
| | stopping_criteria: Optional[StoppingCriteriaList] = None, |
| | logits_warper: Optional[LogitsProcessorList] = None, |
| | max_length: Optional[int] = None, |
| | pad_token_id: Optional[int] = None, |
| | eos_token_id: Optional[Union[int, List[int]]] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | output_scores: Optional[bool] = None, |
| | return_dict_in_generate: Optional[bool] = None, |
| | synced_gpus: Optional[bool] = False, |
| | **model_kwargs, |
| | ) -> Union[SampleOutput, torch.LongTensor]: |
| |
|
| | if type(input_ids) in [dict, BatchEncoding]: |
| | input_ids, ngram_sequences = input_ids["input_ids"], input_ids |
| | del ngram_sequences["input_ids"] |
| | del ngram_sequences["attention_mask"] |
| | else: |
| | ngram_sequences = {} |
| |
|
| | |
| | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
| | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
| | if max_length is not None: |
| | warnings.warn( |
| | "`max_length` is deprecated in this function, use" |
| | " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
| | UserWarning, |
| | ) |
| | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
| | logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() |
| | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id |
| | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id |
| | if isinstance(eos_token_id, int): |
| | eos_token_id = [eos_token_id] |
| |
|
| | eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None |
| | output_scores = output_scores if output_scores is not None else self.generation_config.output_scores |
| | output_attentions = ( |
| | output_attentions if output_attentions is not None else self.generation_config.output_attentions |
| | ) |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states |
| | ) |
| | return_dict_in_generate = ( |
| | return_dict_in_generate |
| | if return_dict_in_generate is not None |
| | else self.generation_config.return_dict_in_generate |
| | ) |
| |
|
| | |
| | scores = () if (return_dict_in_generate and output_scores) else None |
| | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
| | cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
| | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
| |
|
| | |
| | if return_dict_in_generate and self.config.is_encoder_decoder: |
| | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
| | encoder_hidden_states = ( |
| | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
| | ) |
| |
|
| | |
| | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) |
| |
|
| | this_peer_finished = False |
| | |
| | while True: |
| | if synced_gpus: |
| | |
| | |
| | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
| | |
| | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
| | |
| | if this_peer_finished_flag.item() == 0.0: |
| | break |
| |
|
| | |
| | model_inputs = {"input_ids": input_ids} |
| |
|
| | |
| | outputs = self( |
| | **model_inputs, |
| | return_dict=True, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | **ngram_sequences |
| | ) |
| |
|
| | if synced_gpus and this_peer_finished: |
| | continue |
| |
|
| | next_token_logits = outputs.logits[:, -1, :] |
| |
|
| | |
| | next_token_scores = logits_processor(input_ids, next_token_logits) |
| | next_token_scores = logits_warper(input_ids, next_token_scores) |
| |
|
| | |
| | if return_dict_in_generate: |
| | if output_scores: |
| | scores += (next_token_scores,) |
| | if output_attentions: |
| | decoder_attentions += ( |
| | (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
| | ) |
| | if self.config.is_encoder_decoder: |
| | cross_attentions += (outputs.cross_attentions,) |
| |
|
| | if output_hidden_states: |
| | decoder_hidden_states += ( |
| | (outputs.decoder_hidden_states,) |
| | if self.config.is_encoder_decoder |
| | else (outputs.hidden_states,) |
| | ) |
| |
|
| | |
| | probs = nn.functional.softmax(next_token_scores, dim=-1) |
| | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
| |
|
| | |
| | if eos_token_id is not None: |
| | if pad_token_id is None: |
| | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") |
| | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |
| |
|
| | |
| | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
| | decoded = self.tokenizer.batch_decode(input_ids)[0] |
| | encoded = self.tokenizer( |
| | decoded, return_tensors="pt", return_ngram_sequences=True |
| | ) |
| | input_ids = encoded.input_ids.to(self.device) |
| |
|
| | ngram_sequences = {} |
| |
|
| | if "label_gram_2_sequence" in encoded: |
| | ngram_sequences["label_gram_2_sequence"] = encoded["label_gram_2_sequence"].to(self.device) |
| |
|
| | if "label_gram_3_sequence" in encoded: |
| | ngram_sequences["label_gram_3_sequence"] = encoded["label_gram_3_sequence"].to(self.device) |
| |
|
| | if "label_gram_4_sequence" in encoded: |
| | ngram_sequences["label_gram_4_sequence"] = encoded["label_gram_4_sequence"].to(self.device) |
| |
|
| | model_kwargs = self._update_model_kwargs_for_generation( |
| | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
| | ) |
| |
|
| | |
| | if eos_token_id_tensor is not None: |
| | unfinished_sequences = unfinished_sequences.mul( |
| | next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) |
| | ) |
| |
|
| | |
| | if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): |
| | if not synced_gpus: |
| | break |
| | else: |
| | this_peer_finished = True |
| |
|
| | if return_dict_in_generate: |
| | if self.config.is_encoder_decoder: |
| | return SampleEncoderDecoderOutput( |
| | sequences=input_ids, |
| | scores=scores, |
| | encoder_attentions=encoder_attentions, |
| | encoder_hidden_states=encoder_hidden_states, |
| | decoder_attentions=decoder_attentions, |
| | cross_attentions=cross_attentions, |
| | decoder_hidden_states=decoder_hidden_states, |
| | ) |
| | else: |
| | return SampleDecoderOnlyOutput( |
| | sequences=input_ids, |
| | scores=scores, |
| | attentions=decoder_attentions, |
| | hidden_states=decoder_hidden_states, |
| | ) |
| | else: |
| | return input_ids |
| |
|