分析
我连组合数学sbt都不会做了
一开始先把$n$和$m$都加上$1$,从网格变为坐标系。
直接算很难算。考虑计算总数减去不合法的。
总数非常容易,为$C_{n\times m}^3$。
不合法的有三种情况:
- 三个点在同一行上。每一行有$C_m^3$种不合法情况,共$n$行,总数为$n\cdot C_m^3$。
- 三个点在同一列上。每一列有$C_n^3$种不合法情况,共$m$列,总数为$m\cdot C_n^3$。
- 三个点在同一条斜线上。这一部分不太好算。
考虑一下斜线怎么算。
首先枚举两个端点$(a,b)$和$(c,d)$,于是中间的点有$\gcd(c-a,d-b)-1$种可能。
然后这样子是$O(n^2m^2)$的,考虑怎么优化。
发现斜率相同的直线可以放到一起去算。
先只考虑斜率为正的情况,也就是把$(a,b)$平移到原点。这样子就可以只枚举一个点$(i,j)$。
然后这样的直线有$(n-i)\times(m-j)$条。
这里只考虑了斜率为正的情况。发现斜率为负的情况是对称的,于是乘个$2$即可。
所以第三种情况的总数为$2\times\sum\limits_{i=1}^{n-1}\sum\limits_{j=1}^{m-1}(\gcd(i,j)-1)\times(n-i)\times(m-j)$。
所以得出,总的答案为$C_{n\times m}^3-n\cdot C_m^3-m\cdot C_n^3-2\times\sum\limits_{i=1}^{n-1}\sum\limits_{j=1}^{m-1}(\gcd(i,j)-1)\times(n-i)\times(m-j)$。
代码
//It is made by M_sea
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
typedef long long ll;
#define re register
using namespace std;
inline int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
inline ll C(int n) { return 1ll*(n-2)*(n-1)*n/6; } //C(n,3)
int main() {
int n=read()+1,m=read()+1;
ll ans=C(n*m)-n*C(m)-m*C(n);
for (re int i=1;i<n;++i)
for (re int j=1;j<m;++j)
ans-=2ll*(__gcd(i,j)-1)*(n-i)*(m-j);
printf("%lld\n",ans);
return 0;
}