Investigating Discrepancies in Optimal Transport Costs: Using Python Optimal Transport (POT)
Optimal transport theory involves stratergies moving mass from one place to another so to minimize optimal transport cost.
Let’s introduce the mathematical model of optimal transport. Consider a set of points on the plane, with a matrix distance with indicating the distances between the points and . A vector is said to be a mass distribution on if each component is non-negative and the total mass is equal to .
The core question involves determining the minimum transportation cost to rearrange mass from one distribution to another . Denoted as , this cost function considers the distances as the measure of cost for transferring one unit mass from to . The total optimal cost is expressed as:
Here, signifies the mass transferred from at to at . Preservation of mass is ensured by the conditions and for all .
An alternative cost function emerges when considering a straightforward mass rearrangement approach. Specifically, we focus on moving mass only from “heavy” locations such that to “light” locations with . The cost in this scenario is expressed as:
Here, and represent the positive and negative components of , respectively.
Our investigation pivots on questioning the equality between the optimal transport costs represented by (1) and (2). Utilizing numerical examples, we demonstrate that these two cost functions are not equal. The divergence stems from the fact that, in (2), the mass transferred at (no matter sending or receiving mass) consistently equals , while in (1), this quantity can be exceeded.
import ot
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(111)
# generate 5 points and the matrix distance
n = 5
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
M = ot.dist(xs, xs)
# OT(a, b)
a = np.random.rand(n)
a /= a.sum()
b = np.random.rand(n)
b /= b.sum()
opt_plan1 = ot.emd(a, b, M)
cost1 = (opt_plan1*M).sum()
print(f"Optimal cost 1 = {cost1}")
# OT([a-b]+, [a-b]-)
diff = a-b
c = diff.clip(min=0)
d = -diff.clip(max=0.)
opt_plan2 = ot.emd(c, d, M)
cost2 = (opt_plan2*M).sum()
print(f"Optimal cost 2 = {cost2}")
f, axs = plt.subplots(1, 2, figsize=(5*2, 5))
ax=axs[0]
ax.imshow(opt_plan1)
ax.axis("off")
ax.set_title(f"$OT(a, b)$={cost1:.4f}")
ax=axs[1]
ax.imshow(opt_plan2)
ax.axis("off")
ax.set_title(f"$OT([a-b]_+, [a-b]_-)$={cost2:.4f}")
plt.show()