Skip to content

Commit 7d8a459

Browse files
authoredSep 9, 2024
Merge pull request larymak#398 from EugeneMMF/main
sudoku solver script.
2 parents f1b1de7 + 2609693 commit 7d8a459

22 files changed

+1321
-0
lines changed
 
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Sudoku Solver
2+
3+
* This app was built to allow users to solve their sudokus using a computer.
4+
* There is a Flask based webserver `web_interface.py` which when run gives a web interface to upload an image of a sudoku to be solved. The response is a solved sudoku.
5+
* There is a file `full_stack_http.py` which needs to be run alongside the webserver for the full app to run. This is in charge of opening multiple process channels to process the images that are sent to the webserver.
6+
* The app relies of Pytesseract to identify the characters in the sudoku image.
7+
8+
# Operation
9+
10+
* The image is first stripped of color.
11+
* It is then cropped to select the section of the sudoku. NOTE: This section is not dependent on the sudoku but has been hardcoded.
12+
* The resulting image is passed to `Pytesseract` to extract the characters and their position.
13+
* Using the characters and their position the grid size is determined.
14+
* The appropriate grid is created and filled with the discovered characters.
15+
* The grid is then solved with an algorithm contained in `sudoku.py`.
16+
* A snapshot of the solved grid is then created and sent back to the user.
17+
* The resultant snapshot is rendered on the browser page.
18+
19+
# To Run
20+
21+
* First install `Pytesseract`
22+
* Install `Flask`
23+
* Then run the `full_stack_http.py` file.
24+
* Then run the `web_interface.py` file.
25+
* Go to the browser and load the URL provided in the previous step.
26+
* Click the upload button.
27+
* Select your image and submit the form.
28+
* Wait for the result to be loaded.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
UPLOAD_FOLDER="uploads"
2+
SECRET_KEY="secret"
3+
SOLVER_IP="localhost"
4+
SOLVER_PORT=3535
178 KB
Loading
185 KB
Loading
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import multiprocessing.util
2+
import socket
3+
from perspective import resolve_image
4+
from sudoku import Grid
5+
import argparse
6+
import multiprocessing
7+
import os
8+
9+
temp_result_file = "resultfile.png"
10+
temp_input_file = "tempfile.jpg"
11+
12+
def process_handle_transaction(proc_num:int, sock:socket.socket):
13+
print(f"[{proc_num}] Waiting for client...")
14+
sock2, address2 = sock.accept()
15+
print(f"[{proc_num}] Connected to client with address: {address2}")
16+
sock2.settimeout(1)
17+
rec_buf = b''
18+
split = temp_input_file.split('.')
19+
my_temp_input_file = ".".join(i for i in split[:-1]) + str(proc_num) + "." + split[-1]
20+
split = temp_result_file.split('.')
21+
my_temp_result_file = ".".join(i for i in split[:-1]) + str(proc_num) + "." + split[-1]
22+
try:
23+
while True:
24+
try:
25+
rec = sock2.recv(1)
26+
rec_buf += rec
27+
if len(rec) == 0:
28+
print(f"[{proc_num}] Lost connection")
29+
break
30+
except socket.timeout:
31+
with open(my_temp_input_file, "wb") as f:
32+
f.write(rec_buf)
33+
rec_buf = b''
34+
grid_size, points = resolve_image(my_temp_input_file)
35+
grid = Grid(rows=grid_size[0], columns=grid_size[1])
36+
assignment_values = {}
37+
for val,loc in points:
38+
assignment_values[loc] = val
39+
grid.preassign(assignment_values)
40+
grid.solve()
41+
grid.save_grid_image(path=my_temp_result_file, size=(400,400))
42+
with open(my_temp_result_file, "rb") as f:
43+
sock2.send(f.read())
44+
os.remove(my_temp_input_file)
45+
os.remove(my_temp_result_file)
46+
sock2.close()
47+
print(f"[{proc_num}] Finished!")
48+
break
49+
finally:
50+
sock2.close()
51+
52+
class Manager():
53+
def __init__(self, address:tuple[str,int]):
54+
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
55+
self.address = address
56+
57+
def wait_for_connect(self):
58+
print("Waiting for client...")
59+
self.sock2, self.address2 = self.sock.accept()
60+
print(f"Connected to client with address: {self.address2}")
61+
self.sock2.settimeout(1)
62+
63+
def run(self):
64+
self.sock.bind(self.address)
65+
self.sock.listen()
66+
print(f"Listening from address: {self.address}")
67+
try:
68+
while True:
69+
self.wait_for_connect()
70+
rec_buf = b''
71+
while True:
72+
try:
73+
rec = self.sock2.recv(1)
74+
rec_buf += rec
75+
if len(rec) == 0:
76+
print("Lost connection")
77+
break
78+
except socket.timeout:
79+
with open(temp_input_file, "wb") as f:
80+
f.write(rec_buf)
81+
rec_buf = b''
82+
grid_size, points = resolve_image(temp_input_file)
83+
grid = Grid(rows=grid_size[0], columns=grid_size[1])
84+
assignment_values = {}
85+
for val,loc in points:
86+
assignment_values[loc] = val
87+
grid.preassign(assignment_values)
88+
grid.solve()
89+
grid.save_grid_image(path=temp_result_file, size=(400,400))
90+
with open(temp_result_file, "rb") as f:
91+
self.sock2.send(f.read())
92+
os.remove(temp_input_file)
93+
os.remove(temp_result_file)
94+
self.sock2.close()
95+
break
96+
finally:
97+
try:
98+
self.sock2.close()
99+
except socket.error:
100+
pass
101+
except AttributeError:
102+
pass
103+
self.sock.close()
104+
105+
def run_multiprocessing(self, max_clients:int=8):
106+
self.sock.bind(self.address)
107+
self.sock.listen()
108+
print(f"Listening from address: {self.address}")
109+
processes:dict[int,multiprocessing.Process]= {}
110+
proc_num = 0
111+
try:
112+
while True:
113+
if len(processes) <= max_clients:
114+
proc = multiprocessing.Process(target=process_handle_transaction, args=(proc_num, self.sock))
115+
proc.start()
116+
processes[proc_num] = proc
117+
proc_num += 1
118+
proc_num%=(max_clients*2)
119+
keys = list(processes.keys())
120+
for proc_n in keys:
121+
if not processes[proc_n].is_alive():
122+
processes.pop(proc_n)
123+
finally:
124+
if len(processes):
125+
for proc in processes.values():
126+
proc.kill()
127+
self.sock.close()
128+
129+
if "__main__" == __name__:
130+
parser = argparse.ArgumentParser()
131+
parser.add_argument("--port", type=int, default=3535, help="The port to host the server.")
132+
parser.add_argument("--host", type=str, default="localhost", help="The host or ip-address to host the server.")
133+
args = parser.parse_args()
134+
address = (args.host, args.port)
135+
manager = Manager(address)
136+
manager.run_multiprocessing(max_clients=multiprocessing.cpu_count())
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import torch
2+
from torch.utils.data import Dataset, DataLoader
3+
import PIL.Image as Image
4+
import pandas as pd
5+
from tqdm import tqdm
6+
import numpy as np
7+
8+
9+
class SudokuDataset(Dataset):
10+
def __init__(self, grid_locations_file:str, input_shape:tuple[int, int]) -> None:
11+
super().__init__()
12+
self.grid_locations = []
13+
self.image_filenames = []
14+
self.input_shape = input_shape
15+
self.all_data = pd.read_csv(grid_locations_file, header=0)
16+
self.image_filenames = list(self.all_data['filepath'].to_numpy())
17+
self.grid_locations = [list(a[1:]) for a in self.all_data.values]
18+
to_pop = []
19+
for i,file in enumerate(self.image_filenames):
20+
try:
21+
Image.open(file)
22+
except FileNotFoundError:
23+
to_pop.append(i)
24+
print(f"{file} not found.")
25+
for i in reversed(to_pop):
26+
self.image_filenames.pop(i)
27+
self.grid_locations.pop(i)
28+
# print(self.all_data.columns)
29+
# print(self.grid_locations)
30+
31+
def __len__(self) -> int:
32+
return len(self.image_filenames)
33+
34+
def __getitem__(self, index) -> dict[str, torch.Tensor]:
35+
image = Image.open(self.image_filenames[index]).convert("L")
36+
size = image.size
37+
image = image.resize(self.input_shape)
38+
image = np.array(image)
39+
image = image.reshape((1,*image.shape))
40+
location = self.grid_locations[index]
41+
for i in range(len(location)):
42+
if i%2:
43+
location[i] /= size[1]
44+
else:
45+
location[i] /= size[0]
46+
return {
47+
"image": torch.tensor(image, dtype=torch.float32)/255.,
48+
"grid": torch.tensor(location, dtype=torch.float32)
49+
}
50+
51+
class Model(torch.nn.Module):
52+
def __init__(self, input_shape:tuple[int,int], number_of_layers:int, dims:int, *args, **kwargs) -> None:
53+
super().__init__(*args, **kwargs)
54+
self.input_shape = input_shape
55+
self.conv_layers:list = []
56+
self.conv_layers.append(torch.nn.Conv2d(1, dims, (3,3), padding='same'))
57+
for _ in range(number_of_layers-1):
58+
self.conv_layers.append(torch.nn.Conv2d(dims, dims, (3,3), padding='same'))
59+
self.conv_layers.append(torch.nn.LeakyReLU(negative_slope=0.01))
60+
self.conv_layers.append(torch.nn.MaxPool2d((2,2)))
61+
self.conv_layers.append(torch.nn.BatchNorm2d(dims))
62+
self.flatten = torch.nn.Flatten()
63+
self.location = [
64+
torch.nn.Linear(4107, 8),
65+
torch.nn.Sigmoid()
66+
]
67+
self.conv_layers = torch.nn.ModuleList(self.conv_layers)
68+
self.location = torch.nn.ModuleList(self.location)
69+
70+
def forward(self, x:torch.Tensor) -> torch.Tensor:
71+
for layer in self.conv_layers:
72+
x = layer(x)
73+
x = self.flatten(x)
74+
location = x
75+
for layer in self.location:
76+
location = layer(location)
77+
return location
78+
79+
def create_model(input_shape:tuple[int,int], number_of_layers:int, dims:int):
80+
model = Model(input_shape, number_of_layers, dims)
81+
for p in model.parameters():
82+
if p.dim() > 1:
83+
torch.nn.init.xavier_uniform_(p)
84+
return model
85+
86+
def get_dataset(filename:str, input_shape:tuple[int,int], batch_size:int) -> DataLoader:
87+
train_dataset = SudokuDataset(filename, input_shape)
88+
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
89+
return train_dataloader
90+
91+
def train(epochs:int, config:dict, model:None|Model = None) -> Model:
92+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
93+
if not model:
94+
print("========== Using new model =========")
95+
model = create_model(config['input_shape'], config['number_of_layers'], config['dims']).to(device)
96+
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
97+
loss = torch.nn.MSELoss().to(device)
98+
dataset = get_dataset(config['filename'], config['input_shape'], config['batch_size'])
99+
prev_error = 0
100+
try:
101+
for epoch in range(1, epochs+1):
102+
batch_iterator = tqdm(dataset, f"Epoch {epoch}/{epochs}:")
103+
for batch in batch_iterator:
104+
x = batch['image'].to(device)
105+
y_true = batch['grid'].to(device)
106+
# print(batch['grid'])
107+
# return
108+
y_pred = model(x)
109+
error = loss(y_true, y_pred)
110+
batch_iterator.set_postfix({"loss":f"Loss: {error.item():6.6f}"})
111+
error.backward()
112+
optimizer.step()
113+
# optimizer.zero_grad()
114+
if abs(error-0.5) < 0.05:# or (prev_error-error)<0.000001:
115+
del(model)
116+
model = create_model(config['input_shape'], config['number_of_layers'], config['dims']).to(device)
117+
print("New model created")
118+
prev_error = error
119+
except KeyboardInterrupt:
120+
torch.save(model, "model.pt")
121+
return model
122+
123+
def test(config:dict, model_filename:str):
124+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
125+
model = torch.load("model.pt").to(device)
126+
loss = torch.nn.MSELoss().to(device)
127+
dataset = get_dataset(config['filename'], config['input_shape'], config['batch_size'])
128+
129+
130+
if __name__ == '__main__':
131+
config = {
132+
"input_shape": (300,300),
133+
"filename": "archive/outlines_sorted.csv",
134+
"number_of_layers": 4,
135+
"dims": 3,
136+
"batch_size": 8,
137+
"lr": 1e-5
138+
}
139+
# model = train(50, config)
140+
model = torch.load("model.pt")
141+
test(config, model)
Binary file not shown.
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import cv2
2+
import numpy as np
3+
from pytesseract import pytesseract as pt
4+
5+
def resolve_perspective(source_image:np.ndarray, points:np.ndarray, target_shape:tuple[int,int]) -> np.ndarray:
6+
"""Takes an source image and transforms takes the region demarkated by points and creates a rectangular image of target.
7+
8+
Args:
9+
source_image (np.ndarray): the source image.
10+
points (np.ndarray): a numpy array of 4 points that will demarkate the vertices of the region to be transformed.\n
11+
\tShould be in the form of points from the point that would be transformed to the top left of the rectangle, clockwise
12+
target_shape (tuple[int,int]): the target shape of the rectangular output image. Format [height, width].
13+
14+
Returns:
15+
np.ndarray: the output image transformed
16+
"""
17+
output_points:np.ndarray = np.array([
18+
[0,0],
19+
[target_shape[0]-1, 0],
20+
[target_shape[0]-1, target_shape[1]-1],
21+
[0,target_shape[1]-1]
22+
], dtype=np.float32)
23+
transformation_matrix:cv2.typing.MatLike = cv2.getPerspectiveTransform(points.astype(np.float32), output_points)
24+
output:cv2.typing.MatLike = cv2.warpPerspective(source_image, transformation_matrix, (target_shape[1], target_shape[0]), flags=cv2.INTER_LINEAR)
25+
return output
26+
27+
def get_grid_size(image:np.ndarray, boxes:list[list[int]], allowed_sizes:list[tuple[int,int]]=[(2,3),(3,3),(4,4)]) -> tuple[int,int]:
28+
h,w = image.shape
29+
for size in allowed_sizes:
30+
s1 = float(w)/float(size[0])
31+
s2 = float(h)/float(size[1])
32+
for box in boxes:
33+
_,x1,y1,x2,y2 = box
34+
if (abs(int(x1/s1) - int(x2/s1)) + abs(int((h - y1)/s2) - int((h - y2)/s2))) > 0:
35+
break
36+
else:
37+
return size
38+
39+
def get_points(image:np.ndarray, boxes:list[list[int]], grid_size:tuple[int,int]) -> list[tuple[int,tuple]]:
40+
h,w = image.shape
41+
size = grid_size[0] * grid_size[1]
42+
s1 = float(w)/float(size)
43+
s2 = float(h)/float(size)
44+
results = []
45+
for box in boxes:
46+
val,x1,y1,x2,y2 = box
47+
center_x = int((x1+x2)/2)
48+
center_y = int((y1+y2)/2)
49+
results.append((val, (int((h-center_y)/s2), int(center_x/s1))))
50+
return results
51+
52+
def resolve_image(path:str) -> tuple[tuple,list[tuple[int,tuple]]]:
53+
# img = cv2.imread("images/image210.jpg")
54+
img = cv2.imread(path)
55+
numbers = [str(i) for i in range(10)]
56+
max_size = 500
57+
min_area = 150
58+
*img_shape,_ = img.shape
59+
max_ind = np.argmax(img_shape)
60+
min_ind = np.argmin(img_shape)
61+
next_shape = [0,0]
62+
if max_ind != min_ind:
63+
next_shape[max_ind] = max_size
64+
next_shape[min_ind] = int(img_shape[min_ind]*max_size/img_shape[max_ind])
65+
else:
66+
next_shape = [max_size, max_size]
67+
img = cv2.resize(img, tuple(reversed(next_shape)))
68+
points = np.array([6,97,219,99,216,309,7,310])
69+
points = points.reshape((4,2))
70+
target_shape = (400,400)
71+
output = resolve_perspective(img, points, target_shape)
72+
output = cv2.cvtColor(output, cv2.COLOR_BGR2GRAY)
73+
norm_img = np.zeros((output.shape[0], output.shape[1]))
74+
output = cv2.normalize(output, norm_img, 0, 255, cv2.NORM_MINMAX)
75+
output1 = cv2.threshold(output, 140, 255, cv2.THRESH_BINARY_INV)[1]
76+
if np.average(output1.flatten()) > 128:
77+
output = cv2.threshold(output, 140, 255, cv2.THRESH_BINARY)[1]
78+
else:
79+
output = output1
80+
output = cv2.GaussianBlur(output, (1,1), 0)
81+
boxes = pt.image_to_boxes(output, "eng", config=r'-c tessedit_char_whitelist=0123456789 --psm 13 --oem 3')
82+
print(boxes)
83+
h,w = output.shape
84+
new_boxes_str = ""
85+
new_boxes = []
86+
for bt in boxes.splitlines():
87+
b = bt.split(' ')
88+
area = (int(b[1]) - int(b[3]))*(int(b[2]) - int(b[4]))
89+
if b[0] in numbers and area > min_area:
90+
output = cv2.rectangle(output, (int(b[1]), h - int(b[2])), (int(b[3]), h - int(b[4])), (255, 255, 255), 2)
91+
new_boxes_str += bt + "\n"
92+
new_boxes.append(list(int(i) for i in b[:5]))
93+
grid_size = get_grid_size(output, new_boxes)
94+
final_points = get_points(output, new_boxes, grid_size)
95+
return grid_size,final_points
96+
97+
if "__main__" == __name__:
98+
print(resolve_image("f2.jpg"))
14.9 KB
Loading

‎MachineLearning Projects/sudoku_solver/sudoku.py

Lines changed: 373 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 22,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"class Node:\n",
10+
" def __init__(self,val):\n",
11+
" self.val = val\n",
12+
" self.to = {}"
13+
]
14+
},
15+
{
16+
"cell_type": "code",
17+
"execution_count": 137,
18+
"metadata": {},
19+
"outputs": [],
20+
"source": [
21+
"class Node:\n",
22+
" def __init__(self,val):\n",
23+
" self.val:int = val\n",
24+
" self.to:dict[Node,tuple[int,int]] = {} # destinationNode:(steps,price)\n",
25+
" \n",
26+
" def __str__(self) -> str:\n",
27+
" children = ','.join(str(i.val) for i in self.to.keys())\n",
28+
" return f\"Node({self.val})\"\n",
29+
" \n",
30+
" def __repr__(self) -> str:\n",
31+
" children = ','.join(str(i.val) for i in self.to.keys())\n",
32+
" return f\"Node({self.val})\"\n",
33+
" \n",
34+
" def full(self) -> str:\n",
35+
" children = ','.join(str(i.val) for i in self.to.keys())\n",
36+
" return f\"Node({self.val})->[{children}]\"\n",
37+
"\n",
38+
"def update(node:Node, start:list[int]):\n",
39+
" # print(\"iter\", node, start)\n",
40+
" if node.val in start:\n",
41+
" # print(\"found: \", node, \" => \", start)\n",
42+
" return {}\n",
43+
" ret:dict[Node,set[tuple[int,int]]] = {\n",
44+
" i:set([tuple(node.to[i]),]) for i in node.to.keys()\n",
45+
" } # destinationNode:[(steps1,price1), (steps2,price2), ...]\n",
46+
" for destinationNode,(steps,price) in node.to.items():\n",
47+
" # print(f\"step {node} to {destinationNode}\")\n",
48+
" returned = update(destinationNode, [*start,node.val])\n",
49+
" # print(f\"{node.val} going to {destinationNode.val} got {returned}\")\n",
50+
" if returned == {}:\n",
51+
" # print(f\"here on\")\n",
52+
" ret[destinationNode].add((steps,price))\n",
53+
" continue\n",
54+
" for v,mylist in returned.items():\n",
55+
" # v is the a possible destination from our destination node\n",
56+
" # my list is a list of the steps and prices to that possible destination\n",
57+
" for (stp,prc) in mylist:\n",
58+
" newTuple = (stp+steps,prc+price)\n",
59+
" if ret.get(v):\n",
60+
" ret[v].add(newTuple)\n",
61+
" else:\n",
62+
" ret[v] = set([newTuple,])\n",
63+
" return ret"
64+
]
65+
},
66+
{
67+
"cell_type": "code",
68+
"execution_count": 176,
69+
"metadata": {},
70+
"outputs": [],
71+
"source": [
72+
"from cmath import inf\n",
73+
"\n",
74+
"def findCheapestPrice(n: int, flights: list[list[int]], src: int, dst: int, k: int) -> int:\n",
75+
" nodes:dict[int,Node] = {}\n",
76+
" for s,d,p in flights:\n",
77+
" dnode = nodes.get(d)\n",
78+
" if dnode:\n",
79+
" snode = nodes.get(s)\n",
80+
" if snode:\n",
81+
" snode.to[dnode] = (1,p)\n",
82+
" else:\n",
83+
" nd = Node(s)\n",
84+
" nd.to[dnode] = (1,p)\n",
85+
" nodes[s] = nd\n",
86+
" else:\n",
87+
" snode = nodes.get(s)\n",
88+
" if snode:\n",
89+
" nd = Node(d)\n",
90+
" snode.to[nd] = (1,p)\n",
91+
" nodes[d] = nd\n",
92+
" else:\n",
93+
" nd1 = Node(s)\n",
94+
" nd2 = Node(d)\n",
95+
" nd1.to[nd2] = (1,p)\n",
96+
" nodes[s] = nd1\n",
97+
" nodes[d] = nd2\n",
98+
" for _,node in nodes.items():\n",
99+
" print(node.full())\n",
100+
" return method2(nodes, src, dst, k)\n",
101+
"\n",
102+
"def method1(nodes:dict[int,Node], src:int, dst:int, k:int) -> int:\n",
103+
" results = {}\n",
104+
" for val,node in nodes.items():\n",
105+
" ret = update(node, [])\n",
106+
" results[val] = ret\n",
107+
" desired = results[src].get(nodes[dst])\n",
108+
" if not desired:\n",
109+
" return -1\n",
110+
" filtered = []\n",
111+
" k = k + 1\n",
112+
" for d in desired:\n",
113+
" if d[0] <= k:\n",
114+
" filtered.append(d)\n",
115+
" return min(filtered, key=lambda x:x[1])\n",
116+
"\n",
117+
"def method2(nodes:dict[int,Node], src:int, dst:int, k:int) -> int:\n",
118+
" def recurse(node:Node, dst:int, k:int, visited:list[int]):\n",
119+
" results = []\n",
120+
" if k == 1:\n",
121+
" for nd in node.to.keys():\n",
122+
" if nd.val == dst:\n",
123+
" return node.to[nd][1]\n",
124+
" return inf\n",
125+
" if node.val in visited:\n",
126+
" return inf\n",
127+
" for nd in node.to.keys():\n",
128+
" if nd.val == dst:\n",
129+
" results.append(node.to[nd][1])\n",
130+
" else:\n",
131+
" temp = recurse(nd, dst, k-1, [*visited, node.val]) + node.to[nd][1]\n",
132+
" results.append(temp)\n",
133+
" if len(results):\n",
134+
" return min(results)\n",
135+
" return inf\n",
136+
" \n",
137+
" k = k+1\n",
138+
" node = nodes[src]\n",
139+
" result = recurse(node, dst, k, [])\n",
140+
" if result == inf:\n",
141+
" return -1\n",
142+
" return result"
143+
]
144+
},
145+
{
146+
"cell_type": "code",
147+
"execution_count": 157,
148+
"metadata": {},
149+
"outputs": [
150+
{
151+
"data": {
152+
"text/plain": [
153+
"100"
154+
]
155+
},
156+
"execution_count": 157,
157+
"metadata": {},
158+
"output_type": "execute_result"
159+
}
160+
],
161+
"source": [
162+
"findCheapestPrice(n = 3, flights = [[0,1,100],[1,2,100],[0,2,500]], src = 0, dst = 2, k = 1)"
163+
]
164+
},
165+
{
166+
"cell_type": "code",
167+
"execution_count": 178,
168+
"metadata": {},
169+
"outputs": [
170+
{
171+
"name": "stdout",
172+
"output_type": "stream",
173+
"text": [
174+
"Node(0)->[12,8,15,10]\n",
175+
"Node(12)->[4,3,14,13,9,0,16,6]\n",
176+
"Node(5)->[6,14,13,16,10,9,7]\n",
177+
"Node(6)->[14,10,2,12]\n",
178+
"Node(8)->[6,10,11,9,2,13,3]\n",
179+
"Node(13)->[15,12,6,16,0,5,11,7,8]\n",
180+
"Node(15)->[3,0,6,13,12,11,14,2]\n",
181+
"Node(10)->[12,2,15,11,5,4,9,0,7]\n",
182+
"Node(3)->[4,12,5,6,7,10]\n",
183+
"Node(7)->[11,3,1,14,0,12,2]\n",
184+
"Node(11)->[16,1,0,2,6,9]\n",
185+
"Node(9)->[4,6,1,12,7,10,15,5]\n",
186+
"Node(4)->[7,9,8,5,11,10]\n",
187+
"Node(2)->[12,0,11,5,13,10,7]\n",
188+
"Node(14)->[15,1,9,7,11,6]\n",
189+
"Node(16)->[4,12,1,3,8,11,9,14]\n",
190+
"Node(1)->[11,4,3,7]\n"
191+
]
192+
},
193+
{
194+
"data": {
195+
"text/plain": [
196+
"47"
197+
]
198+
},
199+
"execution_count": 178,
200+
"metadata": {},
201+
"output_type": "execute_result"
202+
}
203+
],
204+
"source": [
205+
"findCheapestPrice(n = 4, flights = [[0,12,28],[5,6,39],[8,6,59],[13,15,7],[13,12,38],[10,12,35],[15,3,23],[7,11,26],[9,4,65],[10,2,38],[4,7,7],[14,15,31],[2,12,44],[8,10,34],[13,6,29],[5,14,89],[11,16,13],[7,3,46],[10,15,19],[12,4,58],[13,16,11],[16,4,76],[2,0,12],[15,0,22],[16,12,13],[7,1,29],[7,14,100],[16,1,14],[9,6,74],[11,1,73],[2,11,60],[10,11,85],[2,5,49],[3,4,17],[4,9,77],[16,3,47],[15,6,78],[14,1,90],[10,5,95],[1,11,30],[11,0,37],[10,4,86],[0,8,57],[6,14,68],[16,8,3],[13,0,65],[2,13,6],[5,13,5],[8,11,31],[6,10,20],[6,2,33],[9,1,3],[14,9,58],[12,3,19],[11,2,74],[12,14,48],[16,11,100],[3,12,38],[12,13,77],[10,9,99],[15,13,98],[15,12,71],[1,4,28],[7,0,83],[3,5,100],[8,9,14],[15,11,57],[3,6,65],[1,3,45],[14,7,74],[2,10,39],[4,8,73],[13,5,77],[10,0,43],[12,9,92],[8,2,26],[1,7,7],[9,12,10],[13,11,64],[8,13,80],[6,12,74],[9,7,35],[0,15,48],[3,7,87],[16,9,42],[5,16,64],[4,5,65],[15,14,70],[12,0,13],[16,14,52],[3,10,80],[14,11,85],[15,2,77],[4,11,19],[2,7,49],[10,7,78],[14,6,84],[13,7,50],[11,6,75],[5,10,46],[13,8,43],[9,10,49],[7,12,64],[0,10,76],[5,9,77],[8,3,28],[11,9,28],[12,16,87],[12,6,24],[9,15,94],[5,7,77],[4,10,18],[7,2,11],[9,5,41]], src = 13, dst = 4, k = 13)"
206+
]
207+
}
208+
],
209+
"metadata": {
210+
"kernelspec": {
211+
"display_name": "base",
212+
"language": "python",
213+
"name": "python3"
214+
},
215+
"language_info": {
216+
"codemirror_mode": {
217+
"name": "ipython",
218+
"version": 3
219+
},
220+
"file_extension": ".py",
221+
"mimetype": "text/x-python",
222+
"name": "python",
223+
"nbconvert_exporter": "python",
224+
"pygments_lexer": "ipython3",
225+
"version": "3.12.4"
226+
}
227+
},
228+
"nbformat": 4,
229+
"nbformat_minor": 2
230+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
while True:
2+
pass
268 KB
Loading
267 KB
Loading
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
<!DOCTYPE html>
2+
<html lang="en">
3+
<head>
4+
<meta charset="UTF-8">
5+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
6+
<title>Sudoku Solver</title>
7+
</head>
8+
<body>
9+
<h1>Sudoku Solver</h1>
10+
<hr>
11+
<h3>To solve a sudoku select the image of the sudoku and upload it to the page then hit submit.</h3>
12+
<h3>The solution will be returned as an image on the next page.</h3>
13+
<div>
14+
<form action="" method="post" enctype="multipart/form-data">
15+
<input type="file" name="image" id="image" value={{request.files.image}}>
16+
<input type="submit" value="Submit" >
17+
</form>
18+
</div>
19+
</body>
20+
</html>
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
<!DOCTYPE html>
2+
<html lang="en">
3+
<head>
4+
<meta charset="UTF-8">
5+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
6+
<title>Solution</title>
7+
</head>
8+
<body>
9+
<hr>
10+
<a href="/">Back to Main Page</a>
11+
<hr>
12+
<div>
13+
<h1>Solution</h1>
14+
</div>
15+
<div>
16+
<img src="{{img}}" alt="img" style="height: max-content; width: max-content;">
17+
</div>
18+
</body>
19+
</html>
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import socket
2+
3+
result_file = "resultfile2_server.png"
4+
input_file = "f1.jpg"
5+
6+
def main(address:tuple[str,int]):
7+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
8+
sock.connect(address)
9+
sock.settimeout(10)
10+
with open(input_file, "rb") as f:
11+
sock.send(f.read())
12+
res_buf = b''
13+
try:
14+
while True:
15+
try:
16+
res = sock.recv(1)
17+
res_buf += res
18+
if 0 == len(res):
19+
sock.close()
20+
with open(result_file, "wb") as f:
21+
f.write(res_buf)
22+
break
23+
except socket.timeout:
24+
with open(result_file, "wb") as f:
25+
f.write(res_buf)
26+
break
27+
finally:
28+
sock.close()
29+
30+
if "__main__" == __name__:
31+
main(("localhost", 3535))
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""This code is to verify the image dataset and check that all the labels of the grid location are in the correct place.
2+
"""
3+
4+
import PIL.Image as Image
5+
from matplotlib import pyplot as plt
6+
import numpy as np
7+
from image import SudokuDataset, get_dataset, tqdm, Model
8+
import torch
9+
10+
img_size = (300,300)
11+
12+
def mark(positions, image, color_value):
13+
print(positions)
14+
print(image.shape)
15+
x0,y0,x1,y1,x2,y2,x3,y3 = positions
16+
image = image.transpose()
17+
grad = (y1 - y0)/(x1 - x0)
18+
if x1 > x0:
19+
for i in range(x1 - x0):
20+
image[x0 + i, int(y0 + i * grad)] = color_value
21+
else:
22+
for i in range(x0 - x1):
23+
image[x0 - i, int(y0 - i * grad)] = color_value
24+
25+
grad = (y2 - y1)/(x2 - x1)
26+
if x2 > x1:
27+
for i in range(x2 - x1):
28+
image[x1 + i, int(y1 + i * grad)] = color_value
29+
else:
30+
for i in range(x1 - x2):
31+
image[x1 - i, int(y1 - i * grad)] = color_value
32+
33+
grad = (y3 - y2)/(x3 - x2)
34+
if x3 > x2:
35+
for i in range(x3 - x2):
36+
image[x2 + i, int(y2 + i * grad)] = color_value
37+
else:
38+
for i in range(x2 - x3):
39+
image[x2 - i, int(y2 - i * grad)] = color_value
40+
41+
grad = (y0 - y3)/(x0 - x3)
42+
if x0 > x3:
43+
for i in range(x0 - x3):
44+
image[x3 + i, int(y3 + i * grad)] = color_value
45+
else:
46+
for i in range(x3 - x0):
47+
image[x3 - i, int(y3 - i * grad)] = color_value
48+
return image.transpose()
49+
50+
# dataset = SudokuDataset("./archive/outlines_sorted.csv", img_size)
51+
# for item in dataset:
52+
# try:
53+
# image = item['image']
54+
# grid = item['grid']
55+
# x0,y0,x1,y1,x2,y2,x3,y3 = list(grid.numpy())
56+
# x0 = int(x0 * img_size[0])
57+
# x1 = int(x1 * img_size[0])
58+
# x2 = int(x2 * img_size[0])
59+
# x3 = int(x3 * img_size[0])
60+
# y0 = int(y0 * img_size[1])
61+
# y1 = int(y1 * img_size[1])
62+
# y2 = int(y2 * img_size[1])
63+
# y3 = int(y3 * img_size[1])
64+
# image = mark((x0,y0,x1,y1,x2,y2,x3,y3), image.numpy()[0], 0.7)
65+
# plt.imshow(image)
66+
# plt.colorbar()
67+
# plt.show()
68+
# except KeyboardInterrupt:
69+
# break
70+
71+
def test(config:dict, model_filename:str):
72+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
73+
model = torch.load(model_filename).to(device)
74+
model.eval()
75+
loss = torch.nn.MSELoss().to(device)
76+
dataset = get_dataset(config['filename'], config['input_shape'], config['batch_size'])
77+
batch_iterator = tqdm(dataset)
78+
for batch in batch_iterator:
79+
x = batch['image'].to(device)
80+
y_true = batch['grid'].to(device)
81+
# print(batch['grid'])
82+
# return
83+
y_pred = model(x)
84+
error = loss(y_true, y_pred)
85+
batch_iterator.set_postfix({"loss":f"Loss: {error.item():6.6f}"})
86+
x0,y0,x1,y1,x2,y2,x3,y3 = list(y_pred.detach().numpy()[1])
87+
print(x0,y0,x1,y1,x2,y2,x3,y3)
88+
x0 = int(x0 * img_size[0])
89+
x1 = int(x1 * img_size[0])
90+
x2 = int(x2 * img_size[0])
91+
x3 = int(x3 * img_size[0])
92+
y0 = int(y0 * img_size[1])
93+
y1 = int(y1 * img_size[1])
94+
y2 = int(y2 * img_size[1])
95+
y3 = int(y3 * img_size[1])
96+
image = mark((x0,y0,x1,y1,x2,y2,x3,y3), x.detach().numpy()[0][0], 0.7)
97+
plt.imshow(image)
98+
plt.colorbar()
99+
plt.show()
100+
101+
config = {
102+
"input_shape": (300,300),
103+
"filename": "archive/outlines_sorted.csv",
104+
"number_of_layers": 4,
105+
"dims": 3,
106+
"batch_size": 8,
107+
"lr": 1e-5
108+
}
109+
# model = train(50, config)
110+
test(config, "model.pt")
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from flask import Flask, render_template, redirect, url_for, request, flash, session
2+
from werkzeug.utils import secure_filename
3+
import os
4+
from random import choices, choice
5+
from string import ascii_letters, digits
6+
from time import sleep
7+
from datetime import datetime
8+
import socket
9+
10+
app = Flask(__name__)
11+
12+
app.config.from_pyfile("config.cfg")
13+
14+
def manage_solution(input_file, result_file) -> int:
15+
def send(input_file:str, sock:socket.socket) -> int:
16+
try:
17+
with open(input_file, "rb") as f:
18+
sock.send(f.read())
19+
return 1
20+
except FileNotFoundError:
21+
return -2
22+
except socket.error:
23+
return -1
24+
25+
def connect() -> socket.socket:
26+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
27+
sock.connect((app.config['SOLVER_IP'], int(app.config['SOLVER_PORT'])))
28+
sock.settimeout(10)
29+
return sock
30+
31+
def manage_full_send(input_file:str, sock:socket.socket):
32+
tries = 0
33+
while tries < 5:
34+
send_state = send(input_file, sock)
35+
if send_state == 1:
36+
break
37+
elif send_state == -2:
38+
return -2
39+
elif send_state == -1:
40+
sock = connect()
41+
tries += 1
42+
return send_state
43+
44+
sock = connect()
45+
send_state = manage_full_send(input_file, sock)
46+
if send_state == -1:
47+
return -1
48+
elif send_state == -2:
49+
return -2
50+
res_buf = b''
51+
try:
52+
while True:
53+
try:
54+
res = sock.recv(1)
55+
res_buf += res
56+
if 0 == len(res):
57+
sock.close()
58+
with open(result_file, "wb") as f:
59+
f.write(res_buf)
60+
break
61+
except socket.timeout:
62+
with open(result_file, "wb") as f:
63+
f.write(res_buf)
64+
break
65+
finally:
66+
sock.close()
67+
return 0
68+
69+
@app.route('/', methods=['POST', 'GET'])
70+
def index():
71+
if "POST" == request.method:
72+
print(request)
73+
if 'image' not in request.files:
74+
flash('No file part.', "danger")
75+
else:
76+
file = request.files['image']
77+
if '' == file.filename:
78+
flash("No file selected.", "danger")
79+
else:
80+
ext = "." + file.filename.split('.')[-1]
81+
filename = datetime.now().strftime("%d%m%y%H%M%S") + "_" + "".join(i for i in choices(ascii_letters+digits, k=3)) + ext
82+
filename = os.path.join(app.config['UPLOAD_FOLDER'], filename)
83+
print(filename)
84+
file.save(filename)
85+
session['filename'] = filename
86+
return redirect(url_for('result'))
87+
else:
88+
if session.get('solved'):
89+
session.pop('solved')
90+
if session.get('filename'):
91+
try:
92+
os.remove(session['filename'])
93+
session.pop('filename')
94+
except FileNotFoundError:
95+
pass
96+
return render_template('index.html', request=request)
97+
98+
@app.route('/result', methods=['GET'])
99+
def result():
100+
if not session.get('solved'):
101+
filename = session.get('filename')
102+
if not filename:
103+
return redirect(url_for('/'))
104+
solution = ""
105+
result_file = ".".join(i for i in filename.split(".")[:-1]) + "_sol.png"
106+
result_file = result_file.split("/")[-1]
107+
full_result_file = "static/" + result_file
108+
result_file = f"../static/{result_file}"
109+
result = manage_solution(filename, full_result_file)
110+
os.remove(session['filename'])
111+
if result == 0:
112+
session['filename'] = full_result_file
113+
print("solved")
114+
solution = result_file
115+
session['solved'] = solution
116+
else:
117+
session.pop('filename')
118+
flash(f"There was an issue, Error {result}", "danger")
119+
redirect(url_for('/'))
120+
else:
121+
solution = session['solved']
122+
return render_template('result.html', img=solution)
123+
124+
if "__main__" == __name__:
125+
app.run(
126+
host="192.168.1.88",
127+
port=5000,
128+
debug=True
129+
)

0 commit comments

Comments
 (0)
Please sign in to comment.