DataStructure/Algorithm

Karatsuba Algorithm

S0LL 2024. 9. 11. 17:15

Karatsuba Algorithm은 큰 문제를 잘게 쪼개서 하나씩 하나씩 결합하는 분할&정복 알고리즘 중에 하나입니다.

 

먼저, 문제를 어떻게 잘게 쪼개는지 살펴보겠습니다.

 

1234*5678을 예시로 살펴보죠.

이런 식으로 곱해야 하는 수가 한자리 수가 될 때까지 위와 같은 방법으로 분할시킵니다.

 

그럼 합칠 때는 어떻게 해야할까요?

 

위의 과정을 볼 때, 어떤 두 수의 곱셈을 3개의 부분으로 분할시키고 있습니다.

 

가장 왼쪽에 위치하는 어떤 수들의 앞부분끼리 곱하고 있는 부분. (이를 a라 칭하겠습니다)

중간에 위치한 각 수를 분할한 값끼리 더하고 곱셈하는 부분. (이를 b라 칭하겠습니다)

가장 오른쪽에 위치한 어떤 수들의 뒷부분들끼리 곱하고 있는 부분. (이를 c라 칭하겠습니다)

 

원래의 식 1234*5678을 분할시켜보면 어떻게 해야 할지 감이 잡히실겁니다.

 

아래의 식을 한번 보시죠.

 

어떤가요?? 어떻게 계산해야 할지 감이 오시나요?

계산 값들을 알파벳으로 치환해두면 조금 더 보기 편하실지도 몰라 알파벳으로 표시를 해두었습니다.

 

위의 과정에서 큰 수의 곱셈을 작은 수의 곱셈으로 분할시킴으로써 곱셈의 횟수를 줄였습니다.

 

분할시킬수록 계산해야 하는 값은 단순해지고 곱셈의 횟수도 점차 줄어들겠죠.

 

하지만 분할시키고 다시 합치는데 걸리는 시간도 꽤나 걸릴 것 같습니다.

 

이제 시간 복잡도에 대해 알아볼 시간입니다.

 

 

-시간 복잡도

이 알고리즘은 문제를 3개의 작은 문제로 나누고 크기는 절반으로 줄이기 때문에

T(n) = 3T(n/2) 라고 표현할 수 있습니다. 

 

문제가 3개로 쪼개지고 크기가 절반으로 줄어드는 과정이 계속 반복되면

 

T(n) = 3T(n/2) -> 9T(n/4) -> . . . . . ->3^s * T(1)

 

결국 문제의 개수는 3의 거듭제곱꼴로 늘어나고, 문제의 크기는 1로 수렴하게 됩니다.

 

여기서, n = 2^s 라고 가정하면 s=log2 n 이고, 이를 대입해보면,

 

이와 같은 결과가 나옵니다. log2 3은 약 1.59정도이므로 

분할을 하지 않고 계산했을 때의 n^2보다 빠른 것을 확인할 수 있습니다.

 

마스터 정리라는 것이 있지만 공식을 유도해보는것도 재미있는 것 같아 유도해보았습니다.

 

 

-코드

곱해야 하는 두 숫자의 자릿수가 너무 커서 int형이나 long long형으로도 해결할 수 없을 때도 이용할 수 있도록 

문자열을 이용한 덧셈, 뺄셈, 곱셈을 구현했습니다.

 

더 이상 나누어지지 않을 때까지(1자리가 될 때까지) 나누고 아래서부터 하나씩 결합하며 합치는 방식입니다.

#include <assert.h>

#include <iostream>
#include <string>
#include <vector>

std::string Add(std::string str1, std::string str2) {
  if (!str1.size() && !str2.size()) return "0";

  int dig_1 = str1.size();
  int dig_2 = str2.size();
  int N = std::max(dig_1, dig_2);

  if (dig_1 > dig_2) {
    for (int i = 0; i < dig_1 - dig_2; i++) {
      str2.insert(0, "0");
    }
  } else if (dig_1 < dig_2) {
    for (int i = 0; i < dig_2 - dig_1; i++) {
      str1.insert(0, "0");
    }
  }

  std::string result(N, '0');
  int carry = 0;
  for (int i = N - 1; i >= 0; i--) {
    int n1 = str1[i] - '0';
    int n2 = str2[i] - '0';
    int sum = n1 + n2 + carry;
    carry = sum / 10;
    result[i] = sum % 10 + '0';
  }
  if (carry > 0) {
    result.insert(0, "1");
  }

  return result;
}

std::string Subtract(std::string str1, std::string str2) {
  if (str1 == str2) return "0";

  int dig_1 = str1.size();
  int dig_2 = str2.size();
  int N = std::max(dig_1, dig_2);

  if (dig_1 > dig_2) {
    for (int i = 0; i < dig_1 - dig_2; i++) {
      str2.insert(0, "0");
    }
  } else if (dig_1 < dig_2) {
    for (int i = 0; i < dig_2 - dig_1; i++) {
      str1.insert(0, "0");
    }
  }

  std::string result(N, '0');
  int carry = 0;
  for (int i = N - 1; i >= 0; i--) {
    int n1 = str1[i] - '0';
    int n2 = str2[i] - '0';
    int sum = n1 - n2 + carry + 10;
    carry = sum / 10 - 1;
    result[i] = sum % 10 + '0';
  }
  int idx = 0;
  while (result[idx] == '0') idx++;
  result = result.substr(idx, N - idx);

  return result;
}

std::string KaratsubaHelper(std::string str1, std::string str2, int level) {
  std::cout << "Level " << level << " : " << str1 << " x " << str2 << '\n';

  int dig_1 = str1.size();
  int dig_2 = str2.size();
  int N = std::max(dig_1, dig_2);

  if (dig_1 > dig_2) {
    for (int i = 0; i < dig_1 - dig_2; i++) {
      str2.insert(0, "0");
    }
  } else if (dig_1 < dig_2) {
    for (int i = 0; i < dig_2 - dig_1; i++) {
      str1.insert(0, "0");
    }
  }

  if (N == 1) {
    std::string result = std::to_string(stoi(str1) * stoi(str2));
    return result;
  }

  int mid = N / 2;

  std::string a = str1.substr(0, mid);
  std::string b = str1.substr(mid, N - mid);

  std::string c = str2.substr(0, mid);
  std::string d = str2.substr(mid, N - mid);

  std::string ac = KaratsubaHelper(a, c, level + 1);
  std::string bd = KaratsubaHelper(b, d, level + 1);
  std::string abcd =
      KaratsubaHelper(std::to_string(stoi(a) + stoi(b)),
                      std::to_string(stoi(c) + stoi(d)), level + 1);
  abcd = Subtract(abcd, Add(ac, bd));

  ac.append(std::string((N - mid) * 2, '0'));
  abcd.append(std::string((N - mid), '0'));
  std::string result = Add(Add(ac, bd), abcd);

  return result;
}

std::string Karatsuba(std::string str1, std::string str2) {
  if (!str1.size() && !str2.size()) return "0";

  std::string result = KaratsubaHelper(str1, str2, 0);

  int idx = 0;
  while (result[idx] == '0') idx++;
  result = result.substr(idx, result.size() - idx);

  return result;
}

int main(void) {
  std::vector<std::vector<std::string>> tests = {
      {"1234", "5678", std::to_string(1234 * 5678)},
      {"12", "34", std::to_string(12 * 34)},
      {"123", "2", std::to_string(123 * 2)},
      {"123", "45", std::to_string(123 * 45)},
      {"110", "110", std::to_string(110 * 110)},
      {"5555", "55", std::to_string(5555 * 55)},
      {"5555", "5555", std::to_string(5555 * 5555)},
      {"98234712354214154", "171454654654655",
       "16842798681791158832220782986870"}
      // , {"9823471235421415454545454545454544",
      // "1714546546546545454544548544544545",
      // "16842798681791114273590624445460185389471221520083884298838480662480"}
  };

  for (const auto& t : tests) {
    const std::string str1 = t[0];
    const std::string str2 = t[1];
    const std::string expected = t[2];

    std::cout << str1 << " * " << str2 << " = " << expected << '\n';

    const std::string result = Karatsuba(str1, str2);

    std::cout << result << " " << expected << " ";

    if (result == expected)
      std::cout << "OK";
    else {
      std::cout << "Not Ok";
      exit(-1);
    }
    std::cout << '\n' << '\n';
  }
  std::cout << "All OK!" << '\n';

  return 0;
}

'DataStructure > Algorithm' 카테고리의 다른 글

QuickSort (퀵정렬)  (0) 2024.09.16
MergeSort (병합 정렬)  (2) 2024.09.15
Sequential Search vs Binary Search ( 순차 탐색 vs 이진 탐색)  (0) 2024.09.13
마스터 정리와 증명  (2) 2024.09.12
점근 표기  (0) 2024.09.11