Skip to content

Commit cf0ba2f

Browse files
Add xor_codec script (LAION-AI#2824)
Original PR history was broken. Encode: xor_codec.py output_dir/ model_dir/ llama_dir/ --encode Decode: xor_codec.py output_dir/ delta_dir/ llama_dir/ --------- Co-authored-by: umbra-scientia <44982143+umbra-scientia@users.noreply.github.com>
1 parent b6c4f72 commit cf0ba2f

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

scripts/xor-codec/xor_codec.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import gzip
2+
import os
3+
import shutil
4+
import sys
5+
from pathlib import Path
6+
7+
import numpy
8+
9+
10+
def xor_uncompressed(dst, src_payload, src_base, block_size=4096):
11+
fp_payload = open(src_payload, "rb")
12+
fp_base = open(src_base, "rb")
13+
with open(dst, "wb") as fp:
14+
while True:
15+
buf1 = numpy.array(bytearray(fp_payload.read(block_size)), dtype=numpy.uint8)
16+
buf2 = numpy.array(bytearray(fp_base.read(block_size)), dtype=numpy.uint8)
17+
padding = len(buf1) - len(buf2)
18+
if padding > 0:
19+
buf2 = numpy.pad(buf2, (0, padding), "constant", constant_values=(0,))
20+
if padding < 0:
21+
buf2 = buf2[: len(buf1)]
22+
buf = numpy.bitwise_xor(buf1, buf2)
23+
fp.write(buf)
24+
if len(buf1) < block_size:
25+
break
26+
fp_payload.close()
27+
fp_base.close()
28+
29+
30+
def xor_encode(dst, src_payload, src_base, block_size=4096):
31+
fp_payload = open(src_payload, "rb")
32+
fp_base = open(src_base, "rb")
33+
with gzip.open(dst, "wb") as fp:
34+
while True:
35+
buf1 = numpy.array(bytearray(fp_payload.read(block_size)), dtype=numpy.uint8)
36+
buf2 = numpy.array(bytearray(fp_base.read(block_size)), dtype=numpy.uint8)
37+
padding = len(buf1) - len(buf2)
38+
if padding > 0:
39+
buf2 = numpy.pad(buf2, (0, padding), "constant", constant_values=(0,))
40+
if padding < 0:
41+
buf2 = buf2[: len(buf1)]
42+
buf = numpy.bitwise_xor(buf1, buf2)
43+
fp.write(buf)
44+
if len(buf1) < block_size:
45+
break
46+
fp_payload.close()
47+
fp_base.close()
48+
49+
50+
def xor_decode(dst, src_payload, src_base, block_size=4096):
51+
fp_payload = gzip.open(src_payload, "rb")
52+
fp_base = open(src_base, "rb")
53+
with open(dst, "wb") as fp:
54+
while True:
55+
buf1 = numpy.array(bytearray(fp_payload.read(block_size)), dtype=numpy.uint8)
56+
buf2 = numpy.array(bytearray(fp_base.read(block_size)), dtype=numpy.uint8)
57+
padding = len(buf1) - len(buf2)
58+
if padding > 0:
59+
buf2 = numpy.pad(buf2, (0, padding), "constant", constant_values=(0,))
60+
if padding < 0:
61+
buf2 = buf2[: len(buf1)]
62+
buf = numpy.bitwise_xor(buf1, buf2)
63+
fp.write(buf)
64+
if len(buf1) < block_size:
65+
break
66+
fp_payload.close()
67+
fp_base.close()
68+
69+
70+
def xor_dir(dst, src_payload, src_base, decode=True, compress=True):
71+
if compress:
72+
xor = xor_decode if decode else xor_encode
73+
else:
74+
xor = xor_uncompressed
75+
Path(dst).mkdir(parents=True, exist_ok=True)
76+
shutil.copy(Path(src_payload) / "added_tokens.json", Path(dst) / "added_tokens.json")
77+
for path in os.listdir(src_payload):
78+
print("[*] Processing '%s'" % path)
79+
try:
80+
xor("%s/%s" % (dst, path), "%s/%s" % (src_payload, path), "%s/%s" % (src_base, path))
81+
except Exception:
82+
print("Exception when processing '%s'" % path)
83+
84+
85+
if __name__ == "__main__":
86+
if len(sys.argv) < 4:
87+
print("Usage: xor.py <DESTINATION> <PAYLOAD SOURCE> <LLAMA SOURCE> [--encode] [--compress]")
88+
exit()
89+
dst = sys.argv[1]
90+
src_payload = sys.argv[2]
91+
src_base = sys.argv[3]
92+
decode = True
93+
compress = False
94+
if len(sys.argv) > 4:
95+
for arg in sys.argv[4:]:
96+
if arg == "--encode":
97+
decode = False
98+
if arg == "--compress":
99+
compress = True
100+
xor_dir(dst, src_payload, src_base, decode=decode, compress=compress)

0 commit comments

Comments
 (0)