BigNum

From Second Life Wiki
Jump to: navigation, search

Introduction

(http://www.gnu.org/copyleft/fdl.html) in the spirit of which this script is GPL'd. Copyright (C) 2010 Xaviar Czervik

This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.


What follows is my feeble attempt to create a working BigNum library for LSL. The purpose I designed it for was modular exponentiation; however all of the helper functions (add, subtract, multiply, divide, modulus) work just fine on their own. I say feeble because while working, it is quite slow. It is, however, faster at modular exponentiation than any other LSL function that I have seen.

While this is exactly what RSA is, I did not call it an RSA Library because it can be used for many other encryption techniques -- I am currently using it for Paillier encryption.


Explanation

Getting Started

Using this script is quite simple. To get started, simply copy the code below into a script and save it. By default, it will use RSA to encrypt and decrypt a message using a 128 bit key (encryption takes about three and a half minutes; decryption, longer). However, that is probably not what you will want it to do. What follows is an explanation of how to use this script.

Creating BigNums

Bitnums are stored a lists, with each element in the list essentially a base-65536 integer. For example, [1, 0] represents the number 65536, [1, 0, 0] represents 4294967296. Negative signs are denoted by a "N" prepended to the list. The simplest way to create a bignum is to call fromHex(), where the argument is a string of hex characters: for example, fromHex("10000") would return [1, 0]. You can also make a bignum by calling fromBase with what ever base you want; for example, fromBase("65536", 10) would return [1, 0].

Manipulating BigNums

All of the functions do exactly what it seems that they do... add adds two bignums, multiply multiplies them. There are three functions that are not exactly trivial to use, however; and they are euclid, montStep, and modPow2. They are described later on.

Montgomery Reduction Algorithm

Because this script was designed to do modular exponentiation very quickly, I made use of the Montgomery reduction algorithm. This is a very efficient way to do repeated modular multiplication. I won't describe it in detail here, read the Wikipedia page if you want to understand it. However, in order to use this, a value must first be obtained from the Extended Euclidean algorithm. Hence the function euclid.

Euclid

The Euclid funtion is a helper function to calculate the needed value for the Montgomery Reductions.

MontStep

This MontStep function does a single Montgomery step, based off of the inputs.

ModPow2

The ModPow2 function is quite similar to the ModPow function, except that it also takes xp as a parameter. Xp is the value obtained from the Extended Euclidean algorithm, which along with the modulus and exponent are fixed. This means that it can be pre-calculated and saved as part of the public key. This can result in savings of up to 60% for 128 bit keys, but slowly drops down to 30% for 512 bit keys.



Revision History

Version 1.1

  1. Achieved 40% (or higher) increase in speed with moduli larger than 50 bits. (Changed Montgomery Reduction to pick R a multiple of 16 bits.)
  2. Fixed example code to actually work...

Version 1.0

  1. First release.


To Do

Below are a list of tasks that need to be done to make this script more useful. If you know how to do any of the below, and have some extra time, feel free to add in any of these. (Or anything you may think of on your own!)


Optimize Code: This code has not been optimized in the slightest! I still use a O(n^2) algorithm for multiplication and it just gets worse for division (which, thanks to the Montgomery Reductions, is almost never used). Not to mention the micro-optimizations that could be made.

1024+ Bit Key Sizes: This can only encrypt using an key size of 512 bits, or a stack-heap collision is imminent! This is pathetically low in terms of security.

Floating-Point Number Support: This script only works with integer values, allowing infinite precision floats would be nice (but not required for cryptography).


The Code

integer compare(list a, list b) {
    integer an;
    integer bn;
    if (llList2String(a, 0) == "N") {
        a = llListReplaceList(a, [], 0, 0);
        an = 1;
    }
    if (llList2String(b, 0) == "N") {
        b = llListReplaceList(b, [], 0, 0);
        bn = 1;
    }
    while (llList2Integer(b, 0) == 0 && llGetListLength(b) > 0)
        b = llListReplaceList(b, [], 0, 0);
    while (llList2Integer(a, 0) == 0 && llGetListLength(a) > 0)
        a = llListReplaceList(a, [], 0, 0);
 
    if (a == [] && b == []) 
        return 0;
 
    if (an && !bn)
        return 2;
    if (!an && bn)
        return 1;
 
    if (llGetListLength(a) > llGetListLength(b)) {
        if (an && bn)
            return 2;
        if (!an && !bn)
            return 1;
    } else if (llGetListLength(b) > llGetListLength(a)) {
        if (an && bn)
            return 1;
        if (!an && !bn)
            return 2;
    } else {
        integer i = 0;
        while (i < llGetListLength(b)) {
            if (llList2Integer(b, i) > llList2Integer(a, i)) {
                if (an && bn)
                    return 1;
                if (!an && !bn)
                    return 2;
            } else if (llList2Integer(b, i) < llList2Integer(a, i)) {
                if (an && bn)
                    return 2;
                if (!an && !bn)
                    return 1;
            }
            i++;
        }
        return 0;
    }
    return 0;
}
 
list normalize(list ret) {
    integer neg;
    while (llList2String(ret, 0) == "N") {
        ret = llListReplaceList(ret, [], 0, 0);
        neg = !neg;
    }
    integer i = llGetListLength(ret)-1;
    integer rem;
    while (i >= 0) {
        integer newVal = llList2Integer(ret, i)+rem;
        ret = (ret = []) + llListReplaceList(ret, [newVal & 0xFFFF], i, i);
        rem = (newVal&0x7FFF0000) >> 16;
        if (newVal < 0) rem += 1 << 15;
        i--;
    }
    if (rem != 0) {
        ret = (ret = []) + [rem] + ret;
    }
    while (llList2Integer(ret, 0) == 0 && llGetListLength(ret) > 1)
        ret = (ret = []) + llListReplaceList(ret, [], 0, 0);
    if (ret == [])
        ret = [0];
    if (neg)
        ret = (ret = []) + "N" + ret;
    return ret;
}
 
string hex(integer bits) {
    string nybbles = "";
    while (bits) {
        integer lsn = bits & 0xF;
        string nybble = llGetSubString("0123456789ABCDEF", lsn, lsn);
        nybbles = nybble + nybbles;
        bits = bits >> 4;
        bits = bits & 0xfffFFFF;
    }
    while (llStringLength(nybbles) < 4)
        nybbles = "0"+nybbles;
    return nybbles;
}
 
 
string toHex(list a) {
    string ret;
    if (llList2String(a, 0) == "N") {
        a = llListReplaceList(a, [], 0, 0);
        ret += "-";
    }
    integer i = 0;
    while (i < llGetListLength(a)) {
        ret += hex(llList2Integer(a, i));
        i++;
    }
    return ret;
}
 
list fromHex(string a) {
    list ret;
    integer neg;
    if (llGetSubString(a, 0, 0) == "-") {
        a = llGetSubString(a, 1, -1);
        neg = 1;
    }
    integer i = llStringLength(a);
    while (i > 4) {
        i -= 4;
        ret = (integer)("0x"+llGetSubString(a, i, i+3)) + ret;
    }
    if (i != 0) ret = (integer)("0x"+llGetSubString(a, 0, i-1)) + ret;
    if (neg)
    ret = ["N"] + ret;
    return ret;
}
 
list add(list a, list b) {
    if (llList2String(b, 0) == "N")
        return subtract(a, llList2List(b, 1, -1));
    if (llList2String(a, 0) == "N")
        return subtract(b, llList2List(a, 1, -1));
    if (compare(a, b) == 2)
        return add(b, a);
    while (llGetListLength(b) != llGetListLength(a)) {
        b = (b = []) + [0] + b;
    }
    integer i = 0;
    while (i < llGetListLength(a)) {
        a = (a = []) + llListReplaceList(a, [llList2Integer(a, i)+llList2Integer(b, i)], i, i);
        i++;
    }
    return normalize(a);
}
 
list subtract(list a, list b) {
    if (llList2String(b, 0) == "N")
        return add(a, llList2List(b, 1, -1));
    if (llList2String(a, 0) == "N")
        return normalize(["N"]+add(b, llList2List(a, 1, -1)));
    if (compare(a, b) == 2)
        return normalize(["N"]+subtract(b, a));
 
    while (llGetListLength(b) != llGetListLength(a)) {
        b = (b = []) + [0] + b;
    }
    integer c;
    integer i = llGetListLength(a)-1;
    while (i >= 0) {
        integer oldC = c;
        if (llList2Integer(a, i)-llList2Integer(b, i)-oldC < 0) {
            c = 1;
        } else {
            c = 0;
        }
        a = (a = []) + llListReplaceList(a, [llList2Integer(a, i)-llList2Integer(b, i) - oldC + (c << 16)], i, i);
        i--;
    }
    return normalize(a);
}
 
list multiply(list a, list b) {
    if (llList2String(a, 0) == "N" && llList2String(b, 0) == "N")
        return multiply(llList2List(a, 1, -1), llList2List(b, 1, -1));
    if (llList2String(b, 0) == "N")
        return normalize("N"+multiply(a, llList2List(b, 1, -1)));
    if (llList2String(a, 0) == "N")
        return normalize("N"+multiply(b, llList2List(a, 1, -1))); 
    if (compare(a, b) == 2)
        return multiply(b, a);
    if (compare(a, [0]) == 0)
        return [0];
    if (compare(b, [0]) == 0)
        return [0];
 
    list ret = [0];
    list partialSum = [0];
 
    integer i = 0;
    while (i < llGetListLength(b)) {
        integer subPart = llList2Integer(b, i);
        partialSum = [];
        integer j = 0;
        while (j < llGetListLength(a)) {
            partialSum = (partialSum = []) + partialSum + llList2Integer(a, j)*subPart;
            j++;
        }
        partialSum = (partialSum = []) + normalize(partialSum);
        integer k = 0;
        while (k < llGetListLength(b)-i-1) {
            partialSum = (partialSum = []) + partialSum + [0];
            k++;
        }
        ret = (ret = []) + normalize(add(partialSum, ret));
        i++;
    }
    return normalize(ret);
}
 
list bigDivide(list dividend, list divisor, integer remainder) {
    if (llList2String(dividend, 0) == "N" && llList2String(divisor, 0) == "N")
        return divide(llList2List(divisor, 1, -1), llList2List(dividend, 1, -1));
    if (llList2String(dividend, 0) == "N")
        return normalize("N"+divide(divisor, llList2List(dividend, 1, -1)));
    if (llList2String(divisor, 0) == "N")
        return normalize("N"+multiply(llList2List(divisor, 1, -1), dividend));
    integer i = 0;
    integer li = 0;
    integer len = llGetListLength(dividend);
    list zeros;
    list ret;
    while (i < len) {
        if (compare(llList2List(dividend, li, i), divisor) != 2) {
            list workingOn = llList2List(dividend, li, i);
            integer mid;
            integer first = 0;
            integer last = 65535;
            list val;
            while (first <= last) {
                mid = (first + last)/2;
                val = (val = []) + multiply([mid], divisor);
                if (compare(val, workingOn) == 0) {
                    first = last + 1;
                } else if (compare(val, workingOn) == 1) {
                    last = mid-1;
                } else {
                    first = mid+1;
                }
            }
            if (last < mid)
                mid = last;
            if (first < mid)
                mid = last;
            ret = (ret = []) + ret + [mid];
            zeros = [];
            integer k;
            for (k = 1; k < llGetListLength(dividend)-i; k++) zeros += [0];
            dividend = (dividend = []) + subtract(dividend, multiply(divisor, [mid])+zeros);
            zeros = [];
            for (k = llGetListLength(dividend); k < len; k++) zeros += [0];
            dividend = (dividend = []) + zeros + dividend;
        } else {
            ret = (ret = []) + ret + 0;
        }
        i++;
    }
    if (remainder)
        return normalize(dividend);
    return normalize(ret);
}
 
list divide(list divisor, list dividend) {
    return bigDivide(divisor, dividend, 0);
}
list mod(list divisor, list dividend) {
    return bigDivide(divisor, dividend, 1);
}
 
list shiftLeft(list a, integer n) {
    if (llList2String(a, 0) == "N")
        return normalize("N"+shiftLeft(llList2List(a, 1, -1), n));
    while (n >= 16) {
        a += [0];
        n -= 16;
    }
    integer mask = 0xFFFF ^ ((1 << (16-n)) - 1);
    integer rem;
    integer i = llGetListLength(a)-1;
    while (i >= 0) {
        integer this = llList2Integer(a, i);
        integer oldRem = rem >> (16-n);
        rem = this & mask;
        this = this ^ rem;
        this = (this << n) | oldRem;
        a = (a = []) + llListReplaceList(a, [this], i, i);
        i--;
    }
    rem = rem >> (16-n);
    if (rem != 0) {
        a = [rem] + a;
    }
    return a;
}
 
list shiftRight(list a, integer n) {
    if (llList2String(a, 0) == "N")
        return normalize("N"+shiftRight(llList2List(a, 1, -1), n));
    if (n >= 16) {
        a = llListReplaceList(a, [], -n/16, -1);
        n = n % 16;
    }
    integer mask = (1 << n) - 1;
    integer rem;
    integer i = 0;
    while (i < llGetListLength(a)) {
        integer this = llList2Integer(a, i);
        integer oldRem = rem << (16-n);
        rem = this & mask;
        this = (this >> n) | oldRem;
        a = (a = []) + llListReplaceList(a, [this], i, i);
        i++;
    }
    return a;
}
 
list cutBits(list l, integer r) {
    if (llList2String(l, 0) == "N") {
        return normalize("N"+cutBits(llList2List(l, 1, -1), r));
    }
    list ret;
    integer full = r/16;
    if (full > 0)
        ret = (ret = []) + ret + llList2List(l, -full, -1);
    integer leftOver = llList2Integer(l, -full-1);
    integer ext = r % 16;
    leftOver = leftOver & ((1 << ext) - 1);
    return leftOver + ret;
}
 
integer getBits(list n) {
    if (llList2String(n, 0) == "N")
        return getBits(llList2List(n, 1, -1));
    integer tmp = llGetListLength(n)*16 - 15;
    integer extra = llList2Integer(n, 0);
    extra = (integer)(llLog(extra)/llLog(2));
    return tmp+extra;
}
 
integer getBit(list n, integer bit) {
    integer loc = bit/16;
    integer val = llList2Integer(n, -loc-1);
    return (val & (1 << (bit % 16))) > 0;
}
 
list euclid(list big1, list big2) {
    list x = [0];
    list y = [1];
    list xx = [1];
    list yy = [0];
 
    list q;
    list t;
 
    while (compare(big2, [0]) != 0) {
        q = (q = []) + divide(big1, big2);
 
        t = (t = []) + big2;
        big2 = (big2 = []) + subtract(big1, multiply(q, big2));
        big1 = (big1 = []) + t;
 
        t = (t = []) + x;
        x = (x = []) + subtract(xx, multiply(q, x));
        xx = (xx = []) + t;
 
        t = (t = []) + y;
        y = (y = []) + subtract(yy, multiply(q, y));
        yy = (yy = []) + t;
    }
    return yy;
}
 
 
list montStep(list a, list b, list n, integer r, list np) {
    list t = multiply(a, b);
    list m = cutBits(multiply(t, np), r);
    if (llList2String(m, 0) == "N") {
    list om = m;
        m = add(shiftLeft([1], r), m);
    }
    list u = shiftRight(add(t, multiply(m, n)), r);
    if (compare(u, n) == 1) {
        return normalize(subtract(u, n));
    } else {
        return normalize(u);
    }
}
 
list modExp(list message, list exponent, list modulus) {
    integer k = getBits(modulus);
    k = 16*llCeil(k/16.0);
    integer k2 = getBits(exponent);
    list r = shiftLeft([1], k);
    list ap = mod(multiply(message, r), modulus);
    list xp = mod(r, modulus);
 
    list np = euclid(r, modulus);
    if (llList2String(np, 0) == "N")
        np = llList2List(np, 1, -1);
    else
        np = ["N"]+np;
 
    integer i = k2;
    while (i >= 0) {
        xp = (xp = []) + montStep(xp, xp, modulus, k, np);
        if (getBit(exponent, i)) {
            xp = (xp = []) + montStep(ap, xp, modulus, k, np);
        }
        i--;
    }
    return montStep(xp, [1], modulus, k, np);
}
 
 
list modExp2(list message, list exponent, list modulus, list np) {
    integer k = getBits(modulus);
    k = 16*llCeil(k/16.0);
    integer k2 = getBits(exponent);
    list r = shiftLeft([1], k);
    list ap = mod(multiply(message, r), modulus);
    list xp = mod(r, modulus);
 
    integer i = k2;
    while (i >= 0) {
        xp = (xp = []) + montStep(xp, xp, modulus, k, np);
        if (getBit(exponent, i)) {
            xp = (xp = []) + montStep(ap, xp, modulus, k, np);
        }
        i--;
    }
    return montStep(xp, [1], modulus, k, np);
}
 
default {
    state_entry() {
        list message = fromHex("123456789ABCDEF0");
        list cipher;
        list decrypt;
 
        float t = llGetWallclock();
        cipher = modExp(message, fromHex("10001"), fromHex("64f3c5db7c9db53dda3ee77a3dc7097"));    
        llOwnerSay("When encrypted, " + toHex(message) + " is " + toHex(cipher) + ".");
        llOwnerSay("It took " + (string)(llGetWallclock()-t) + " seconds to complete this calculation.");
 
        t = llGetWallclock();
        decrypt = modExp(cipher, fromHex("2f61db73f14174556fe950f3b3f2cc1"), fromHex("64f3c5db7c9db53dda3ee77a3dc7097"));
        llOwnerSay("When decrypted, " + toHex(cipher) + " is " + toHex(decrypt) + ".");
        llOwnerSay("It took " + (string)(llGetWallclock()-t) + " seconds to complete this calculation.");
    }
}

Generalized Bases

Be warned that this is VERY slow, about 1 second per digit.

list fromBase(string in, integer base){
    integer i = llStringLength(in)-1;
    list mul = fromHex("1");
    list ten = [base];
    list sum = fromHex("0");
    for(;i>=0;--i){
        sum = add(sum, multiply(mul, fromHex(llGetSubString(in, i, i))));
        mul = multiply(mul, ten);
    }
    return sum;
}
 
string toBase(list in, integer base){
    string chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+="; //This is the list of characters each digit will be replaced with
    //For example, 10 -> A, 36 -> a, etc.
    string out;
    list zero = [0];
    list ten = [base];
    while(compare(in, zero) == 1){
        integer pos = (integer)mod(in, ten);
        in = divide(in, ten);
        out = llGetSubString(chars, pos, pos) + out;
    }
    return out;
}