diff --git a/code4me-server/src/api.py b/code4me-server/src/api.py index 8c6630f..64c376a 100644 --- a/code4me-server/src/api.py +++ b/code4me-server/src/api.py @@ -49,12 +49,15 @@ def autocomplete(): unique_predictions_set = set() def predict_model(model: Model) -> List[str]: - return model.value[1](left_context, right_context) - - results = Parallel(n_jobs=os.cpu_count(), prefer="threads")(delayed(predict_model)(model) for model in Model) - for model, model_predictions in zip(Model, results): - predictions[model.name] = model_predictions - unique_predictions_set.update(model_predictions) + prediction = model.value[1](left_context, right_context) + predictions[model.name] = prediction + unique_predictions_set.update(prediction) + + try: + Parallel(n_jobs=os.cpu_count(), prefer="threads", timeout=1)(delayed(predict_model)(model) for model in Model) + except TimeoutError: + # timeout is fine - predictions that were fast enough have been recorded + pass t_after = datetime.now() unique_predictions = list(unique_predictions_set)