;;;; bignum.scm: A user-level unsigned bignum arithmetic library ;;; J. Welsh, August 2017 ;;; Trimmed for gbw, March 2020 ;; A bignum is a list of words, least significant first. It must not have trailing zeros. Thus each number has a unique representation, and zero is the empty list. (lambda () ;;; Constants (define base-nibbles (quotient *fixnum-width* 8)) (define base-bits (delay (fx*/wrap base-nibbles 4))) (define base/2 (delay (expt 2 (- base-bits 1)))) (define base (delay (* 2 base/2))) ;; must be <= sqrt of largest fixnum (define base-1 (delay (- base 1))) (define neg-base-bits (delay (- base-bits))) (define - fx-/wrap) (define + fx+/wrap) (define * fx*/wrap) (define (zero? x) (fx= x 0)) (define (even? x) (fx= (fxand x 1) 0)) (define = fx=) (define < fx<) (define <= fx<=) (define hex "0123456789abcdef") (define char0 (char->integer #\0)) (define char10-A (fx-/wrap 10 (char->integer #\A))) (define bn0 '()) (define bn1 '(1)) ;;; Helpers (define (fix->hex n) ;; note 0 -> empty string (do ((n n (fxshift n -4)) (acc '() (cons (string-ref hex (fxand n 15)) acc))) ((zero? n) (list->string acc)))) (define (hexdigit->fix c) (if (char-numeric? c) (- (char->integer c) char0) (let ((i (+ (char->integer (char-upcase c)) char10-A))) (if (and (<= 10 i) (< i 16)) i (error "bad hex digit:" c))))) (define (decdigit->fix c) (if (char-numeric? c) (- (char->integer c) char0) (error "bad decimal digit:" c))) (define (left-pad s len char) (string-append (make-string (- len (string-length s)) char) s)) (define (bn-pad-word->hex w) (left-pad (fix->hex w) base-nibbles #\0)) (define (word->bn w) (if (zero? w) '() (list w))) ;; Construct bignum from big-endian, vector-like sequence of nibbles. Ugly, but linear time. (define (nibbles->bn nibble-ref len) (define (loop-words start acc) (if (= start len) acc (let* ((next (+ start base-nibbles)) (word (get-word start (- next 1)))) (loop-words next (if (and (null? acc) (zero? word)) acc (cons word acc)))))) (define (get-word start stop) (define (loop start acc) (if (< stop start) acc (loop (+ start 1) (+ (* 16 acc) (nibble-ref start))))) (loop (+ start 1) (nibble-ref start))) (if (zero? len) '() (let* ((msw-end (remainder (- len 1) base-nibbles)) (msw (get-word 0 msw-end))) (loop-words (+ msw-end 1) (word->bn msw))))) ;; Rather than "shift left/right", which unnecessarily invoke endianness, I'm using "shift" for multiplications and "unshift" for divisions. (define (shift-words a k) (if (null? a) a (shift-words-nz a k))) (define (shift-words-nz a k) (if (zero? k) a (shift-words-nz (cons 0 a) (- k 1)))) (define (unshift-words a k) (if (or (zero? k) (null? a)) a (unshift-words (cdr a) (- k 1)))) ;;; Type conversion (define (bn->hex n) (let ((n (reverse n))) (if (null? n) "0" (apply string-append (fix->hex (car n)) ;; optimize? (map bn-pad-word->hex (cdr n)))))) (define (hex->bn s) (nibbles->bn (lambda (k) (hexdigit->fix (string-ref s k))) (string-length s))) (define (bytes->bn v) (nibbles->bn (lambda (k) (let ((byte (vector-ref v (fxshift k -1)))) (if (even? k) (fxshift byte -4) ;; big-endian (fxand byte 15)))) (* 2 (vector-length v)))) ;; ~Cubic algorithm! (define (bn->dec n) (let loop ((n n) (acc '())) (if (null? n) (list->string acc) (bn-divrem n '(10) (lambda (q r) (loop q (cons (string-ref hex (bn->fix r)) acc))))))) ;; Quadratic algorithm! (define (dec->bn s) (do ((i 0 (+ i 1)) (acc bn0 (bn+fix (bn*fix acc 10) (decdigit->fix (string-ref s i))))) ((= i (string-length s)) acc))) ;; Can overflow (obviously) (define (bn->fix n) (do ((n (reverse n) (cdr n)) (acc 0 (+ (* acc base) (car n)))) ((null? n) acc))) (define (fix->bn n) (do ((n n (fxshift n neg-base-bits)) (acc '() (cons (fxand n base-1) acc))) ((zero? n) (reverse acc)))) ;;; Predicates (define bn-zero? null?) (define (bn-even? a) (or (null? a) (even? (car a)))) (define (bn-odd? a) (not (bn-even? a))) (define (cmp a b) (cond ((< a b) -1) ((< b a) 1) (else 0))) (define (bn-cmp a b) (cond ((null? a) (if (null? b) 0 -1)) ((null? b) 1) (else (let ((c (bn-cmp (cdr a) (cdr b)))) (if (zero? c) (cmp (car a) (car b)) c))))) (define bn= equal?) (define (bn< a b) (< (bn-cmp a b) 0)) (define (bn> a b) (< 0 (bn-cmp a b))) (define (bn<= a b) (<= (bn-cmp a b) 0)) (define (bn>= a b) (<= 0 (bn-cmp a b))) ;;; Addition (define (bn+1 a) (if (null? a) bn1 (let ((head (car a))) (if (= head base-1) (cons 0 (bn+1 (cdr a))) (cons (+ head 1) (cdr a)))))) (define (bn+ a b) (cond ((null? a) b) ((null? b) a) (else (let ((sum (+ (car a) (car b)))) (if (< sum base) (cons sum (bn+ (cdr a) (cdr b))) (cons (- sum base) (bn+carry (cdr a) (cdr b)))))))) (define (bn+carry a b) (cond ((null? a) (bn+1 b)) ((null? b) (bn+1 a)) (else (let ((sum (+ (car a) (car b) 1))) (if (< sum base) (cons sum (bn+ (cdr a) (cdr b))) (cons (- sum base) (bn+carry (cdr a) (cdr b)))))))) ;; CAUTION: assumes 0 <= b < base (define (bn+fix a b) (cond ((zero? b) a) ((null? a) (list b)) (else (let ((sum (+ (car a) b))) (if (< sum base) (cons sum (cdr a)) (cons (- sum base) (bn+1 (cdr a)))))))) ;;; Subtraction (define (bn-1 a) (if (null? a) (error "bn-1: subtract from zero")) (let ((head (car a)) (tail (cdr a))) (cond ((zero? head) (cons base-1 (bn-1 tail))) ((and (= head 1) (null? tail)) '()) (else (cons (- head 1) tail))))) (define (bn- a b) (cond ((null? a) (if (null? b) b (error "bn-: subtract from zero"))) ((null? b) a) (else (let ((diff (- (car a) (car b)))) (if (< diff 0) (cons (+ diff base) (bn-sub-borrow (cdr a) (cdr b))) (let ((tail (bn- (cdr a) (cdr b)))) (if (and (= diff 0) (null? tail)) '() (cons diff tail)))))))) (define (bn-sub-borrow a b) (cond ((null? a) (error "bn-: subtract from zero")) ((null? b) (bn-1 a)) (else (let ((diff (- (car a) (car b) 1))) (if (< diff 0) (let ((tail (bn-sub-borrow (cdr a) (cdr b))) (diff (+ diff base))) (if (and (= diff 0) (null? tail)) '() (cons diff tail))) (let ((tail (bn- (cdr a) (cdr b)))) (if (and (= diff 0) (null? tail)) '() (cons diff tail)))))))) ;;; Multiplication (define (bn*2 a) (if (null? a) '() (let ((product (* (car a) 2))) (if (< product base) (cons product (bn*2 (cdr a))) (cons (- product base) (bn*2+carry (cdr a))))))) (define (bn*2+carry a) (if (null? a) bn1 (let ((product (+ (* (car a) 2) 1))) (if (< product base) (cons product (bn*2 (cdr a))) (cons (- product base) (bn*2+carry (cdr a))))))) ;; CAUTION: assumes 0 <= scale < base (define (bn*fix a scale) (if (or (null? a) (zero? scale)) '() (let ((product (* (car a) scale))) (if (< product base) (cons product (bn*fix (cdr a) scale)) (cons (fxand product base-1) (bn*fix+carry (cdr a) scale (fxshift product neg-base-bits))))))) (define (bn*fix+carry a scale carry) (if (or (null? a) (zero? scale)) (list carry) (let ((product (+ (* (car a) scale) carry))) (if (< product base) (cons product (bn*fix (cdr a) scale)) (cons (fxand product base-1) (bn*fix+carry (cdr a) scale (fxshift product neg-base-bits))))))) (define (bn-shift a bits) (cond ((< bits 0) (error "bn-shift: negative bits")) ((null? a) a) (else (bn*fix (shift-words-nz a (quotient bits base-bits)) (expt 2 (remainder bits base-bits)))))) (define (bn* a b) (define (a* b) (if (null? b) b (bn+ (bn*fix a (car b)) (shift-words (a* (cdr b)) 1)))) (if (null? a) a (a* b))) ;; Still quadratic, but ~30% faster than generic multiplication (define (bn^2 a) (if (null? a) '() (let* ((hd (car a)) (tl (cdr a)) (hd^2 (* hd hd)) (hd^2 (if (< hd^2 base) (word->bn hd^2) (list (fxand hd^2 base-1) (fxshift hd^2 neg-base-bits))))) (if (null? tl) hd^2 (bn+ hd^2 (cons 0 (bn+ (cons 0 (bn^2 tl)) (bn*fix (bn*2 tl) hd)))))))) (define (strip-leading-zeros l) (cond ((null? l) l) ((zero? (car l)) (strip-leading-zeros (cdr l))) (else l))) (define (bn-split a k cont) (do ((head '() (cons (car tail) head)) (tail a (cdr tail)) (k k (- k 1))) ((or (null? tail) (zero? k)) (cont (reverse (strip-leading-zeros head)) tail)))) ;;; Division (define (bn/2 a) (if (null? a) a (cdr (bn*fix a base/2)))) (define (bn-unshift a bits) (if (< bits 0) (error "bn-unshift: negative bits")) (let* ((full-words (quotient bits base-bits)) (extra-bits (remainder bits base-bits)) (a (unshift-words a full-words))) (if (or (null? a) (zero? extra-bits)) a (cdr (bn*fix a (expt 2 (- base-bits extra-bits))))))) (define (num-bit-shifts start target) ;; optimize? (define (loop start n) (if (<= target start) n (loop (* start 2) (+ n 1)))) (loop start 0)) (define (last l) (if (null? (cdr l)) (car l) (last (cdr l)))) (define (bn-divrem a b return) (if (null? b) (error "division by zero")) ;; Normalize: most sig. bit of most sig. word of divisor must be 1 (let ((msw (last b))) (if (<= base/2 msw) (div-normalized a b return) (let ((s (expt 2 (num-bit-shifts msw base/2)))) (div-normalized (bn*fix a s) (bn*fix b s) ;; optimize (lambda (q r) (return q (if (null? r) r (cdr (bn*fix r (quotient base s))))))))))) (define (divrem-q q r) q) (define (divrem-r q r) r) (define (bn-quotient a b) (bn-divrem a b divrem-q)) ;; optimize? (define (bn-remainder a b) (bn-divrem a b divrem-r)) (define (slice-2 a k) (if (null? a) 0 (let ((tail (cdr a))) (if (zero? k) (if (null? tail) (car a) (+ (car a) (* (car tail) base))) (slice-2 tail (- k 1)))))) (define (most-sig-word a) ;; assumes not null (define (loop a tail) (if (null? tail) (car a) (loop tail (cdr tail)))) (loop a (cdr a))) (define (div-normalized A B return) (let* ((n (length B)) (m (- (length A) n))) (if (< m 0) (return '() A) (let ((n-1 (- n 1)) (B-msw (most-sig-word B))) (define (get-qi i Q A) (define (try-qi qi) (let ((prod (shift-words (bn*fix B qi) i))) (if (bn< A prod) (try-qi (- qi 1)) (get-qi (- i 1) (cons qi Q) (bn- A prod))))) (if (< i 0) (return Q A) (try-qi (min base-1 (quotient (slice-2 A (+ n-1 i)) B-msw))))) (let ((B-shift (shift-words B m))) (if (bn< A B-shift) (get-qi (- m 1) '() A) (get-qi (- m 1) '(1) (bn- A B-shift)))))))) ;; Multiplicative inverse of a mod n: ;; (bn-remainder (bn* a (bn-mod-inverse a n)) n) -> bn1 ;; Assumes reduced input (0 < a < n) (define (bn-mod-inverse a n) ;; Extended Euclidean algorithm: find x where ax + by = gcd(a, b) ;; If the gcd is 1, it follows that ax == 1 mod b ;; See [HAC] Algorithm 2.142 / 2.107 ;; Simplified / adjusted for unsigned bignums ;; Invariants: ;; [1] 0 <= r < b (loop terminates when b would reach 0) ;; [2] 0 < b < a (by [1], as b is last r and a is last b) ;; [3] q > 0 (by [2]) ;; [4] If x > 0, then last x <= 0 and next x < 0 ;; If x < 0, then last x > 0 and next x > 0 ;; (by [3], as next x is last x - qx) ;; Full proofs (mostly by induction) left as an exercise to the reader. (define (loop a b x neg last-x) (bn-divrem a b (lambda (q r) (if (bn-zero? r) (if (bn= b bn1) (if neg (bn- n x) x) (error "not invertible (modulus not prime?)")) (loop b r (bn+ last-x (bn* q x)) (not neg) x))))) (loop n a bn1 #f bn0)) ;;; Misc ;; Number of significant bits ;; = least integer b such that 2^b > a ;; = ceil(log_2(a+1)) (define (bn-bits a) (if (null? a) 0 (+ (num-bit-shifts 1 (+ (last a) 1)) (* (- (length a) 1) base-bits)))) (define (read-bytes n port) (let ((v (make-vector n))) (do ((k 0 (+ k 1))) ((= k n) v) (vector-set! v k (char->integer (read-char port)))))) (define (ceil-quotient a b) (quotient (+ a b -1) b)) ; Unbiased random integer generator in the interval [0, n) (define (rand-bn n) (if (bn-zero? n) (error "rand-bn: zero range")) ;; Collecting one more byte than strictly necessary avoids cases where a large part of the range is invalid (e.g. n=130) (let* ((nbytes (+ (ceil-quotient (bn-bits (bn-1 n)) 8) 1)) (rand-range (bn-shift bn1 (* nbytes 8))) (unbiased-range (bn- rand-range (bn-remainder rand-range n)))) (lambda (rng-port) (define (retry) (let ((r (bytes->bn (read-bytes nbytes rng-port)))) (if (bn< r unbiased-range) (bn-remainder r n) (retry)))) (retry)))) ;; Deferred initializations (set! base-bits (force base-bits)) (set! base/2 (force base/2)) (set! base (force base)) (set! base-1 (force base-1)) (set! neg-base-bits (force neg-base-bits)) (export bn0 bn1 hexdigit->fix decdigit->fix ;; not strictly bignum ops, but handy bn->hex hex->bn bytes->bn bn->dec dec->bn bn->fix fix->bn bn-zero? bn-even? bn-odd? bn= bn< bn> bn<= bn>= bn+1 bn+ bn-1 bn- bn*2 bn-shift bn*fix bn* bn^2 bn/2 bn-unshift bn-divrem bn-quotient bn-remainder bn-mod-inverse bn-bits rand-bn))