1
+ #!/usr/bin/env python
2
+
3
+ """shors.py: Shor's algorithm for quantum integer factorization"""
4
+
5
+ import math
6
+ import random
7
+ import argparse
8
+
9
+ __author__ = "Todd Wildey"
10
+ __copyright__ = "Copyright 2013"
11
+ __credits__ = ["Todd Wildey" ]
12
+
13
+ __license__ = "MIT"
14
+ __version__ = "1.0.0"
15
+ __maintainer__ = "Todd Wildey"
16
+ __email__ = "toddwildey@gmail.com"
17
+ __status__ = "Prototype"
18
+
19
+ def printNone (str ):
20
+ pass
21
+
22
+ def printVerbose (str ):
23
+ print (str )
24
+
25
+ printInfo = printNone
26
+
27
+ ####################################################################################################
28
+ #
29
+ # Quantum Components
30
+ #
31
+ ####################################################################################################
32
+
33
+ class Mapping :
34
+ def __init__ (self , state , amplitude ):
35
+ self .state = state
36
+ self .amplitude = amplitude
37
+
38
+
39
+ class QuantumState :
40
+ def __init__ (self , amplitude , register ):
41
+ self .amplitude = amplitude
42
+ self .register = register
43
+ self .entangled = {}
44
+
45
+ def entangle (self , fromState , amplitude ):
46
+ register = fromState .register
47
+ entanglement = Mapping (fromState , amplitude )
48
+ try :
49
+ self .entangled [register ].append (entanglement )
50
+ except KeyError :
51
+ self .entangled [register ] = [entanglement ]
52
+
53
+ def entangles (self , register = None ):
54
+ entangles = 0
55
+ if register is None :
56
+ for states in self .entangled .values ():
57
+ entangles += len (states )
58
+ else :
59
+ entangles = len (self .entangled [register ])
60
+
61
+ return entangles
62
+
63
+
64
+ class QubitRegister :
65
+ def __init__ (self , numBits ):
66
+ self .numBits = numBits
67
+ self .numStates = 1 << numBits
68
+ self .entangled = []
69
+ self .states = [QuantumState (complex (0.0 ), self ) for x in range (self .numStates )]
70
+ self .states [0 ].amplitude = complex (1.0 )
71
+
72
+ def propagate (self , fromRegister = None ):
73
+ if fromRegister is not None :
74
+ for state in self .states :
75
+ amplitude = complex (0.0 )
76
+
77
+ try :
78
+ entangles = state .entangled [fromRegister ]
79
+ for entangle in entangles :
80
+ amplitude += entangle .state .amplitude * entangle .amplitude
81
+
82
+ state .amplitude = amplitude
83
+ except KeyError :
84
+ state .amplitude = amplitude
85
+
86
+ for register in self .entangled :
87
+ if register is fromRegister :
88
+ continue
89
+
90
+ register .propagate (self )
91
+
92
+ # Map will convert any mapping to a unitary tensor given each element v
93
+ # returned by the mapping has the property v * v.conjugate() = 1
94
+ #
95
+ def map (self , toRegister , mapping , propagate = True ):
96
+ self .entangled .append (toRegister )
97
+ toRegister .entangled .append (self )
98
+
99
+ # Create the covariant/contravariant representations
100
+ mapTensorX = {}
101
+ mapTensorY = {}
102
+ for x in range (self .numStates ):
103
+ mapTensorX [x ] = {}
104
+ codomain = mapping (x )
105
+ for element in codomain :
106
+ y = element .state
107
+ mapTensorX [x ][y ] = element
108
+
109
+ try :
110
+ mapTensorY [y ][x ] = element
111
+ except KeyError :
112
+ mapTensorY [y ] = { x : element }
113
+
114
+ # Normalize the mapping:
115
+ def normalize (tensor , p = False ):
116
+ lSqrt = math .sqrt
117
+ for vectors in tensor .values ():
118
+ sumProb = 0.0
119
+ for element in vectors .values ():
120
+ amplitude = element .amplitude
121
+ sumProb += (amplitude * amplitude .conjugate ()).real
122
+
123
+ normalized = lSqrt (sumProb )
124
+ for element in vectors .values ():
125
+ element .amplitude = element .amplitude / normalized
126
+
127
+ normalize (mapTensorX )
128
+ normalize (mapTensorY , True )
129
+
130
+ # Entangle the registers
131
+ for x , yStates in mapTensorX .items ():
132
+ for y , element in yStates .items ():
133
+ amplitude = element .amplitude
134
+ toState = toRegister .states [y ]
135
+ fromState = self .states [x ]
136
+ toState .entangle (fromState , amplitude )
137
+ fromState .entangle (toState , amplitude .conjugate ())
138
+
139
+ if propagate :
140
+ toRegister .propagate (self )
141
+
142
+ def measure (self ):
143
+ measure = random .random ()
144
+ sumProb = 0.0
145
+
146
+ # Pick a state
147
+ finalX = None
148
+ finalState = None
149
+ for x , state in enumerate (self .states ):
150
+ amplitude = state .amplitude
151
+ sumProb += (amplitude * amplitude .conjugate ()).real
152
+
153
+ if sumProb > measure :
154
+ finalState = state
155
+ finalX = x
156
+ break
157
+
158
+ # If state was found, update the system
159
+ if finalState is not None :
160
+ for state in self .states :
161
+ state .amplitude = complex (0.0 )
162
+
163
+ finalState .amplitude = complex (1.0 )
164
+ self .propagate ()
165
+
166
+ return finalX
167
+
168
+ def entangles (self , register = None ):
169
+ entangles = 0
170
+ for state in self .states :
171
+ entangles += state .entangles (None )
172
+
173
+ return entangles
174
+
175
+ def amplitudes (self ):
176
+ amplitudes = []
177
+ for state in self .states :
178
+ amplitudes .append (state .amplitude )
179
+
180
+ return amplitudes
181
+
182
+ def printEntangles (register ):
183
+ printInfo ("Entagles: " + str (register .entangles ()))
184
+
185
+ def printAmplitudes (register ):
186
+ amplitudes = register .amplitudes ()
187
+ for x , amplitude in enumerate (amplitudes ):
188
+ printInfo ('State #' + str (x ) + '\' s amplitude: ' + str (amplitude ))
189
+
190
+ def hadamard (x , Q ):
191
+ codomain = []
192
+ for y in range (Q ):
193
+ amplitude = complex (pow (- 1.0 , bitCount (x & y ) & 1 ))
194
+ codomain .append (Mapping (y , amplitude ))
195
+
196
+ return codomain
197
+
198
+ # Quantum Modular Exponentiation
199
+ def qModExp (a , exp , mod ):
200
+ state = modExp (a , exp , mod )
201
+ amplitude = complex (1.0 )
202
+ return [Mapping (state , amplitude )]
203
+
204
+ # Quantum Fourier Transform
205
+ def qft (x , Q ):
206
+ fQ = float (Q )
207
+ k = - 2.0 * math .pi
208
+ codomain = []
209
+
210
+ for y in range (Q ):
211
+ theta = (k * float ((x * y ) % Q )) / fQ
212
+ amplitude = complex (math .cos (theta ), math .sin (theta ))
213
+ codomain .append (Mapping (y , amplitude ))
214
+
215
+ return codomain
216
+
217
+ def findPeriod (a , N ):
218
+ nNumBits = N .bit_length ()
219
+ inputNumBits = (2 * nNumBits ) - 1
220
+ inputNumBits += 1 if ((1 << inputNumBits ) < (N * N )) else 0
221
+ Q = 1 << inputNumBits
222
+
223
+ printInfo ("Finding the period..." )
224
+ printInfo ("Q = " + str (Q ) + "\t a = " + str (a ))
225
+
226
+ inputRegister = QubitRegister (inputNumBits )
227
+ hmdInputRegister = QubitRegister (inputNumBits )
228
+ qftInputRegister = QubitRegister (inputNumBits )
229
+ outputRegister = QubitRegister (inputNumBits )
230
+
231
+ printInfo ("Registers generated" )
232
+ printInfo ("Performing Hadamard on input register" )
233
+
234
+ inputRegister .map (hmdInputRegister , lambda x : hadamard (x , Q ), False )
235
+ # inputRegister.hadamard(False)
236
+
237
+ printInfo ("Hadamard complete" )
238
+ printInfo ("Mapping input register to output register, where f(x) is a^x mod N" )
239
+
240
+ hmdInputRegister .map (outputRegister , lambda x : qModExp (a , x , N ), False )
241
+
242
+ printInfo ("Modular exponentiation complete" )
243
+ printInfo ("Performing quantum Fourier transform on output register" )
244
+
245
+ hmdInputRegister .map (qftInputRegister , lambda x : qft (x , Q ), False )
246
+ inputRegister .propagate ()
247
+
248
+ printInfo ("Quantum Fourier transform complete" )
249
+ printInfo ("Performing a measurement on the output register" )
250
+
251
+ y = outputRegister .measure ()
252
+
253
+ printInfo ("Output register measured\t y = " + str (y ))
254
+
255
+ # Interesting to watch - simply uncomment
256
+ # printAmplitudes(inputRegister)
257
+ # printAmplitudes(qftInputRegister)
258
+ # printAmplitudes(outputRegister)
259
+ # printEntangles(inputRegister)
260
+
261
+ printInfo ("Performing a measurement on the periodicity register" )
262
+
263
+ x = qftInputRegister .measure ()
264
+
265
+ printInfo ("QFT register measured\t x = " + str (x ))
266
+
267
+ if x is None :
268
+ return None
269
+
270
+ printInfo ("Finding the period via continued fractions" )
271
+
272
+ r = cf (x , Q , N )
273
+
274
+ printInfo ("Candidate period\t r = " + str (r ))
275
+
276
+ return r
277
+
278
+ ####################################################################################################
279
+ #
280
+ # Classical Components
281
+ #
282
+ ####################################################################################################
283
+
284
+ BIT_LIMIT = 12
285
+
286
+ def bitCount (x ):
287
+ sumBits = 0
288
+ while x > 0 :
289
+ sumBits += x & 1
290
+ x >>= 1
291
+
292
+ return sumBits
293
+
294
+ # Greatest Common Divisor
295
+ def gcd (a , b ):
296
+ while b != 0 :
297
+ tA = a % b
298
+ a = b
299
+ b = tA
300
+
301
+ return a
302
+
303
+ # Extended Euclidean
304
+ def extendedGCD (a , b ):
305
+ fractions = []
306
+ while b != 0 :
307
+ fractions .append (a // b )
308
+ tA = a % b
309
+ a = b
310
+ b = tA
311
+
312
+ return fractions
313
+
314
+ # Continued Fractions
315
+ def cf (y , Q , N ):
316
+ fractions = extendedGCD (y , Q )
317
+ depth = 2
318
+
319
+ def partial (fractions , depth ):
320
+ c = 0
321
+ r = 1
322
+
323
+ for i in reversed (range (depth )):
324
+ tR = fractions [i ] * r + c
325
+ c = r
326
+ r = tR
327
+
328
+ return c
329
+
330
+ r = 0
331
+ for d in range (depth , len (fractions ) + 1 ):
332
+ tR = partial (fractions , d )
333
+ if tR == r or tR >= N :
334
+ return r
335
+
336
+ r = tR
337
+
338
+ return r
339
+
340
+ # Modular Exponentiation
341
+ def modExp (a , exp , mod ):
342
+ fx = 1
343
+ while exp > 0 :
344
+ if (exp & 1 ) == 1 :
345
+ fx = fx * a % mod
346
+ a = (a * a ) % mod
347
+ exp = exp >> 1
348
+
349
+ return fx
350
+
351
+ def pick (N ):
352
+ a = math .floor ((random .random () * (N - 1 )) + 0.5 )
353
+ return a
354
+
355
+ def checkCandidates (a , r , N , neighborhood ):
356
+ if r is None :
357
+ return None
358
+
359
+ # Check multiples
360
+ for k in range (1 , neighborhood + 2 ):
361
+ tR = k * r
362
+ if modExp (a , a , N ) == modExp (a , a + tR , N ):
363
+ return tR
364
+
365
+ # Check lower neighborhood
366
+ for tR in range (r - neighborhood , r ):
367
+ if modExp (a , a , N ) == modExp (a , a + tR , N ):
368
+ return tR
369
+
370
+ # Check upper neigborhood
371
+ for tR in range (r + 1 , r + neighborhood + 1 ):
372
+ if modExp (a , a , N ) == modExp (a , a + tR , N ):
373
+ return tR
374
+
375
+ return None
376
+
377
+ def shors (N , attempts = 1 , neighborhood = 0.0 , numPeriods = 1 ):
378
+ if (N .bit_length () > BIT_LIMIT or N < 3 ):
379
+ return False
380
+
381
+ periods = []
382
+ neighborhood = math .floor (N * neighborhood ) + 1
383
+
384
+ printInfo ("N = " + str (N ))
385
+ printInfo ("Neighborhood = " + str (neighborhood ))
386
+ printInfo ("Number of periods = " + str (numPeriods ))
387
+
388
+ for attempt in range (attempts ):
389
+ printInfo ("\n Attempt #" + str (attempt ))
390
+
391
+ a = pick (N )
392
+ while a < 2 :
393
+ a = pick (N )
394
+
395
+ d = gcd (a , N )
396
+ if d > 1 :
397
+ printInfo ("Found factors classically, re-attempt" )
398
+ continue
399
+
400
+ r = findPeriod (a , N )
401
+
402
+ printInfo ("Checking candidate period, nearby values, and multiples" )
403
+
404
+ r = checkCandidates (a , r , N , neighborhood )
405
+
406
+ if r is None :
407
+ printInfo ("Period was not found, re-attempt" )
408
+ continue
409
+
410
+ if (r % 2 ) > 0 :
411
+ printInfo ("Period was odd, re-attempt" )
412
+ continue
413
+
414
+ d = modExp (a , (r // 2 ), N )
415
+ if r == 0 or d == (N - 1 ):
416
+ printInfo ("Period was trivial, re-attempt" )
417
+ continue
418
+
419
+ printInfo ("Period found\t r = " + str (r ))
420
+
421
+ periods .append (r )
422
+ if (len (periods ) < numPeriods ):
423
+ continue
424
+
425
+ printInfo ("\n Finding least common multiple of all periods" )
426
+
427
+ r = 1
428
+ for period in periods :
429
+ d = gcd (period , r )
430
+ r = (r * period ) // d
431
+
432
+ b = modExp (a , (r // 2 ), N )
433
+ f1 = gcd (N , b + 1 )
434
+ f2 = gcd (N , b - 1 )
435
+
436
+ return [f1 , f2 ]
437
+
438
+ return None
439
+
440
+ ####################################################################################################
441
+ #
442
+ # Command-line functionality
443
+ #
444
+ ####################################################################################################
445
+
446
+ def parseArgs ():
447
+ parser = argparse .ArgumentParser (description = 'Simulate Shor\' s algorithm for N.' )
448
+ parser .add_argument ('-a' , '--attempts' , type = int , default = 20 , help = 'Number of quantum attemtps to perform' )
449
+ parser .add_argument ('-n' , '--neighborhood' , type = float , default = 0.01 , help = 'Neighborhood size for checking candidates (as percentage of N)' )
450
+ parser .add_argument ('-p' , '--periods' , type = int , default = 2 , help = 'Number of periods to get before determining least common multiple' )
451
+ parser .add_argument ('-v' , '--verbose' , type = bool , default = True , help = 'Verbose' )
452
+ parser .add_argument ('N' , type = int , help = 'The integer to factor' )
453
+ return parser .parse_args ()
454
+
455
+ def main ():
456
+ args = parseArgs ()
457
+
458
+ global printInfo
459
+ if args .verbose :
460
+ printInfo = printVerbose
461
+ else :
462
+ printInfo = printNone
463
+
464
+ factors = shors (args .N , args .attempts , args .neighborhood , args .periods )
465
+ if factors is not None :
466
+ print ("Factors:\t " + str (factors [0 ]) + ", " + str (factors [1 ]))
467
+
468
+ if __name__ == "__main__" :
469
+ main ()
0 commit comments