Skip to content

Commit 3b97d7b

Browse files
committedSep 5, 2017
Shot Algorithm Python 3.6
Some more Quantum Computing fun with Shortchanged Algorithm.
1 parent 0412656 commit 3b97d7b

File tree

2 files changed

+633
-0
lines changed

2 files changed

+633
-0
lines changed
 

‎sympy/Shor'ly Yours!.ipynb

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"## More of Shor Algorithm "
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": 1,
13+
"metadata": {},
14+
"outputs": [
15+
{
16+
"name": "stdout",
17+
"output_type": "stream",
18+
"text": [
19+
"N = 15\n",
20+
"Neighborhood = 1\n",
21+
"Number of periods = 2\n",
22+
"\n",
23+
"Attempt #0\n",
24+
"Finding the period...\n",
25+
"Q = 256\ta = 2\n",
26+
"Registers generated\n",
27+
"Performing Hadamard on input register\n",
28+
"Hadamard complete\n",
29+
"Mapping input register to output register, where f(x) is a^x mod N\n",
30+
"Modular exponentiation complete\n",
31+
"Performing quantum Fourier transform on output register\n",
32+
"Quantum Fourier transform complete\n",
33+
"Performing a measurement on the output register\n",
34+
"Output register measured\ty = 2\n",
35+
"Performing a measurement on the periodicity register\n",
36+
"QFT register measured\tx = 64\n",
37+
"Finding the period via continued fractions\n",
38+
"Candidate period\tr = 4\n",
39+
"Checking candidate period, nearby values, and multiples\n",
40+
"Period found\tr = 4\n",
41+
"\n",
42+
"Attempt #1\n",
43+
"Finding the period...\n",
44+
"Q = 256\ta = 7\n",
45+
"Registers generated\n",
46+
"Performing Hadamard on input register\n",
47+
"Hadamard complete\n",
48+
"Mapping input register to output register, where f(x) is a^x mod N\n",
49+
"Modular exponentiation complete\n",
50+
"Performing quantum Fourier transform on output register\n",
51+
"Quantum Fourier transform complete\n",
52+
"Performing a measurement on the output register\n",
53+
"Output register measured\ty = 4\n",
54+
"Performing a measurement on the periodicity register\n",
55+
"QFT register measured\tx = 192\n",
56+
"Finding the period via continued fractions\n",
57+
"Candidate period\tr = 4\n",
58+
"Checking candidate period, nearby values, and multiples\n",
59+
"Period found\tr = 4\n",
60+
"\n",
61+
"Finding least common multiple of all periods\n",
62+
"Factors:\t5, 3\n"
63+
]
64+
}
65+
],
66+
"source": [
67+
"!python shor2.py 15"
68+
]
69+
},
70+
{
71+
"cell_type": "code",
72+
"execution_count": null,
73+
"metadata": {},
74+
"outputs": [
75+
{
76+
"name": "stdout",
77+
"output_type": "stream",
78+
"text": [
79+
"N = 23\n",
80+
"Neighborhood = 1\n",
81+
"Number of periods = 2\n",
82+
"\n",
83+
"Attempt #0\n",
84+
"Finding the period...\n",
85+
"Q = 1024\ta = 15\n",
86+
"Registers generated\n",
87+
"Performing Hadamard on input register\n",
88+
"Hadamard complete\n",
89+
"Mapping input register to output register, where f(x) is a^x mod N\n",
90+
"Modular exponentiation complete\n",
91+
"Performing quantum Fourier transform on output register\n",
92+
"Quantum Fourier transform complete\n",
93+
"Performing a measurement on the output register\n",
94+
"Output register measured\ty = 7\n",
95+
"Performing a measurement on the periodicity register\n",
96+
"QFT register measured\tx = 372\n",
97+
"Finding the period via continued fractions\n",
98+
"Candidate period\tr = 11\n",
99+
"Checking candidate period, nearby values, and multiples\n",
100+
"Period was trivial, re-attempt\n",
101+
"\n",
102+
"Attempt #1\n",
103+
"Finding the period...\n",
104+
"Q = 1024\ta = 6\n",
105+
"Registers generated\n",
106+
"Performing Hadamard on input register\n",
107+
"Hadamard complete\n",
108+
"Mapping input register to output register, where f(x) is a^x mod N\n",
109+
"Modular exponentiation complete\n",
110+
"Performing quantum Fourier transform on output register\n",
111+
"Quantum Fourier transform complete\n",
112+
"Performing a measurement on the output register\n",
113+
"Output register measured\ty = 18\n",
114+
"Performing a measurement on the periodicity register\n",
115+
"QFT register measured\tx = 186\n",
116+
"Finding the period via continued fractions\n",
117+
"Candidate period\tr = 11\n",
118+
"Checking candidate period, nearby values, and multiples\n",
119+
"Period was odd, re-attempt\n",
120+
"\n",
121+
"Attempt #2\n",
122+
"Finding the period...\n",
123+
"Q = 1024\ta = 9\n",
124+
"Registers generated\n",
125+
"Performing Hadamard on input register\n"
126+
]
127+
}
128+
],
129+
"source": [
130+
"! python shor2.py 23"
131+
]
132+
},
133+
{
134+
"cell_type": "code",
135+
"execution_count": null,
136+
"metadata": {
137+
"collapsed": true
138+
},
139+
"outputs": [],
140+
"source": []
141+
}
142+
],
143+
"metadata": {
144+
"kernelspec": {
145+
"display_name": "Python 3",
146+
"language": "python",
147+
"name": "python3"
148+
},
149+
"language_info": {
150+
"codemirror_mode": {
151+
"name": "ipython",
152+
"version": 3
153+
},
154+
"file_extension": ".py",
155+
"mimetype": "text/x-python",
156+
"name": "python",
157+
"nbconvert_exporter": "python",
158+
"pygments_lexer": "ipython3",
159+
"version": "3.6.1"
160+
}
161+
},
162+
"nbformat": 4,
163+
"nbformat_minor": 2
164+
}

‎sympy/shor2.py

+469
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,469 @@
1+
#!/usr/bin/env python
2+
3+
"""shors.py: Shor's algorithm for quantum integer factorization"""
4+
5+
import math
6+
import random
7+
import argparse
8+
9+
__author__ = "Todd Wildey"
10+
__copyright__ = "Copyright 2013"
11+
__credits__ = ["Todd Wildey"]
12+
13+
__license__ = "MIT"
14+
__version__ = "1.0.0"
15+
__maintainer__ = "Todd Wildey"
16+
__email__ = "toddwildey@gmail.com"
17+
__status__ = "Prototype"
18+
19+
def printNone(str):
20+
pass
21+
22+
def printVerbose(str):
23+
print(str)
24+
25+
printInfo = printNone
26+
27+
####################################################################################################
28+
#
29+
# Quantum Components
30+
#
31+
####################################################################################################
32+
33+
class Mapping:
34+
def __init__(self, state, amplitude):
35+
self.state = state
36+
self.amplitude = amplitude
37+
38+
39+
class QuantumState:
40+
def __init__(self, amplitude, register):
41+
self.amplitude = amplitude
42+
self.register = register
43+
self.entangled = {}
44+
45+
def entangle(self, fromState, amplitude):
46+
register = fromState.register
47+
entanglement = Mapping(fromState, amplitude)
48+
try:
49+
self.entangled[register].append(entanglement)
50+
except KeyError:
51+
self.entangled[register] = [entanglement]
52+
53+
def entangles(self, register = None):
54+
entangles = 0
55+
if register is None:
56+
for states in self.entangled.values():
57+
entangles += len(states)
58+
else:
59+
entangles = len(self.entangled[register])
60+
61+
return entangles
62+
63+
64+
class QubitRegister:
65+
def __init__(self, numBits):
66+
self.numBits = numBits
67+
self.numStates = 1 << numBits
68+
self.entangled = []
69+
self.states = [QuantumState(complex(0.0), self) for x in range(self.numStates)]
70+
self.states[0].amplitude = complex(1.0)
71+
72+
def propagate(self, fromRegister = None):
73+
if fromRegister is not None:
74+
for state in self.states:
75+
amplitude = complex(0.0)
76+
77+
try:
78+
entangles = state.entangled[fromRegister]
79+
for entangle in entangles:
80+
amplitude += entangle.state.amplitude * entangle.amplitude
81+
82+
state.amplitude = amplitude
83+
except KeyError:
84+
state.amplitude = amplitude
85+
86+
for register in self.entangled:
87+
if register is fromRegister:
88+
continue
89+
90+
register.propagate(self)
91+
92+
# Map will convert any mapping to a unitary tensor given each element v
93+
# returned by the mapping has the property v * v.conjugate() = 1
94+
#
95+
def map(self, toRegister, mapping, propagate = True):
96+
self.entangled.append(toRegister)
97+
toRegister.entangled.append(self)
98+
99+
# Create the covariant/contravariant representations
100+
mapTensorX = {}
101+
mapTensorY = {}
102+
for x in range(self.numStates):
103+
mapTensorX[x] = {}
104+
codomain = mapping(x)
105+
for element in codomain:
106+
y = element.state
107+
mapTensorX[x][y] = element
108+
109+
try:
110+
mapTensorY[y][x] = element
111+
except KeyError:
112+
mapTensorY[y] = { x: element }
113+
114+
# Normalize the mapping:
115+
def normalize(tensor, p = False):
116+
lSqrt = math.sqrt
117+
for vectors in tensor.values():
118+
sumProb = 0.0
119+
for element in vectors.values():
120+
amplitude = element.amplitude
121+
sumProb += (amplitude * amplitude.conjugate()).real
122+
123+
normalized = lSqrt(sumProb)
124+
for element in vectors.values():
125+
element.amplitude = element.amplitude / normalized
126+
127+
normalize(mapTensorX)
128+
normalize(mapTensorY, True)
129+
130+
# Entangle the registers
131+
for x, yStates in mapTensorX.items():
132+
for y, element in yStates.items():
133+
amplitude = element.amplitude
134+
toState = toRegister.states[y]
135+
fromState = self.states[x]
136+
toState.entangle(fromState, amplitude)
137+
fromState.entangle(toState, amplitude.conjugate())
138+
139+
if propagate:
140+
toRegister.propagate(self)
141+
142+
def measure(self):
143+
measure = random.random()
144+
sumProb = 0.0
145+
146+
# Pick a state
147+
finalX = None
148+
finalState = None
149+
for x, state in enumerate(self.states):
150+
amplitude = state.amplitude
151+
sumProb += (amplitude * amplitude.conjugate()).real
152+
153+
if sumProb > measure:
154+
finalState = state
155+
finalX = x
156+
break
157+
158+
# If state was found, update the system
159+
if finalState is not None:
160+
for state in self.states:
161+
state.amplitude = complex(0.0)
162+
163+
finalState.amplitude = complex(1.0)
164+
self.propagate()
165+
166+
return finalX
167+
168+
def entangles(self, register = None):
169+
entangles = 0
170+
for state in self.states:
171+
entangles += state.entangles(None)
172+
173+
return entangles
174+
175+
def amplitudes(self):
176+
amplitudes = []
177+
for state in self.states:
178+
amplitudes.append(state.amplitude)
179+
180+
return amplitudes
181+
182+
def printEntangles(register):
183+
printInfo("Entagles: " + str(register.entangles()))
184+
185+
def printAmplitudes(register):
186+
amplitudes = register.amplitudes()
187+
for x, amplitude in enumerate(amplitudes):
188+
printInfo('State #' + str(x) + '\'s amplitude: ' + str(amplitude))
189+
190+
def hadamard(x, Q):
191+
codomain = []
192+
for y in range(Q):
193+
amplitude = complex(pow(-1.0, bitCount(x & y) & 1))
194+
codomain.append(Mapping(y, amplitude))
195+
196+
return codomain
197+
198+
# Quantum Modular Exponentiation
199+
def qModExp(a, exp, mod):
200+
state = modExp(a, exp, mod)
201+
amplitude = complex(1.0)
202+
return [Mapping(state, amplitude)]
203+
204+
# Quantum Fourier Transform
205+
def qft(x, Q):
206+
fQ = float(Q)
207+
k = -2.0 * math.pi
208+
codomain = []
209+
210+
for y in range(Q):
211+
theta = (k * float((x * y) % Q)) / fQ
212+
amplitude = complex(math.cos(theta), math.sin(theta))
213+
codomain.append(Mapping(y, amplitude))
214+
215+
return codomain
216+
217+
def findPeriod(a, N):
218+
nNumBits = N.bit_length()
219+
inputNumBits = (2 * nNumBits) - 1
220+
inputNumBits += 1 if ((1 << inputNumBits) < (N * N)) else 0
221+
Q = 1 << inputNumBits
222+
223+
printInfo("Finding the period...")
224+
printInfo("Q = " + str(Q) + "\ta = " + str(a))
225+
226+
inputRegister = QubitRegister(inputNumBits)
227+
hmdInputRegister = QubitRegister(inputNumBits)
228+
qftInputRegister = QubitRegister(inputNumBits)
229+
outputRegister = QubitRegister(inputNumBits)
230+
231+
printInfo("Registers generated")
232+
printInfo("Performing Hadamard on input register")
233+
234+
inputRegister.map(hmdInputRegister, lambda x: hadamard(x, Q), False)
235+
# inputRegister.hadamard(False)
236+
237+
printInfo("Hadamard complete")
238+
printInfo("Mapping input register to output register, where f(x) is a^x mod N")
239+
240+
hmdInputRegister.map(outputRegister, lambda x: qModExp(a, x, N), False)
241+
242+
printInfo("Modular exponentiation complete")
243+
printInfo("Performing quantum Fourier transform on output register")
244+
245+
hmdInputRegister.map(qftInputRegister, lambda x: qft(x, Q), False)
246+
inputRegister.propagate()
247+
248+
printInfo("Quantum Fourier transform complete")
249+
printInfo("Performing a measurement on the output register")
250+
251+
y = outputRegister.measure()
252+
253+
printInfo("Output register measured\ty = " + str(y))
254+
255+
# Interesting to watch - simply uncomment
256+
# printAmplitudes(inputRegister)
257+
# printAmplitudes(qftInputRegister)
258+
# printAmplitudes(outputRegister)
259+
# printEntangles(inputRegister)
260+
261+
printInfo("Performing a measurement on the periodicity register")
262+
263+
x = qftInputRegister.measure()
264+
265+
printInfo("QFT register measured\tx = " + str(x))
266+
267+
if x is None:
268+
return None
269+
270+
printInfo("Finding the period via continued fractions")
271+
272+
r = cf(x, Q, N)
273+
274+
printInfo("Candidate period\tr = " + str(r))
275+
276+
return r
277+
278+
####################################################################################################
279+
#
280+
# Classical Components
281+
#
282+
####################################################################################################
283+
284+
BIT_LIMIT = 12
285+
286+
def bitCount(x):
287+
sumBits = 0
288+
while x > 0:
289+
sumBits += x & 1
290+
x >>= 1
291+
292+
return sumBits
293+
294+
# Greatest Common Divisor
295+
def gcd(a, b):
296+
while b != 0:
297+
tA = a % b
298+
a = b
299+
b = tA
300+
301+
return a
302+
303+
# Extended Euclidean
304+
def extendedGCD(a, b):
305+
fractions = []
306+
while b != 0:
307+
fractions.append(a // b)
308+
tA = a % b
309+
a = b
310+
b = tA
311+
312+
return fractions
313+
314+
# Continued Fractions
315+
def cf(y, Q, N):
316+
fractions = extendedGCD(y, Q)
317+
depth = 2
318+
319+
def partial(fractions, depth):
320+
c = 0
321+
r = 1
322+
323+
for i in reversed(range(depth)):
324+
tR = fractions[i] * r + c
325+
c = r
326+
r = tR
327+
328+
return c
329+
330+
r = 0
331+
for d in range(depth, len(fractions) + 1):
332+
tR = partial(fractions, d)
333+
if tR == r or tR >= N:
334+
return r
335+
336+
r = tR
337+
338+
return r
339+
340+
# Modular Exponentiation
341+
def modExp(a, exp, mod):
342+
fx = 1
343+
while exp > 0:
344+
if (exp & 1) == 1:
345+
fx = fx * a % mod
346+
a = (a * a) % mod
347+
exp = exp >> 1
348+
349+
return fx
350+
351+
def pick(N):
352+
a = math.floor((random.random() * (N - 1)) + 0.5)
353+
return a
354+
355+
def checkCandidates(a, r, N, neighborhood):
356+
if r is None:
357+
return None
358+
359+
# Check multiples
360+
for k in range(1, neighborhood + 2):
361+
tR = k * r
362+
if modExp(a, a, N) == modExp(a, a + tR, N):
363+
return tR
364+
365+
# Check lower neighborhood
366+
for tR in range(r - neighborhood, r):
367+
if modExp(a, a, N) == modExp(a, a + tR, N):
368+
return tR
369+
370+
# Check upper neigborhood
371+
for tR in range(r + 1, r + neighborhood + 1):
372+
if modExp(a, a, N) == modExp(a, a + tR, N):
373+
return tR
374+
375+
return None
376+
377+
def shors(N, attempts = 1, neighborhood = 0.0, numPeriods = 1):
378+
if(N.bit_length() > BIT_LIMIT or N < 3):
379+
return False
380+
381+
periods = []
382+
neighborhood = math.floor(N * neighborhood) + 1
383+
384+
printInfo("N = " + str(N))
385+
printInfo("Neighborhood = " + str(neighborhood))
386+
printInfo("Number of periods = " + str(numPeriods))
387+
388+
for attempt in range(attempts):
389+
printInfo("\nAttempt #" + str(attempt))
390+
391+
a = pick(N)
392+
while a < 2:
393+
a = pick(N)
394+
395+
d = gcd(a, N)
396+
if d > 1:
397+
printInfo("Found factors classically, re-attempt")
398+
continue
399+
400+
r = findPeriod(a, N)
401+
402+
printInfo("Checking candidate period, nearby values, and multiples")
403+
404+
r = checkCandidates(a, r, N, neighborhood)
405+
406+
if r is None:
407+
printInfo("Period was not found, re-attempt")
408+
continue
409+
410+
if (r % 2) > 0:
411+
printInfo("Period was odd, re-attempt")
412+
continue
413+
414+
d = modExp(a, (r // 2), N)
415+
if r == 0 or d == (N - 1):
416+
printInfo("Period was trivial, re-attempt")
417+
continue
418+
419+
printInfo("Period found\tr = " + str(r))
420+
421+
periods.append(r)
422+
if(len(periods) < numPeriods):
423+
continue
424+
425+
printInfo("\nFinding least common multiple of all periods")
426+
427+
r = 1
428+
for period in periods:
429+
d = gcd(period, r)
430+
r = (r * period) // d
431+
432+
b = modExp(a, (r // 2), N)
433+
f1 = gcd(N, b + 1)
434+
f2 = gcd(N, b - 1)
435+
436+
return [f1, f2]
437+
438+
return None
439+
440+
####################################################################################################
441+
#
442+
# Command-line functionality
443+
#
444+
####################################################################################################
445+
446+
def parseArgs():
447+
parser = argparse.ArgumentParser(description='Simulate Shor\'s algorithm for N.')
448+
parser.add_argument('-a', '--attempts', type=int, default=20, help='Number of quantum attemtps to perform')
449+
parser.add_argument('-n', '--neighborhood', type=float, default=0.01, help='Neighborhood size for checking candidates (as percentage of N)')
450+
parser.add_argument('-p', '--periods', type=int, default=2, help='Number of periods to get before determining least common multiple')
451+
parser.add_argument('-v', '--verbose', type=bool, default=True, help='Verbose')
452+
parser.add_argument('N', type=int, help='The integer to factor')
453+
return parser.parse_args()
454+
455+
def main():
456+
args = parseArgs()
457+
458+
global printInfo
459+
if args.verbose:
460+
printInfo = printVerbose
461+
else:
462+
printInfo = printNone
463+
464+
factors = shors(args.N, args.attempts, args.neighborhood, args.periods)
465+
if factors is not None:
466+
print("Factors:\t" + str(factors[0]) + ", " + str(factors[1]))
467+
468+
if __name__ == "__main__":
469+
main()

0 commit comments

Comments
 (0)
Please sign in to comment.