|
66 | 66 | from huggingface_hub.utils._token import get_token |
67 | 67 | pass |
68 | 68 | from triton import __version__ as triton_version |
| 69 | +BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if xformers is not None else None |
| 70 | + |
69 | 71 |
|
70 | 72 | def original_apply_qkv(self, X): |
71 | 73 | Q = self.q_proj(X) |
@@ -330,7 +332,7 @@ def fast_layernorm_compiled(layernorm, X): |
330 | 332 | def LlamaAttention_fast_forward( |
331 | 333 | self, |
332 | 334 | hidden_states: torch.Tensor, |
333 | | - causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None, |
| 335 | + causal_mask: Optional[BlockDiagonalCausalMask] = None, |
334 | 336 | attention_mask: Optional[torch.Tensor] = None, |
335 | 337 | position_ids: Optional[torch.LongTensor] = None, |
336 | 338 | past_key_value: Optional[Tuple[torch.Tensor]] = None, |
@@ -538,7 +540,7 @@ def LlamaDecoderLayer_fast_forward( |
538 | 540 | def LlamaModel_fast_forward( |
539 | 541 | self, |
540 | 542 | input_ids: torch.LongTensor, |
541 | | - causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None, |
| 543 | + causal_mask: Optional[BlockDiagonalCausalMask] = None, |
542 | 544 | attention_mask: Optional[torch.Tensor] = None, |
543 | 545 | position_ids: Optional[torch.LongTensor] = None, |
544 | 546 | past_key_values: Optional[List[torch.FloatTensor]] = None, |
@@ -942,7 +944,7 @@ def CausalLM_fast_forward(fast_forward_inference): |
942 | 944 | def _CausalLM_fast_forward( |
943 | 945 | self, |
944 | 946 | input_ids: torch.LongTensor = None, |
945 | | - causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None, |
| 947 | + causal_mask: Optional[BlockDiagonalCausalMask] = None, |
946 | 948 | attention_mask: Optional[torch.Tensor] = None, |
947 | 949 | position_ids: Optional[torch.LongTensor] = None, |
948 | 950 | past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
0 commit comments