1111
1212
1313class Rebase (GitSimBaseCommand ):
14- def __init__ (self , branch : str ):
14+ def __init__ (self , branch : str , rebase_merges : bool ):
1515 super ().__init__ ()
1616 self .branch = branch
17+ self .rebase_merges = rebase_merges
1718
1819 try :
1920 git .repo .fun .rev_parse (self .repo , self .branch )
@@ -38,11 +39,10 @@ def __init__(self, branch: str):
3839 self .alt_colors = {
3940 0 : [m .BLUE_B , m .BLUE_E ],
4041 1 : [m .PURPLE_B , m .PURPLE_E ],
41- 2 : [m .RED_B , m .RED_E ],
42- 3 : [m .GREEN_B , m .GREEN_E ],
42+ 2 : [m .GOLD_B , m .GOLD_E ],
43+ 3 : [m .TEAL_B , m .TEAL_E ],
4344 4 : [m .MAROON_B , m .MAROON_E ],
44- 5 : [m .GOLD_B , m .GOLD_E ],
45- 6 : [m .TEAL_B , m .TEAL_E ],
45+ 5 : [m .GREEN_B , m .GREEN_E ],
4646 }
4747
4848 def construct (self ):
@@ -79,30 +79,30 @@ def construct(self):
7979 head_commit = self .get_commit ()
8080 default_commits = {}
8181 self .get_default_commits (self .get_commit (), default_commits )
82- default_commits = self .sort_and_flatten (default_commits )
82+ flat_default_commits = self .sort_and_flatten (default_commits )
8383
8484 reached_base = False
85- for commit in default_commits :
86- if commit != "dark" and self .branch in self .repo .git .branch (
87- "--contains" , commit
88- ):
89- reached_base = True
85+ merge_base = self .repo .git .merge_base (self .branch , self .repo .active_branch .name )
86+ if merge_base in self .drawnCommits :
87+ reached_base = True
9088
9189 self .parse_commits (head_commit , shift = 4 * m .DOWN )
9290 self .parse_all ()
9391 self .center_frame_on_commit (branch_commit )
9492
9593 to_rebase = []
96- for c in default_commits :
94+ for c in flat_default_commits :
9795 if self .branch not in self .repo .git .branch ("--contains" , c ):
9896 to_rebase .append (c )
9997
10098 parent = branch_commit .hexsha
10199 branch_counts = {}
102100 rebased_shas = []
101+ rebased_sha_map = {}
103102 for j , tr in enumerate (reversed (to_rebase )):
104- if len (tr .parents ) > 1 :
105- continue
103+ if not self .rebase_merges :
104+ if len (tr .parents ) > 1 :
105+ continue
106106 if not reached_base and j == 0 :
107107 message = "..."
108108 else :
@@ -112,18 +112,20 @@ def construct(self):
112112 branch_counts [color_index ] = 0
113113 branch_counts [color_index ] += 1
114114 commit_color = self .alt_colors [color_index % len (self .alt_colors )][1 ]
115- parent = self .setup_and_draw_parent (parent , message , color = commit_color )
115+ parent = self .setup_and_draw_parent (parent , tr . hexsha , message , color = commit_color , branch_index = color_index , default_commits = default_commits )
116116 rebased_shas .append (parent )
117+ rebased_sha_map [tr .hexsha ] = parent
117118
118119 self .recenter_frame ()
119120 self .scale_frame ()
120121
121122 branch_counts = {}
122123 k = 0
123124 for j , tr in enumerate (reversed (to_rebase )):
124- if len (tr .parents ) > 1 :
125- k += 1
126- continue
125+ if not self .rebase_merges :
126+ if len (tr .parents ) > 1 :
127+ k += 1
128+ continue
127129 color_index = int (self .drawnCommits [tr .hexsha ].get_center ()[1 ] / - 4 ) - 1
128130 if color_index not in branch_counts :
129131 branch_counts [color_index ] = 0
@@ -132,7 +134,10 @@ def construct(self):
132134 arrow_color = self .alt_colors [color_index % len (self .alt_colors )][1 if branch_counts [color_index ] % 2 == 0 else 1 ]
133135 self .draw_arrow_between_commits (tr .hexsha , rebased_shas [j - k ], color = arrow_color )
134136
135- self .reset_head_branch (parent )
137+ if self .rebase_merges :
138+ self .reset_head_branch (rebased_sha_map [default_commits [0 ][0 ].hexsha ])
139+ else :
140+ self .reset_head_branch (parent )
136141 self .color_by (offset = 2 * len (to_rebase ))
137142 self .show_command_as_title ()
138143 self .fadeout ()
@@ -141,10 +146,13 @@ def construct(self):
141146 def setup_and_draw_parent (
142147 self ,
143148 child ,
149+ orig ,
144150 commitMessage = "New commit" ,
145151 shift = numpy .array ([0.0 , 0.0 , 0.0 ]),
146152 draw_arrow = True ,
147153 color = m .RED ,
154+ branch_index = 0 ,
155+ default_commits = {},
148156 ):
149157 circle = m .Circle (
150158 stroke_color = color ,
@@ -153,25 +161,48 @@ def setup_and_draw_parent(
153161 fill_opacity = 0.25 ,
154162 )
155163 circle .height = 1
156- circle .next_to (
157- self .drawnCommits [child ],
158- m .LEFT if settings .reverse else m .RIGHT ,
159- buff = 1.5 ,
160- )
164+ if self .rebase_merges and branch_index > 0 :
165+ circle .move_to (
166+ self .drawnCommits [orig ].get_center (),
167+ ).shift (m .UP * 4 + (m .LEFT if settings .reverse else m .RIGHT ) * len (default_commits [0 ]) * 2.5 )
168+ else :
169+ circle .next_to (
170+ self .drawnCommits [child ],
171+ m .LEFT if settings .reverse else m .RIGHT ,
172+ buff = 1.5 ,
173+ )
161174 circle .shift (shift )
162175
176+ arrow_start_ends = []
177+ arrows = []
163178 start = circle .get_center ()
164- end = self .drawnCommits [child ].get_center ()
165- arrow = m .Arrow (
166- start ,
167- end ,
168- color = self .fontColor ,
169- stroke_width = self .arrow_stroke_width ,
170- tip_shape = self .arrow_tip_shape ,
171- max_stroke_width_to_length_ratio = 1000 ,
172- )
173- length = numpy .linalg .norm (start - end ) - (1.5 if start [1 ] == end [1 ] else 3 )
174- arrow .set_length (length )
179+ if not self .rebase_merges or branch_index == 0 :
180+ end = self .drawnCommits [child ].get_center ()
181+ arrow_start_ends .append ((start , end ))
182+ if self .rebase_merges :
183+ for p in self .get_commit (orig ).parents :
184+ if self .branch in self .repo .git .branch (
185+ "--contains" , p
186+ ):
187+ continue
188+ try :
189+ end = self .drawnCommits [p .hexsha ].get_center () + m .UP * 4 + (m .LEFT if settings .reverse else m .RIGHT ) * len (default_commits [0 ]) * 2.5
190+ arrow_start_ends .append ((start , end ))
191+ except KeyError :
192+ pass
193+
194+ for start , end in arrow_start_ends :
195+ arrow = m .Arrow (
196+ start ,
197+ end ,
198+ color = self .fontColor ,
199+ stroke_width = self .arrow_stroke_width ,
200+ tip_shape = self .arrow_tip_shape ,
201+ max_stroke_width_to_length_ratio = 1000 ,
202+ )
203+ length = numpy .linalg .norm (start - end ) - (1.5 if start [1 ] == end [1 ] else 3 )
204+ arrow .set_length (length )
205+ arrows .append (arrow )
175206
176207 sha = None
177208 while not sha or sha in self .drawnCommits :
@@ -212,18 +243,23 @@ def setup_and_draw_parent(
212243
213244 if draw_arrow :
214245 if settings .animate :
215- self .play (m .Create (arrow ), run_time = 1 / settings .speed )
246+ for arrow in arrows :
247+ self .play (m .Create (arrow ), run_time = 1 / settings .speed )
248+ self .toFadeOut .add (arrow )
216249 else :
217- self .add (arrow )
218- self .toFadeOut .add (arrow )
250+ for arrow in arrows :
251+ self .add (arrow )
252+ self .toFadeOut .add (arrow )
219253
220254 return sha
221255
222256 def get_default_commits (self , commit , default_commits , branch_index = 0 ):
223257 if branch_index not in default_commits :
224258 default_commits [branch_index ] = []
225259 if len (default_commits [branch_index ]) < self .n :
226- if commit not in self .sort_and_flatten (default_commits ):
260+ if commit not in self .sort_and_flatten (default_commits ) and self .branch not in self .repo .git .branch (
261+ "--contains" , commit
262+ ):
227263 default_commits [branch_index ].append (commit )
228264 for i , parent in enumerate (commit .parents ):
229265 self .get_default_commits (parent , default_commits , branch_index + i )
0 commit comments