Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why s_mp_sqr function doesn't consider inner product u larger than a single-precision integer. #492

Closed
mazhenke opened this issue Nov 23, 2020 · 3 comments
Labels

Comments

@mazhenke
Copy link

mazhenke commented Nov 23, 2020

Hi Tom,

I am reading the HAC book, and see that:
14.17 Note (computational efficiency of Algorithm 14.16)
image

But in your function s_mp_sqr, I didn't see special handling for r (the u,v value in 14.16 Algorithm) larger than 3 single-precision integers, is this a bug?

/* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
mp_err s_mp_sqr(const mp_int *a, mp_int *b)
{
   mp_int   t;
   int      ix, pa;
   mp_err   err;

   pa = a->used;
   if ((err = mp_init_size(&t, (2 * pa) + 1)) != MP_OKAY) {
      return err;
   }

   /* default used is maximum possible size */
   t.used = (2 * pa) + 1;

   for (ix = 0; ix < pa; ix++) {
      mp_digit u;
      int iy;

      /* first calculate the digit at 2*ix */
      /* calculate double precision result */
      mp_word r = (mp_word)t.dp[2*ix] +
                  ((mp_word)a->dp[ix] * (mp_word)a->dp[ix]);

      /* store lower part in result */
      t.dp[ix+ix] = (mp_digit)(r & (mp_word)MP_MASK);

      /* get the carry */
      u           = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);

      for (iy = ix + 1; iy < pa; iy++) {
         /* first calculate the product */
         r       = (mp_word)a->dp[ix] * (mp_word)a->dp[iy];

         /* now calculate the double precision result, note we use
          * addition instead of *2 since it's easier to optimize
          */
         r       = (mp_word)t.dp[ix + iy] + r + r + (mp_word)u;    **-------> r may be larger than 3 mp_digit**

         /* store lower part */
         t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);

         /* get carry */
         u       = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
      }
      /* propagate upwards */
      while (u != 0uL) {
         r       = (mp_word)t.dp[ix + iy] + (mp_word)u;
         t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
         u       = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
         ++iy;
      }
   }

   mp_clamp(&t);
   mp_exch(&t, b);
   mp_clear(&t);
   return MP_OKAY;
}
@czurnieden
Copy link
Contributor

There is at least one unused bit per mp_digit. The 16 bit version of LTM uses 15 bit large mp_digits for example. The mp_word is twice the size of mp_digit and has hence two unused bits (assuming "normal" architectures, of course!). So far, so good? OK:

Let b = 15 be the number of used bits in an mp_digit assuming at least 16 bit large unsigned ints allowing for 16 useable bits and let B = 32 the number of useable bits in mp_word assuming at least 32 bit large unsigned longss.
Than (2^15)^2 = 2^(2*15) = 2^30 < 2^32 so it is not only (2^b)^2 < 2^B but also 2*(2^b)^2 < B because 2*2^30 = 2^(30 + 1) = 2^31 < 2^32. That together with 2*2^15 = 2^(15 + 1) = 2^16 enables us to state that 2*(2^b) + 2*(2^b)^2 = 2^16 + 2^31 < 2^32 which will fit into a mp_word.

But we can use only 30 bits for the two mp_digits, we have more than twice as much at the end! Let's take a look at the actual numbers: the maximum value of a mp_digit is 2^15 -1 = 0x7FFF. If we keep the hexadecimal representation we get 0x7FFF + 0x3FFF0001 + 0x3FFF0001 + 0x7FFF = 0x7FFF0000. The two mp_digits are high = (mp_digit)(0x7FFF0000 >> 15) = 0xFFFE = 2^16 -2 which is too big and low = (mp_digit)(0x7FFF0000 & ( (1<<15) - 1) = 0.

But those results are not used in the output.

Let's follow all of these numbers along their merry ways for MP_16BIT. All a->dp[0..m-1] = 0x7FFF in that case.

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <inttypes.h>
/* Adapt to your paths */
#include "/home/czurnieden/GITHUB/libtommath/tommath.h"
#include "/home/czurnieden/GITHUB/libtommath/tommath_private.h"


#if defined(MP_16BIT)
#define PRINT_WORD      PRIx32
#define PRINT_DIGIT     PRIx16
#elif defined(MP_64BIT)
#define PRINT_WORD      "llx"
#define PRINT_DIGIT     PRIx64
#else
#define PRINT_WORD      PRIx64
#define PRINT_DIGIT     PRIx32
#endif


/* clang -Weverything -std=c11 -DMP_16BIT -o check_squareing check_squareing.c /home/czurnieden/GITHUB/libtommath/libtommath.a */
/* clang -Weverything -std=c11 -DMP_64BIT -o check_squareing check_squareing.c /home/czurnieden/GITHUB/libtommath/libtommath.a */
static mp_err local_s_mp_sqr(const mp_int *a, mp_int *b)
{
   mp_int   t;
   int      ix, pa;
   mp_err   err;

   pa = a->used;
   if ((err = mp_init_size(&t, (2 * pa) + 1)) != MP_OKAY) {
      return err;
   }

   /* default used is maximum possible size */
   t.used = (2 * pa) + 1;

   printf("pa = %d\n", pa);

   for (ix = 0; ix < pa; ix++) {
      mp_digit u;
      int iy;

      /* first calculate the digit at 2*ix */
      /* calculate double precision result */
      printf("\nOUTER START ix = %d\n", ix);
      mp_word r = (mp_word)t.dp[2*ix] +
                  ((mp_word)a->dp[ix] * (mp_word)a->dp[ix]);
      printf("   r = %" PRINT_WORD "\n", r);
      /* store lower part in result */
      t.dp[ix+ix] = (mp_digit)(r & (mp_word)MP_MASK);
      printf("   t.dp[%d] = %" PRINT_DIGIT "\n", ix + ix, t.dp[ix+ix]);

      /* get the carry */
      u           = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
      printf("   u = %" PRINT_DIGIT "\n", u);
      printf("   INNER START ix = %d, iy = %d\n", ix, ix + 1);
      for (iy = ix + 1; iy < pa; iy++) {
         /* first calculate the product */
         r       = (mp_word)a->dp[ix] * (mp_word)a->dp[iy];
         printf("      r = %" PRINT_WORD "\n", r);
         /* now calculate the double precision result, note we use
          * addition instead of *2 since it's easier to optimize
          */
         r       = (mp_word)t.dp[ix + iy] + r + r + (mp_word)u;
         printf("      r = %" PRINT_WORD "\n", r);
         /* store lower part */
         t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
         printf("      t.dp[%d] = %" PRINT_DIGIT "\n", ix + iy, t.dp[ix+iy]);
         /* get carry */
         u       = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
         printf("      u = %" PRINT_DIGIT "\n", u);
      }
      /* propagate upwards */
      printf("   WHILE LOOP START ix = %d, iy = %d\n", ix, iy);
      while (u != 0uL) {
         r       = (mp_word)t.dp[ix + iy] + (mp_word)u;
         printf("      r = %" PRINT_WORD "\n", r);
         t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
         printf("      t.dp[%d] = %" PRINT_DIGIT "\n", ix + iy, t.dp[ix+iy]);
         u       = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
         printf("      u = %" PRINT_DIGIT "\n", u);
         ++iy;
         printf("      iy = %d\n", iy);
      }
      printf("   WHILE LOOP END ix = %d, iy = %d\n", ix, iy);
   }

   mp_clamp(&t);
   mp_exch(&t, b);
   mp_clear(&t);
   return MP_OKAY;
}

int main(void)
{
   mp_int a, c;
   int err, m, n;

   printf("MP_DIGIT_BIT = %d\n", MP_DIGIT_BIT);

   if ((err = mp_init_multi(&a, &c, NULL)) != MP_OKAY) {
      fprintf(stderr,"Something went wrong inside LTM: %s\n", mp_error_to_string(err));
      exit(EXIT_FAILURE);
   }

   m = 3;
   /* Produce numbers of the form 2^n -1 = 0b111...111 so all a.dp[0..m-1] = MP_MASK */
   n = m * (MP_DIGIT_BIT);
   if ((err = mp_2expt(&a, n)) != MP_OKAY)                                            goto LTM_ERR;
   if ((err = mp_decr(&a)) != MP_OKAY)                                                goto LTM_ERR;

   if ((err = local_s_mp_sqr(&a, &c)) != MP_OKAY)                                     goto LTM_ERR;

   fprintf(stderr,"\na   = 0x");
   if ((err = mp_fwrite(&a, 16, stderr)) != MP_OKAY)                                  goto LTM_ERR;
   fprintf(stderr,"\n");
   fprintf(stderr,"a^2 = 0x");
   if ((err = mp_fwrite(&c, 16, stderr)) != MP_OKAY)                                  goto LTM_ERR;
   fprintf(stderr,"\n");


   err = MP_OKAY;
LTM_ERR:
   mp_clear_multi(&a, &c, NULL);
   if (err != MP_OKAY) {
      fprintf(stderr,"Something went wrong inside LTM: %s\n", mp_error_to_string(err));
      exit(EXIT_FAILURE);
   }
   exit(EXIT_SUCCESS);
}

That prints:

MP_DIGIT_BIT = 15
pa = 3

OUTER START ix = 0
   r = 3fff0001
   t.dp[0] = 1
   u = 7ffe
   INNER START ix = 0, iy = 1
      r = 3fff0001
      r = 7ffe8000
      t.dp[1] = 0
      u = fffd
      r = 3fff0001
      r = 7ffeffff
      t.dp[2] = 7fff
      u = fffd
   WHILE LOOP START ix = 0, iy = 3
      r = fffd
      t.dp[3] = 7ffd
      u = 1
      iy = 4
      r = 1
      t.dp[4] = 1
      u = 0
      iy = 5
   WHILE LOOP END ix = 0, iy = 5

OUTER START ix = 1
   r = 3fff8000
   t.dp[2] = 0
   u = 7fff
   INNER START ix = 1, iy = 2
      r = 3fff0001
      r = 7ffefffe
      t.dp[3] = 7ffe
      u = fffd
   WHILE LOOP START ix = 1, iy = 3
      r = fffe
      t.dp[4] = 7ffe
      u = 1
      iy = 4
      r = 1
      t.dp[5] = 1
      u = 0
      iy = 5
   WHILE LOOP END ix = 1, iy = 5

OUTER START ix = 2
   r = 3fff7fff
   t.dp[4] = 7fff
   u = 7ffe
   INNER START ix = 2, iy = 3
   WHILE LOOP START ix = 2, iy = 3
      r = 7fff
      t.dp[5] = 7fff
      u = 0
      iy = 4
   WHILE LOOP END ix = 2, iy = 4

a   = 0x1FFFFFFFFFFF
a^2 = 0x3FFFFFFFFFFC00000000001

As you can see: none of the output digits in t are too large, only some of the intermediate values. The carry u is bigger than 0x7FFF but always smaller than 0xFFFF which makes it fit into a mp_digit and we have shown above that the biggest values always fit into a mp_word.

Proper proof left as an exercise for the dear reader and if you actually do a proof of correctness you are more than welcome to put it into /doc. If you are not familiar with Latex: just ask, we'll help.

@czurnieden
Copy link
Contributor

Lack of further comments leads to the assumption that it is safe to close this issue.

@mazhenke
Copy link
Author

@czurnieden

Hi czurnieden,

Thanks, seems the extra 2 bits in mp_word works. And here I have some simple proof, I think it is correct, just for your reference:

(u, v) = w[i+j] + 2 * x[j] * x[i] + c
r = w[i+j] + 2 * x[j] * x[i] + c

Let b be the max of mp_digit bits, then:
image

and the result r:
image

Because mp_word bits is twice size of mp_digit + 2, that is:
2 x b + 2
and Max of mp_word can be:
image

and obviously, we got:
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants