1+ import os
12import logging
23import asyncio
34import yaml
@@ -15,8 +16,17 @@ async def get_model_list(self, required_mem: int) -> None:
1516 try :
1617 if not self ._sys_client :
1718 self ._sys_client = SYSClient (host = self .host , port = self .port )
19+ current_dir = os .path .dirname (os .path .abspath (__file__ ))
20+ parent_dir = os .path .dirname (current_dir )
21+ config_path = os .path .join (parent_dir , "config" , "config.yaml" )
1822
19- with open ('config/config.yaml' , 'r' ) as f :
23+ with open (config_path , 'r' ) as f :
24+ config = yaml .safe_load (f )
25+ config ['models' ] = {}
26+ with open (config_path , 'w' ) as f :
27+ yaml .safe_dump (config , f , default_flow_style = False , sort_keys = False )
28+
29+ with open (config_path , 'r' ) as f :
2030 config = yaml .safe_load (f )
2131 models_config = config .get ('models' , {})
2232 model_list = await self ._get_model_list ()
@@ -84,7 +94,7 @@ async def get_model_list(self, required_mem: int) -> None:
8494
8595 models_config [mode ] = new_entry
8696 config ['models' ] = models_config
87- with open ('config/config.yaml' , 'w' ) as f :
97+ with open (config_path , 'w' ) as f :
8898 yaml .safe_dump (config , f , default_flow_style = False , sort_keys = False )
8999
90100 except Exception as e :
0 commit comments