10911091
10921092import re
10931093import time
1094+ import torch
1095+ from typing import Union , Dict
1096+ from transformers import AutoTokenizer
1097+
10941098from hf_mini .filter import SensitiveInforRM
1095- is_security = SensitiveInforRM ()
10961099
1097- def input_wrapper (code_string , later_code : str = "" , path : str = "" ) -> str :
10981100
1101+ is_security = SensitiveInforRM ()
1102+
1103+ def input_wrapper (tokenizer : AutoTokenizer , code_string : str , later_code : str = "" , path : str = "" , pad_token : str = "☺" ) -> Union [Dict ,None ]:
10991104 start = time .time ()
11001105 _sequerity = True
11011106 for i in [code_string , later_code , path ]:
@@ -1104,7 +1109,7 @@ def input_wrapper(code_string, later_code: str = "", path: str = "") -> str:
11041109 break
11051110 print (f"Done inputs checking with { (time .time ()- start ) * 1000 :.2f} ms" , flush = True )
11061111 if not _sequerity :
1107- return ""
1112+ return None
11081113
11091114 extension_pattern = re .compile (r"(\.\w+)$" )
11101115 p = ""
@@ -1119,4 +1124,18 @@ def input_wrapper(code_string, later_code: str = "", path: str = "") -> str:
11191124 des = LANGUAGE_WRAPPER .get (lang , "" )
11201125 if len (des ) > 0 and "<AIX-SPE>" in des :
11211126 p = des .replace ("<AIX-SPE>" , f"the file path is: { path } " ) + "\n "
1122- return f"<s>▁<AIX-SPAN-PRE>▁<AIX-SPAN-POST>{ later_code } ▁<AIX-SPAN-MIDDLE>{ p } { code_string } "
1127+
1128+ # SPM
1129+ pad_ids = tokenizer (pad_token , return_tensors = "pt" , return_token_type_ids = False )
1130+ pad_len = len (pad_ids ["input_ids" ][0 ])
1131+ pre_code_ids = tokenizer ("<s>▁<AIX-SPAN-PRE>▁<AIX-SPAN-POST>" , return_tensors = "pt" , return_token_type_ids = False )
1132+
1133+ later_code_ids = tokenizer (pad_token + later_code , return_tensors = "pt" , return_token_type_ids = False )
1134+ later_code_ids ["input_ids" ] = later_code_ids ["input_ids" ][:,pad_len :]
1135+ later_code_ids ["attention_mask" ] = later_code_ids ["attention_mask" ][:,pad_len :]
1136+
1137+ code_string_ids = tokenizer (f"▁<AIX-SPAN-MIDDLE>{ p } { code_string } " , return_tensors = "pt" , return_token_type_ids = False )
1138+ code_string_ids ["input_ids" ] = torch .cat ([pre_code_ids ["input_ids" ], later_code_ids ["input_ids" ], code_string_ids ["input_ids" ]], dim = 1 )
1139+ code_string_ids ["attention_mask" ] = torch .cat ([pre_code_ids ["attention_mask" ], later_code_ids ["attention_mask" ], code_string_ids ["attention_mask" ]], dim = 1 )
1140+
1141+ return code_string_ids
0 commit comments