diff --git a/comfyui_to_python.py b/comfyui_to_python.py index 33e6e53..e3b8465 100644 --- a/comfyui_to_python.py +++ b/comfyui_to_python.py @@ -11,9 +11,15 @@ import black -from utils import import_custom_nodes, find_path, add_comfyui_directory_to_sys_path, add_extra_model_paths, get_value_at_index - -sys.path.append('../') +from utils import ( + import_custom_nodes, + find_path, + add_comfyui_directory_to_sys_path, + add_extra_model_paths, + get_value_at_index, +) + +sys.path.append("../") from nodes import NODE_CLASS_MAPPINGS @@ -40,7 +46,7 @@ def read_json_file(file_path: str) -> dict: """ try: - with open(file_path, 'r') as file: + with open(file_path, "r") as file: data = json.load(file) return data @@ -59,7 +65,9 @@ def read_json_file(file_path: str) -> dict: # Format the list of JSON files as a string json_files_str = "\n".join(json_files) - raise FileNotFoundError(f"\n\nFile not found: {file_path}. JSON files in the directory:\n{json_files_str}") + raise FileNotFoundError( + f"\n\nFile not found: {file_path}. JSON files in the directory:\n{json_files_str}" + ) except json.JSONDecodeError: raise ValueError(f"Invalid JSON format in file: {file_path}") @@ -83,7 +91,7 @@ def write_code_to_file(file_path: str, code: str) -> None: os.makedirs(directory) # Save the code to a .py file - with open(file_path, 'w') as file: + with open(file_path, "w") as file: file.write(code) @@ -135,7 +143,7 @@ def _dfs(self, key: str) -> None: """ # Mark the node as visited. self.visited[key] = True - inputs = self.data[key]['inputs'] + inputs = self.data[key]["inputs"] # Loop over each input key. for input_key, val in inputs.items(): # If the value is a list and the first item in the list has not been visited yet, @@ -153,11 +161,15 @@ def _load_special_functions_first(self) -> None: """ # Iterate over each key in the data to check for loader keys. for key in self.data: - class_def = self.node_class_mappings[self.data[key]['class_type']]() + class_def = self.node_class_mappings[self.data[key]["class_type"]]() # Check if the class is a loader class or meets specific conditions. - if (class_def.CATEGORY == 'loaders' or - class_def.FUNCTION in ['encode'] or - not any(isinstance(val, list) for val in self.data[key]['inputs'].values())): + if ( + class_def.CATEGORY == "loaders" + or class_def.FUNCTION in ["encode"] + or not any( + isinstance(val, list) for val in self.data[key]["inputs"].values() + ) + ): self.is_special_function = True # If the key has not been visited, perform a DFS from that key. if key not in self.visited: @@ -182,7 +194,12 @@ def __init__(self, node_class_mappings: Dict, base_node_class_mappings: Dict): self.node_class_mappings = node_class_mappings self.base_node_class_mappings = base_node_class_mappings - def generate_workflow(self, load_order: List, filename: str = 'generated_code_workflow.py', queue_size: int = 10) -> str: + def generate_workflow( + self, + load_order: List, + filename: str = "generated_code_workflow.py", + queue_size: int = 10, + ) -> str: """Generate the execution code based on the load order. Args: @@ -195,7 +212,12 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo str: Generated execution code as a string. """ # Create the necessary data structures to hold imports and generated code - import_statements, executed_variables, special_functions_code, code = set(['NODE_CLASS_MAPPINGS']), {}, [], [] + import_statements, executed_variables, special_functions_code, code = ( + set(["NODE_CLASS_MAPPINGS"]), + {}, + [], + [], + ) # This dictionary will store the names of the objects that we have already initialized initialized_objects = {} @@ -204,16 +226,18 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo for idx, data, is_special_function in load_order: # Generate class definition and inputs from the data - inputs, class_type = data['inputs'], data['class_type'] + inputs, class_type = data["inputs"], data["class_type"] class_def = self.node_class_mappings[class_type]() # If the class hasn't been initialized yet, initialize it and generate the import statements if class_type not in initialized_objects: # No need to use preview image nodes since we are executing the script in a terminal - if class_type == 'PreviewImage': + if class_type == "PreviewImage": continue - class_type, import_statement, class_code = self.get_class_info(class_type) + class_type, import_statement, class_code = self.get_class_info( + class_type + ) initialized_objects[class_type] = self.clean_variable_name(class_type) if class_type in self.base_node_class_mappings.keys(): import_statements.add(import_statement) @@ -222,29 +246,58 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo special_functions_code.append(class_code) # Get all possible parameters for class_def - class_def_params = self.get_function_parameters(getattr(class_def, class_def.FUNCTION)) + class_def_params = self.get_function_parameters( + getattr(class_def, class_def.FUNCTION) + ) # Remove any keyword arguments from **inputs if they are not in class_def_params - inputs = {key: value for key, value in inputs.items() if key in class_def_params} + inputs = { + key: value for key, value in inputs.items() if key in class_def_params + } # Deal with hidden variables - if 'unique_id' in class_def_params: - inputs['unique_id'] = random.randint(1, 2**64) + if "unique_id" in class_def_params: + inputs["unique_id"] = random.randint(1, 2**64) # Create executed variable and generate code - executed_variables[idx] = f'{self.clean_variable_name(class_type)}_{idx}' + executed_variables[idx] = f"{self.clean_variable_name(class_type)}_{idx}" inputs = self.update_inputs(inputs, executed_variables) - + if is_special_function: - special_functions_code.append(self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, **inputs)) + special_functions_code.append( + self.create_function_call_code( + initialized_objects[class_type], + class_def.FUNCTION, + executed_variables[idx], + is_special_function, + **inputs, + ) + ) else: - code.append(self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, **inputs)) + code.append( + self.create_function_call_code( + initialized_objects[class_type], + class_def.FUNCTION, + executed_variables[idx], + is_special_function, + **inputs, + ) + ) # Generate final code by combining imports and code, and wrap them in a main function - final_code = self.assemble_python_code(import_statements, special_functions_code, code, queue_size, custom_nodes) + final_code = self.assemble_python_code( + import_statements, special_functions_code, code, queue_size, custom_nodes + ) return final_code - def create_function_call_code(self, obj_name: str, func: str, variable_name: str, is_special_function: bool, **kwargs) -> str: + def create_function_call_code( + self, + obj_name: str, + func: str, + variable_name: str, + is_special_function: bool, + **kwargs, + ) -> str: """Generate Python code for a function call. Args: @@ -257,15 +310,15 @@ def create_function_call_code(self, obj_name: str, func: str, variable_name: str Returns: str: The generated Python code. """ - args = ', '.join(self.format_arg(key, value) for key, value in kwargs.items()) + args = ", ".join(self.format_arg(key, value) for key, value in kwargs.items()) # Generate the Python code - code = f'{variable_name} = {obj_name}.{func}({args})\n' + code = f"{variable_name} = {obj_name}.{func}({args})\n" # If the code contains dependencies and is not a loader or encoder, indent the code because it will be placed inside # of a for loop if not is_special_function: - code = f'\t{code}' + code = f"\t{code}" return code @@ -279,16 +332,23 @@ def format_arg(self, key: str, value: any) -> str: Returns: str: Formatted argument as a string. """ - if key == 'noise_seed' or key == 'seed': - return f'{key}=random.randint(1, 2**64)' + if key == "noise_seed" or key == "seed": + return f"{key}=random.randint(1, 2**64)" elif isinstance(value, str): value = value.replace("\n", "\\n").replace('"', "'") return f'{key}="{value}"' - elif isinstance(value, dict) and 'variable_name' in value: + elif isinstance(value, dict) and "variable_name" in value: return f'{key}={value["variable_name"]}' - return f'{key}={value}' - - def assemble_python_code(self, import_statements: set, speical_functions_code: List[str], code: List[str], queue_size: int, custom_nodes=False) -> str: + return f"{key}={value}" + + def assemble_python_code( + self, + import_statements: set, + speical_functions_code: List[str], + code: List[str], + queue_size: int, + custom_nodes=False, + ) -> str: """Generates the final code string. Args: @@ -303,29 +363,54 @@ def assemble_python_code(self, import_statements: set, speical_functions_code: L """ # Get the source code of the utils functions as a string func_strings = [] - for func in [get_value_at_index, find_path, add_comfyui_directory_to_sys_path, add_extra_model_paths]: - func_strings.append(f'\n{inspect.getsource(func)}') + for func in [ + get_value_at_index, + find_path, + add_comfyui_directory_to_sys_path, + add_extra_model_paths, + ]: + func_strings.append(f"\n{inspect.getsource(func)}") # Define static import statements required for the script - static_imports = ['import os', 'import random', 'import sys', 'from typing import Sequence, Mapping, Any, Union', - 'import torch'] + func_strings + ['\n\nadd_comfyui_directory_to_sys_path()\nadd_extra_model_paths()\n'] + static_imports = ( + [ + "import os", + "import random", + "import sys", + "from typing import Sequence, Mapping, Any, Union", + "import torch", + ] + + func_strings + + ["\n\nadd_comfyui_directory_to_sys_path()\nadd_extra_model_paths()\n"] + ) # Check if custom nodes should be included if custom_nodes: - static_imports.append(f'\n{inspect.getsource(import_custom_nodes)}\n') - custom_nodes = 'import_custom_nodes()\n\t' + static_imports.append(f"\n{inspect.getsource(import_custom_nodes)}\n") + custom_nodes = "import_custom_nodes()\n\t" else: - custom_nodes = '' + custom_nodes = "" # Create import statements for node classes - imports_code = [f"from nodes import {', '.join([class_name for class_name in import_statements])}" ] + imports_code = [ + f"from nodes import {', '.join([class_name for class_name in import_statements])}" + ] # Assemble the main function code, including custom nodes if applicable - main_function_code = "def main():\n\t" + f'{custom_nodes}with torch.inference_mode():\n\t\t' + '\n\t\t'.join(speical_functions_code) \ - + f'\n\n\t\tfor q in range({queue_size}):\n\t\t' + '\n\t\t'.join(code) + main_function_code = ( + "def main():\n\t" + + f"{custom_nodes}with torch.inference_mode():\n\t\t" + + "\n\t\t".join(speical_functions_code) + + f"\n\n\t\tfor q in range({queue_size}):\n\t\t" + + "\n\t\t".join(code) + ) # Concatenate all parts to form the final code - final_code = '\n'.join(static_imports + imports_code + ['', main_function_code, '', 'if __name__ == "__main__":', '\tmain()']) + final_code = "\n".join( + static_imports + + imports_code + + ["", main_function_code, "", 'if __name__ == "__main__":', "\tmain()"] + ) # Format the final code according to PEP 8 using the Black library final_code = black.format_str(final_code, mode=black.Mode()) return final_code - + def get_class_info(self, class_type: str) -> Tuple[str, str, str]: """Generates and returns necessary information about class type. @@ -338,12 +423,12 @@ def get_class_info(self, class_type: str) -> Tuple[str, str, str]: import_statement = class_type variable_name = self.clean_variable_name(class_type) if class_type in self.base_node_class_mappings.keys(): - class_code = f'{variable_name} = {class_type.strip()}()' + class_code = f"{variable_name} = {class_type.strip()}()" else: class_code = f'{variable_name} = NODE_CLASS_MAPPINGS["{class_type}"]()' return class_type, import_statement, class_code - + @staticmethod def clean_variable_name(class_type: str) -> str: """ @@ -357,14 +442,14 @@ def clean_variable_name(class_type: str) -> str: """ # Convert to lowercase and replace spaces with underscores clean_name = class_type.lower().strip().replace("-", "_").replace(" ", "_") - + # Remove characters that are not letters, numbers, or underscores - clean_name = re.sub(r'[^a-z0-9_]', '', clean_name) - + clean_name = re.sub(r"[^a-z0-9_]", "", clean_name) + # Ensure that it doesn't start with a number if clean_name[0].isdigit(): clean_name = "_" + clean_name - + return clean_name def get_function_parameters(self, func: Callable) -> List: @@ -377,9 +462,11 @@ def get_function_parameters(self, func: Callable) -> List: List: A list containing the names of the function's parameters. """ signature = inspect.signature(func) - parameters = {name: param.default if param.default != param.empty else None - for name, param in signature.parameters.items()} - return list(parameters.keys()) + parameters = { + name: param.default if param.default != param.empty else None + for name, param in signature.parameters.items() + } + return list(parameters.keys()) def update_inputs(self, inputs: Dict, executed_variables: Dict) -> Dict: """Update inputs based on the executed variables. @@ -392,10 +479,15 @@ def update_inputs(self, inputs: Dict, executed_variables: Dict) -> Dict: Dict: Updated inputs dictionary. """ for key in inputs.keys(): - if isinstance(inputs[key], list) and inputs[key][0] in executed_variables.keys(): - inputs[key] = {'variable_name': f"get_value_at_index({executed_variables[inputs[key][0]]}, {inputs[key][1]})"} + if ( + isinstance(inputs[key], list) + and inputs[key][0] in executed_variables.keys() + ): + inputs[key] = { + "variable_name": f"get_value_at_index({executed_variables[inputs[key][0]]}, {inputs[key][1]})" + } return inputs - + class ComfyUItoPython: """Main workflow to generate Python code from a workflow_api.json file. @@ -408,7 +500,13 @@ class ComfyUItoPython: base_node_class_mappings (Dict): Base mappings of node classes. """ - def __init__(self, input_file: str, output_file: str, queue_size: int = 10, node_class_mappings: Dict = NODE_CLASS_MAPPINGS): + def __init__( + self, + input_file: str, + output_file: str, + queue_size: int = 10, + node_class_mappings: Dict = NODE_CLASS_MAPPINGS, + ): """Initialize the ComfyUItoPython class with the given parameters. Args: @@ -441,8 +539,12 @@ def execute(self): load_order = load_order_determiner.determine_load_order() # Step 4: Generate the workflow code - code_generator = CodeGenerator(self.node_class_mappings, self.base_node_class_mappings) - generated_code = code_generator.generate_workflow(load_order, filename=self.output_file, queue_size=self.queue_size) + code_generator = CodeGenerator( + self.node_class_mappings, self.base_node_class_mappings + ) + generated_code = code_generator.generate_workflow( + load_order, filename=self.output_file, queue_size=self.queue_size + ) # Step 5: Write the generated code to a file FileHandler.write_code_to_file(self.output_file, generated_code) @@ -450,11 +552,13 @@ def execute(self): print(f"Code successfully generated and written to {self.output_file}") -if __name__ == '__main__': +if __name__ == "__main__": # Update class parameters here - input_file = 'workflow_api.json' - output_file = 'workflow_api.py' + input_file = "workflow_api.json" + output_file = "workflow_api.py" queue_size = 10 # Convert ComfyUI workflow to Python - ComfyUItoPython(input_file=input_file, output_file=output_file, queue_size=queue_size) + ComfyUItoPython( + input_file=input_file, output_file=output_file, queue_size=queue_size + )