来源:一道美图的笔试题
题目:
已知斐波那契数列 F(n) 满足 F(0) = 0, F(1) = F(2) = 1,当 n >= 3 时有 F(n) = F(n-1) + F(n-2),此外当 n >= 1 时有 F(2n-1) = F(n) * F(n) + F(n-1) * F(n-1),F(2n) = (2*F(n-1) + F(n)) * F(n),求数列第 n 项末尾的 9 个数字。注意输出中不应该包含前导 0,即输出 0,而不是 000000000,输出 6765,而不是 000006765。一些示例输入输出如下表所示:
输入 | 输出 |
---|---|
1 | 1 |
2 | 1 |
8 | 21 |
20 | 6765 |
46 | 836311903 |
60 | 8755920 |
3749999998 | 499999999 |
3749999999 | 500000001 |
3750000000 | 0 |
3750000001 | 500000001 |
281474976710656 | 309764667 |
分析:
显然这是一道大数问题,应当用字符串来表示数字,最后一个示例输入表明输入可能非常大,这时若只是一项一项地算,就会超时,而是应当利用题目中给出的后两个递推公式,采用递归的方式解决该问题,在解决该问题前,先总结回顾大数运算的字符串操作。
加法
定义宏 _MAX_BITS_ 表示字符串数组的长度(不包括 ‘\0’),也即加数的最大位数,按照写法习惯采用顺序存储的方式,数字的高位存储在数组索引低端,如果数字位数不及最大位数,则在数组低端补 ‘0’,例如,当 _MAX_BITS_ 取 9 时,6765 存储在字符串数组 a 中,则 a = “000006765”。数字的输入程序和打印程序如下:
#define _MAX_BITS_ 9 void inputBigNum(char *number) { cin.getline(number, _MAX_BITS_ + 1); int i = 0, diff = _MAX_BITS_ - static_cast<int>(strlen(number)); for (i = _MAX_BITS_ - 1; i >= diff; --i) number[i] = number[i - diff]; for (; i >= 0; --i) number[i] = '0'; } void printBigNum(char *number) { size_t i = 0, L = strlen(number); for (i = 0; i < L - 1; ++i) if (number[i] != 48) break; for (; i < L; ++i) cout << number[i]; cout << endl; }
在输入函数中,若位数不够最大位数,则进行移位补 ‘0’,输出函数中跳过了前导 ‘0’。
加法函数中,从数字低位往高位计算,每次逐位加的时候都记录是否有进位,因为两个加数最高位相加后可能发生进位,因此存储和的数组中整体往后偏移了一位:
void bigPlus(char *a, char *b, char *sum) { size_t L = std::max(strlen(a), strlen(b)); if (strlen(sum) < L + 1) throw std::exception("sum bits inadequate!"); char carry = 0; for (int i = _MAX_BITS_ - 1; i >= 0; --i) { sum[i+1] = a[i] - 48 + b[i] + carry; if (sum[i+1] > 57) { sum[i+1] -= 10; carry = 1; } else carry = 0; } sum[0] = 48 + carry; }
测试程序如下:
char *a = new char[_MAX_BITS_ + 1]; a[_MAX_BITS_] = '\0'; char *b = new char[_MAX_BITS_ + 1]; b[_MAX_BITS_] = '\0'; char *c = new char[_MAX_BITS_ + 2]; c[_MAX_BITS_ + 1] = '\0'; cout << "input a: "; inputBigNum(a); cout << "input b: "; inputBigNum(b); try { bigPlus(a, b, c); cout << "sum c: "; printBigNum(c); } catch (std::exception& e) { cout << e.what() << endl; } delete []a; delete []b; delete []c;
后面减法和乘法的测试程序中不再重复 a, b 两个数组,只体现存储结果的数组 c 的差异。
减法
这里的实现中不考虑被减数小于减数的情况,减法需要借位,从低位到高位逐位计算后,若最高位相减后,需要借位的话就表明被减数小于减数了:
void bigMinus(char *a /* minuend */, char *b /* subtractor */, char *difference) { size_t L = strlen(a); if (strlen(difference) < L) throw std::exception("difference bits inadequate!"); char borrow = 0; for (int i = _MAX_BITS_ - 1; i >= 0; --i) { difference[i] = a[i] + 48 - b[i] - borrow; if (difference[i] < 48) { difference[i] += 10; borrow = 1; } else borrow = 0; } if (borrow == 1) throw std::exception("minuend < subtractor!"); }
测试代码如下:
char *c = new char[_MAX_BITS_ + 1]; c[_MAX_BITS_] = '\0'; try { bigMinus(a, b, c); cout << "difference c: "; printBigNum(c); } ...
乘法
此处乘法的实现将逐位相乘后形成的值都累加到一个暂存的 int 数组中,最后再从低位到高位处理进位,具体原理参考博客 http://www.cnblogs.com/king-ding/p/bigIntegerMul.html 的解法二:
void bigMultiply(char *a, char *b, char *product) { size_t L = strlen(a) + strlen(b); if (strlen(product) < L) throw std::exception("product bits inadequate!"); int *result = new int[L]; memset(result, 0, L*sizeof(int)); for (int j = _MAX_BITS_ - 1; j >= 0; --j) for (int i = _MAX_BITS_ - 1; i >= 0; --i) result[i+j+1] += (a[i] - 48) * (b[j] - 48); for (size_t l = L - 1; l > 0; --l) { result[l-1] += result[l]/10; result[l] %= 10; product[l] = result[l] + 48; } product[0] = result[0] + 48; delete []result; }
测试代码如下:
char *c = new char[2*_MAX_BITS_ + 1]; c[2*_MAX_BITS_] = '\0'; try { bigMultiply(a, b, c); cout << "product c: "; printBigNum(c); }
解答:
有了上述的加法函数和乘法函数后(大数减法操作在本题中用不到),最棘手的地方就已经解决了,不过由于是需要递归实现,递归函数具备返回值比较方便,这样一来,上述的加法函数和乘法函数也要转换成具备返回值的形式:
char* bigPlus(char *a, char *b) { char carry = 0; char *sum = new char[_MAX_BITS_ + 1]; sum[_MAX_BITS_] = '\0'; for (int i = _MAX_BITS_ - 1; i >= 0; --i) { sum[i] = a[i] - 48 + b[i] + carry; if (sum[i] > 57) { sum[i] -= 10; carry = 1; } else carry = 0; } return sum; } char* bigMultiply(char *a, char *b) { int L = 2 * _MAX_BITS_; char *product = new char[_MAX_BITS_ + 1]; product[_MAX_BITS_] = '\0'; int *result = new int[L]; memset(result, 0, L*sizeof(int)); for (int j = _MAX_BITS_ - 1; j >= 0; --j) for (int i = _MAX_BITS_ - 1; i >= 0; --i) result[i+j+1] += (a[i] - 48) * (b[j] - 48); for (int l = L - 1; l > 0; --l) { result[l-1] += result[l]/10; result[l] %= 10; } for (int i = 0; i < _MAX_BITS_; ++i) product[i] = result[i+_MAX_BITS_] + '0'; delete []result; return product; }
注意到上述返回的字符串数组是动态分配内存的,并且回收内存的操作由调用者负责,在乘法函数中,由于我们只关心最后 9 位数,因此 product[] 只需要取出暂存 int 数组中最后 9 个数字就可以了,此外加分函数中也不关心倒数第 9 位相加后得到倒数的第 10 位数字如何,sum[] 数组也没有进行一位的偏移。
由于递推式中有个 2,而且这个 2 是需要在递归函数中被反复使用的,因此就需要创建一个全局的字符串数组专门来表示这个系数 2,此外该递归式是树形的递归结构,在递归的执行过程中,需要将已经计算得到的存储第 n 项最后 9 位数字的数组记录下来,这儿将其放置在一个全局哈希表中(需要 C++11):
static char *constant2 = NULL; static std::unordered_map<int64_t, char*> tailMap;
注意到递归函数的终止条件是当 n 为 1 或是 2 的时候终止,因此需要将这两项的结果在该递归函数被调用前就事先保存在哈希表中,下面的代码在 main() 函数中被调用:
constant2 = new char[_MAX_BITS_ + 1]; memset(constant2, '0', _MAX_BITS_*sizeof(char)); constant2[_MAX_BITS_ - 1] = '2'; constant2[_MAX_BITS_] = '\0'; char *F1 = new char[_MAX_BITS_ + 1]; memset(F1, '0', _MAX_BITS_*sizeof(char)); F1[_MAX_BITS_ - 1] = '1'; F1[_MAX_BITS_] = '\0'; char *F2 = new char[_MAX_BITS_ + 1]; F2[_MAX_BITS_] = '\0'; strcpy(F2, F1); tailMap.insert(std::pair<int64_t, char*>(1, F1)); tailMap.insert(std::pair<int64_t, char*>(2, F2));
哈希表中有了这两项后就可以根据递推关系式写出如下的递归函数了:
char* Fn(int64_t n) { if (n < 3) return tailMap[n]; char *result = NULL; if (n & 1) // F(2n-1) = F(n)^2 + F(n-1)^2 { char *fn = tailMap.find(n/2 + 1) != tailMap.end() ? tailMap[n/2 + 1] : Fn(n/2 + 1); char *fn_1 = tailMap.find(n/2) != tailMap.end() ? tailMap[n/2] : Fn(n/2); char *fnSquare = bigMultiply(fn, fn); char *fn_1Square = bigMultiply(fn_1, fn_1); result = bigPlus(fnSquare, fn_1Square); delete []fnSquare; delete []fn_1Square; } else // F(2n) = (2*F(n-1) + F(n)) * F(n) { char *fn_1 = tailMap.find(n/2 - 1) != tailMap.end() ? tailMap[n/2 - 1] : Fn(n/2 - 1); char *dualFn_1 = bigMultiply(constant2, fn_1); char *fn = tailMap.find(n/2) != tailMap.end() ? tailMap[n/2] : Fn(n/2); char *parenthesesSum = bigPlus(dualFn_1, fn); result = bigMultiply(parenthesesSum, fn); delete []dualFn_1; delete []parenthesesSum; } tailMap.insert(std::pair<int64_t, char*>(n, result)); return result; }
注意其中的中间计算结果并不是保存在哈希表中的,因此在使用过后应当立即释放其内存,而且在每次计算得到新的项后立即将其插入到哈希表中,这样若以后再当碰到 n 为该值时,就直接从哈希表中里面取值就好了,从而树形的递归结构退化成了接近于线形的结构了,而无论输入的 n 有多大(在 64 位范围内),递归的深度都不会太深,不会超过 70 层,另外经过测试,哈希表中存储的项最多也就几百项,因此算法的效率还是非常高的。
最后注意在 main() 函数中输出结果之后,要记得释放掉所有保存在哈希表中的字符串数组。
完整的程序代码如下(“……” 表示前述的那些函数),调整宏 _MAX_BITS_ 的值可改变显示末尾数字的个数,题目是中指定是 9,因此就设置为 9
#include <cstdlib> #include <cstdint> #include <cstring> #include <exception> #include <unordered_map> #include "Timer.hpp" using std::cin; using std::cout; using std::endl; #define _MAX_BITS_ 9 static char *constant2 = NULL; static std::unordered_map<int64_t, char*> tailMap; ...... int main(void) { Timer programTimer("Algorithm cost: "); int64_t m; cin >> m; constant2 = new char[_MAX_BITS_ + 1]; memset(constant2, '0', _MAX_BITS_*sizeof(char)); constant2[_MAX_BITS_ - 1] = '2'; constant2[_MAX_BITS_] = '\0'; char *F1 = new char[_MAX_BITS_ + 1]; memset(F1, '0', _MAX_BITS_*sizeof(char)); F1[_MAX_BITS_ - 1] = '1'; F1[_MAX_BITS_] = '\0'; char *F2 = new char[_MAX_BITS_ + 1]; F2[_MAX_BITS_] = '\0'; strcpy(F2, F1); tailMap.insert(std::pair<int64_t, char*>(1, F1)); tailMap.insert(std::pair<int64_t, char*>(2, F2)); char *answer = Fn(m); printBigNum(answer); delete []constant2; for (std::unordered_map<int64_t, char*>::iterator it = tailMap.begin(); it != tailMap.end(); ++it) delete []it->second; tailMap.clear(); return 0; }
可以将 m 设置为固定的大数以测试算法用时(因为输入要用时,用户输入就没法计时了),Timer.hpp 头文件如下:
#ifndef __TIMER_HPP__ #define __TIMER_HPP__ #include <time.h> #include <iostream> #include <string> /** Convenient class to calculate time elapsing. */ class Timer { public: explicit Timer(const char *s) : _note(s) { _start_time = clock(); } ~Timer() { std::cout << _note.c_str() << _elapsed() << "s.\n"; } inline clock_t elapsed() { return clock() - _start_time; } inline double _elapsed() { return static_cast<double>(clock() - _start_time) / CLOCKS_PER_SEC; } private: Timer(const Timer&); Timer& operator=(const Timer&); private: std::string _note; clock_t _start_time; }; #endif /* __TIMER_HPP__ */