-
Notifications
You must be signed in to change notification settings - Fork 199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why s_mp_sqr function doesn't consider inner product u larger than a single-precision integer. #492
Comments
There is at least one unused bit per Let But we can use only 30 bits for the two But those results are not used in the output. Let's follow all of these numbers along their merry ways for #include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <inttypes.h>
/* Adapt to your paths */
#include "/home/czurnieden/GITHUB/libtommath/tommath.h"
#include "/home/czurnieden/GITHUB/libtommath/tommath_private.h"
#if defined(MP_16BIT)
#define PRINT_WORD PRIx32
#define PRINT_DIGIT PRIx16
#elif defined(MP_64BIT)
#define PRINT_WORD "llx"
#define PRINT_DIGIT PRIx64
#else
#define PRINT_WORD PRIx64
#define PRINT_DIGIT PRIx32
#endif
/* clang -Weverything -std=c11 -DMP_16BIT -o check_squareing check_squareing.c /home/czurnieden/GITHUB/libtommath/libtommath.a */
/* clang -Weverything -std=c11 -DMP_64BIT -o check_squareing check_squareing.c /home/czurnieden/GITHUB/libtommath/libtommath.a */
static mp_err local_s_mp_sqr(const mp_int *a, mp_int *b)
{
mp_int t;
int ix, pa;
mp_err err;
pa = a->used;
if ((err = mp_init_size(&t, (2 * pa) + 1)) != MP_OKAY) {
return err;
}
/* default used is maximum possible size */
t.used = (2 * pa) + 1;
printf("pa = %d\n", pa);
for (ix = 0; ix < pa; ix++) {
mp_digit u;
int iy;
/* first calculate the digit at 2*ix */
/* calculate double precision result */
printf("\nOUTER START ix = %d\n", ix);
mp_word r = (mp_word)t.dp[2*ix] +
((mp_word)a->dp[ix] * (mp_word)a->dp[ix]);
printf(" r = %" PRINT_WORD "\n", r);
/* store lower part in result */
t.dp[ix+ix] = (mp_digit)(r & (mp_word)MP_MASK);
printf(" t.dp[%d] = %" PRINT_DIGIT "\n", ix + ix, t.dp[ix+ix]);
/* get the carry */
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
printf(" u = %" PRINT_DIGIT "\n", u);
printf(" INNER START ix = %d, iy = %d\n", ix, ix + 1);
for (iy = ix + 1; iy < pa; iy++) {
/* first calculate the product */
r = (mp_word)a->dp[ix] * (mp_word)a->dp[iy];
printf(" r = %" PRINT_WORD "\n", r);
/* now calculate the double precision result, note we use
* addition instead of *2 since it's easier to optimize
*/
r = (mp_word)t.dp[ix + iy] + r + r + (mp_word)u;
printf(" r = %" PRINT_WORD "\n", r);
/* store lower part */
t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
printf(" t.dp[%d] = %" PRINT_DIGIT "\n", ix + iy, t.dp[ix+iy]);
/* get carry */
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
printf(" u = %" PRINT_DIGIT "\n", u);
}
/* propagate upwards */
printf(" WHILE LOOP START ix = %d, iy = %d\n", ix, iy);
while (u != 0uL) {
r = (mp_word)t.dp[ix + iy] + (mp_word)u;
printf(" r = %" PRINT_WORD "\n", r);
t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
printf(" t.dp[%d] = %" PRINT_DIGIT "\n", ix + iy, t.dp[ix+iy]);
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
printf(" u = %" PRINT_DIGIT "\n", u);
++iy;
printf(" iy = %d\n", iy);
}
printf(" WHILE LOOP END ix = %d, iy = %d\n", ix, iy);
}
mp_clamp(&t);
mp_exch(&t, b);
mp_clear(&t);
return MP_OKAY;
}
int main(void)
{
mp_int a, c;
int err, m, n;
printf("MP_DIGIT_BIT = %d\n", MP_DIGIT_BIT);
if ((err = mp_init_multi(&a, &c, NULL)) != MP_OKAY) {
fprintf(stderr,"Something went wrong inside LTM: %s\n", mp_error_to_string(err));
exit(EXIT_FAILURE);
}
m = 3;
/* Produce numbers of the form 2^n -1 = 0b111...111 so all a.dp[0..m-1] = MP_MASK */
n = m * (MP_DIGIT_BIT);
if ((err = mp_2expt(&a, n)) != MP_OKAY) goto LTM_ERR;
if ((err = mp_decr(&a)) != MP_OKAY) goto LTM_ERR;
if ((err = local_s_mp_sqr(&a, &c)) != MP_OKAY) goto LTM_ERR;
fprintf(stderr,"\na = 0x");
if ((err = mp_fwrite(&a, 16, stderr)) != MP_OKAY) goto LTM_ERR;
fprintf(stderr,"\n");
fprintf(stderr,"a^2 = 0x");
if ((err = mp_fwrite(&c, 16, stderr)) != MP_OKAY) goto LTM_ERR;
fprintf(stderr,"\n");
err = MP_OKAY;
LTM_ERR:
mp_clear_multi(&a, &c, NULL);
if (err != MP_OKAY) {
fprintf(stderr,"Something went wrong inside LTM: %s\n", mp_error_to_string(err));
exit(EXIT_FAILURE);
}
exit(EXIT_SUCCESS);
} That prints:
As you can see: none of the output digits in Proper proof left as an exercise for the dear reader and if you actually do a proof of correctness you are more than welcome to put it into |
Lack of further comments leads to the assumption that it is safe to close this issue. |
Hi Tom,
I am reading the HAC book, and see that:

14.17 Note (computational efficiency of Algorithm 14.16)
But in your function s_mp_sqr, I didn't see special handling for r (the u,v value in 14.16 Algorithm) larger than 3 single-precision integers, is this a bug?
The text was updated successfully, but these errors were encountered: