@@ -101,23 +101,20 @@ def sft_trainer_prepare_dataset(function_name, function):
101101
102102# Ignore mean_token_accuracy since it needs logits
103103# We override it directly with our version
104- def _sft_trainer_compute_loss (self , model , inputs , return_outputs = False , num_items_in_batch = None ):
105- (loss , outputs ) = super ().compute_loss (
106- model ,
107- inputs ,
108- return_outputs = return_outputs ,
109- num_items_in_batch = num_items_in_batch ,
110- )
111- return (loss , outputs ) if return_outputs else loss
112- pass
113-
114104def sft_trainer_compute_loss (function_name , function ):
115105 if function_name != "compute_loss" : return function
116106
117- function = inspect .getsource (_sft_trainer_compute_loss )
118- function = function .replace ("def _sft_trainer_compute_loss" , "def compute_loss" )
119- function = function .split ("\n " )
120- function = "\n " .join (" " * 4 + x for x in function )
107+ def compute_loss (self , model , inputs , return_outputs = False , num_items_in_batch = None ):
108+ outputs = super ().compute_loss (
109+ model ,
110+ inputs ,
111+ return_outputs = return_outputs ,
112+ num_items_in_batch = num_items_in_batch ,
113+ )
114+ return outputs
115+ pass
116+
117+ function = inspect .getsource (compute_loss )
121118 return function
122119pass
123120RL_FUNCTIONS ["sft_trainer" ].append (sft_trainer_compute_loss )
0 commit comments