The solution is split into two parts:
- Check if there are loops
- Return ordered vertices indices
Check for the loops is done with recursion and without just as an exercise.
A more interesting thing is to compute the order. So we actually solve a different problem -- for each vertex tell what is the earliest moment we can start it.
Easy to see that:
start[i] = max{ start[child] } + 1
So that's all the idea.
class Solution {
public:
enum class Status {
NotVisited, InProgress, Visited
};
bool isThereLoop(vector<vector<int>>& adj, vector<Status>& visited, int u) {
if (visited[u] == Status::Visited)
return false;
visited[u] = Status::InProgress;
bool res = false;
for (auto& v : adj[u]) {
if (visited[v] == Status::InProgress)
return true;
res = res | isThereLoop(adj, visited, v);
}
visited[u] = Status::Visited;
return res;
}
bool isThereLoopIter(vector<vector<int>>& adj, vector<Status>& visited, int u) {
stack< pair<int, Status> > st;
st.emplace(u, Status::NotVisited);
while (!st.empty()) {
auto[cur, status] = st.top(); st.pop();
if (status == Status::NotVisited) {
visited[cur] = Status::InProgress;
st.emplace(cur, Status::Visited);
for (auto& v : adj[cur]) {
if (visited[v] == Status::InProgress) {
return true;
}
if (visited[v] == Status::NotVisited)
st.emplace(v, Status::NotVisited);
}
} else {
visited[cur] = Status::Visited;
}
}
return false;
}
// to make it simple, moved check for loops to a separate function
void visit(vector<vector<int>>& adj, vector<int>& start, int u) {
if (start[u] != -1)
return;
start[u] = 0; // a leaf
for (int v : adj[u]) {
visit(adj, start, v);
start[u] = max(start[u], start[v] + 1);
}
}
vector<int> findOrder(int numCourses, vector<vector<int>>& prerequisites) {
vector<vector<int>> adj(numCourses);
for (auto& e : prerequisites) {
adj[e[0]].push_back(e[1]);
}
vector<Status> visited(numCourses, Status::NotVisited);
bool res = false;
for (int i = 0; i < numCourses; ++i) {
if (isThereLoopIter(adj, visited, i))
return vector<int>{};
}
vector<int> start(numCourses, -1);
for (int i = 0; i < numCourses; ++i)
if (start[i] == -1)
visit(adj, start, i);
// now make a list of indices
// short way
vector<int> res(numCourses);
iota(res.begin(), res.end(), 0);
sort(res.begin(), res.end(), [&start](int i, int j) {
return start[i] < start[j];
} );
/* Long way
vector< pair<int,int> > pairs(numCourses);
for (int i = 0; i < numCourses; ++i) {
pairs[i] = {start[i], i};
}
sort(pairs.begin(), pairs.end());
vector<int> answer;
answer.reserve(numCourses);
for (auto& p : pairs)
answer.emplace_back(p.second);
*/
return answer;
}
};