I treat each top-left–to–bottom-right diagonal as a group keyed by i - j. For every offset d = i - j, the whole diagonal lies either on/below the main diagonal (d ≥ 0) or above it (d < 0). So I collect each diagonal: if d ≥ 0 (bottom-left, including main), I sort it in non-increasing order; if d < 0 (top-right), I sort it in non-decreasing order. Then I write the sorted values back along the same diagonal. This processes all diagonals independently and satisfies both ordering rules.
class Solution:
def sortMatrix(self, grid: List[List[int]]) -> List[List[int]]:
n = len(grid)
for d in range(-(n - 1), n):
elems = []
i_start = max(0, d)
i_end = min(n - 1, n - 1 + d)
for i in range(i_start, i_end + 1):
j = i - d
elems.append(grid[i][j])
if d >= 0:
elems.sort(reverse=True)
else:
elems.sort()
idx = 0
for i in range(i_start, i_end + 1):
j = i - d
grid[i][j] = elems[idx]
idx += 1
return grid
- Time: O(n^2 log n)
- Space: O(n)
