내 풀이

  1. 트리의 노드를 순차적으로 방문
  2. 노드의 자식 노드를 방문하여 해당 자식의 자식 노드를 순회하며 제거할 리스트에 추가
  3. 제거할 리스트가 나오면 순차적 방문 중인 노드의 자식 노드 중 해당하는 노드를 제거
  4. 자식 노드는 하나만 남게 되는 데 해당 노드는 진짜 자식 노드가 된다.
  • 코드

      package contest.java.p02;
    
      import java.io.*;
      import java.util.*;
    
      public class Main {
          private static final HashMap<String, HashSet<String>> childTree = new HashMap<>();
          private static final HashMap<String, LinkedList<String>> ancestorTree = new HashMap<>();
    
          public static void main(String[] args) throws Exception {
              BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
              BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
    
              int n = Integer.parseInt(br.readLine());
              StringTokenizer st = new StringTokenizer(br.readLine());
              ArrayList<String> names = new ArrayList<>();
              for (int i = 0; i < n; i++) {
                  String temp = st.nextToken();
                  names.add(temp);
                  childTree.put(temp, new HashSet<>());
                  ancestorTree.put(temp, new LinkedList<>());
              }
    
              int m = Integer.parseInt(br.readLine());
              for (int i = 0; i < m; i++) {
                  st = new StringTokenizer(br.readLine());
                  String lower = st.nextToken();
                  String upper = st.nextToken();
    
                  childTree.get(upper).add(lower);
                  ancestorTree.get(lower).add(upper);
              }
    
              //childTree.forEach((s, strings) -> strings.forEach(s1 -> childTree.get(s1).forEach(s2 -> removeList.get(s).add(s2))));
    
              for(Map.Entry<String, HashSet<String>> s : childTree.entrySet()) {
                  //String name = s.getKey();
                  HashSet<String> children = s.getValue();
                  HashSet<String> removeList = new HashSet<>();
    
                  for(String child : children) {
                      removeList.addAll(childTree.get(child));
                  }
    
                  for(String target : removeList) {
                      children.remove(target);
                  }
              }
    
              ArrayList<String> roots = new ArrayList<>();
              for(Map.Entry<String, LinkedList<String>> s : ancestorTree.entrySet()) {
                  if(s.getValue().size() == 0) {
                      roots.add(s.getKey());
                  }
              }
              bw.write(Integer.toString(roots.size()));
              bw.newLine();
              Collections.sort(roots);
              for(String rName : roots) {
                  bw.write(rName + " ");
              }
              bw.newLine();
              Collections.sort(names);
              for (String name : names) {
                  Object[] children = childTree.get(name).toArray();
                  Arrays.sort(children);
                  bw.write(name + " " + children.length + " ");
                  for (Object child : children) {
                      bw.write(child + " ");
                  }
                  bw.newLine();
              }
              bw.flush();
          }
      }
    

문제점

  • 자식노드의 모든 자식 노드들을 순회하며 제거 리스트에 추가하는 것은 시간을 많이 소모된다. O(n^3)
  • 자식과 부모를 기록한 트리를 HashMap을 사용해서 해싱에 시간을 많이 소모된다.
  • 노드의 자식 노드들을 기록한 LinkedList를 순회하며 일일이 삭제하는 것은 시간이 많이 소요된다.

옳은 풀이

  1. 전체 트리의 노드를 순차적으로 방문하여 큐에 조상 노드가 없는 노드를 넣는다.
  2. 큐에서 노드를 꺼낸다.
  3. 방문한 트리의 노드의 자식 노드를 순회하며 자식 노드의 조상 노드의 개수(미리 기록된)를 하나씩 줄여나간다. (조상 노드들과 연결을 끊는다.)
  4. 자식 노드들 중 조상이 된 노드를 큐에 넣고 2.의 큐에서 꺼낸 노드의 자식 노드로 확정한다.
  5. 큐가 비워질 때 까지 반복한다.

문제에서 입력으로 준 이름들을 숫자로 변형하여 문제를 풀어야 한다. 이를 통해 해쉬 함수의 동작 횟수를 줄일 수 있다.

  • 코드

      package contest.java.p02;
    
      import java.io.BufferedReader;
      import java.io.BufferedWriter;
      import java.io.InputStreamReader;
      import java.io.OutputStreamWriter;
      import java.util.*;
    
      public class Main2 {
          private static int N, M;
          private static HashMap<String, Integer> name2Int = new HashMap<>();
          private static ArrayList<String> int2String = new ArrayList<>();
    
          public static void main(String[] args) throws Exception {
              BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
              BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
    
              N = Integer.parseInt(br.readLine());
              StringTokenizer st = new StringTokenizer(br.readLine());
              while (st.hasMoreTokens()) int2String.add(st.nextToken());
              Collections.sort(int2String);
              for (int i = 0; i < N; i++)
                  name2Int.put(int2String.get(i), i);
    
              ArrayList<ArrayList<Integer>> ret = new ArrayList<>();
              ArrayList<Integer> rootList = new ArrayList<>();
              int[] ancestors = new int[N];
              ArrayList<ArrayList<Integer>> children = new ArrayList<>();
              for (int i = 0; i < N; i++) {
                  children.add(new ArrayList<>());
                  ret.add(new ArrayList<>());
              }
    
              M = Integer.parseInt(br.readLine());
              for (int i = 0; i < M; i++) {
                  st = new StringTokenizer(br.readLine());
                  int lower = name2Int.get(st.nextToken());
                  int upper = name2Int.get(st.nextToken());
                  ancestors[lower]++;
                  children.get(upper).add(lower);
              }
    
              Queue<Integer> q = new LinkedList<>();
              for (int i = 0; i < N; i++) {
                  if (ancestors[i] == 0) {
                      q.add(i);
                      rootList.add(i);
                  }
              }
    
              while (!q.isEmpty()) {
                  Integer top = q.poll();
                  for (int child : children.get(top)) {
                      ancestors[child]--;
                      if (ancestors[child] == 0) {
                          q.add(child);
                          ret.get(top).add(child);
                      }
                  }
              }
    
              bw.write((rootList.size()) + "\n");
              for (int root : rootList) {
                  bw.write(int2String.get(root) + " ");
              }
              bw.newLine();
    
              for (int i = 0; i < N; i++) {
                  bw.write(int2String.get(i) + " ");
                  ArrayList<Integer> t = ret.get(i);
                  Collections.sort(t);
                  bw.write(t.size() + " ");
                  for (int n : t) bw.write(int2String.get(n) + " ");
                  bw.newLine();
              }
    
              bw.flush();
          }
      }
    

댓글남기기