Saturday, November 3, 2018

K-Way Merge

I was recently given a programming test which basically went something like this. You've got a large amount of items, way more than you can fit in memory, and you need to get them sorted. In this case we have a team or clan that is given a rank and we want to create a leaderboard of all clans globally; there are millions.

With a small list this would be pretty easy, just read them all into memory, sort them and output the list, however they won't all fit.  The answer I eventually came up with is a K-way merge.  When I was doing research on the topic I found very few actual examples of the topic so I thought I'd document an example solution here.

I'm positive this solution is a bit simplistic but perhaps it'll give some others direction and is a useful reminder to me on how it works.  Given that the limiting resource is available memory, the first thing to do is grab a bunch of unsorted records, something that fits in ram and sort them.  Once that's been done the objects are sent to a holding area, in my case a file, into multiple buckets, and that whole set is called a stack.  The reason for multiple buckets is because we need to be able to fit one bucket from all available stacks back into memory when we do the final sort.

Once the stacks are created, the first bucket from each is read with the lowest value object placed into our final list.  Again, in my case this was a file output stream to keep things simple.  When a bucket empties it get's refilled with the remaining buckets in the stack until the entire stack is empty.


  1. A set of object is read from the pool and then sorted
  2. The sorted list is split into buckets and saved out to files
  3. The first bucket from all stacks is read back
  4. The lowest ranked objects are pulled off
  5. As buckets empty they are refilled with remaining buckets in the stack
Here's a method to collect the initial records, sort them (with a very simplistic collections.sort) and output them into stacks and buckets.


/**
 * Collect all the clan records available and divide them up into stacks to store on disk
 * e.g. totalRecords = 1000, batchRecords = 200, numberOfBuckets = 2
 *   means we'll get 200 records from the database at a time, called a batch, and have 2 segments for each batch each with 100 records.  This would result in 5 stacks (1000/200=5). 
 * e.g. totalRecords = 100, batchRecords = 50, numberOfBuckets = 5
 *   means we'll get 50 records at a time, with 5 segments each having 10 records.  This would result in 2 stacks
 * Each batch is sorted and then output across multiple stacks
 * output file has the format outputFile[stack]_[bucketRangeStart]_[bucketRangeStop]
 * e.g. totalRecords = 100, batchRecords = 50, numberOfBuckets = 5
 *   from the first batch - 10 records  to outputFile1_1_10, next 10 records will go to outputFile1_11_20, next 10 records to outputFile1_21_30, etc
 *   from the second batch - 10 records to outputFile2_1_10, next 10 to outputFile2_11_20, etc 
 * This way, the lowest numbers from each stack can be compared against each other later
 */

private void getClanRecords() {
 while (retrievedRecords < totalRecords) {
  // collect the data and put it into a sorted set
  ArrayList clanList = getClanData(batchRecords); // go get 50 records
  clanList.sort(new ClanComparator());
  int fileIndexStart = 1;
  try {
   for (int i = 0; i < numberOfBuckets; i++) {
    System.out.println("Writing out stack = "+stackCounter+", startIndex = "+fileIndexStart+", endIndex = "+(fileIndexStart+recordsPerBucket-1));
    String fileName = "outputFile"+stackCounter+"_"+fileIndexStart+"_"+(fileIndexStart+recordsPerBucket-1);
    ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(fileName));
    ArrayList subList = new ArrayList(clanList.subList(fileIndexStart-1, (fileIndexStart+recordsPerBucket-1)));
    oos.writeObject(subList);
    System.out.println("wrote out "+subList.size()+" records");
    oos.close();
    fileIndexStart+=recordsPerBucket; 
   }
  } catch (Exception e) {
   System.out.println("Yea, we had an error writing out our buckets - "+e);
   System.exit(1);
  }
  stackCounter+=1;
 }
 bucketRangeStart = new int[stackCounter];
 bucketRangeStop = new int[stackCounter];
 Arrays.fill(bucketRangeStart, 1);
 Arrays.fill(bucketRangeStop, 0);
}

Once the records have been initially sorted, we sort across the stacks and refill buckets as required


/**
 * The second level sorting, this method will compare records across all the available stacks (which are already sorted)
 * removing the lowest one to the final output and updating the stack with the next set of records when it runs out
 * @param clanMaps an initial set of clans where  is the stack number
 */

private void sortList() {
 // collect our first set of stacks to start the process
 HashMap> clanMaps = new HashMap>();
 for (int stack = 1; stack < stackCounter; stack++) {
  clanMaps.put(stack, getNextBlock(stack));
 }

 // setup our final output stream to disk
 ObjectOutputStream oos = null;
 try {
  BufferedOutputStream bufferedStream = new BufferedOutputStream(new FileOutputStream(finalFileList));
  oos = new ObjectOutputStream(bufferedStream);
 } catch (Exception e) {
  System.out.println("Unable to open final output stream - "+e);
 }

 ArrayList finalList = new ArrayList();
 long finalRecordCounter = 0;
 
 // process all of our expected records; this should be smarter as it can end up in an infinite loop with bad data
 while (finalRecordCounter < totalRecords) {
  double smallestNumber = 200.0;
  //System.out.println("final record counter = "+finalRecordCounter);
  int foundStack = -1;
  for (int stackNumber : clanMaps.keySet()) {
   ArrayList s = clanMaps.get(stackNumber);
   if (s.size() == 0) {
    System.out.println("ran out of records for stack "+stackNumber+", getting next block");
    s = getNextBlock(stackNumber);
    clanMaps.put(stackNumber,s);
   }
   if (s.size() != 0 && s.get(0).getClanRank() < smallestNumber) {
    foundStack = stackNumber;
    smallestNumber = s.get(0).getClanRank();
   }
  }
  // assuming we found the stack with the lowest value, move it to finalList and when enough records are found, write that out to disk
  if (foundStack > -1) {
   //System.out.println("Found record from stack "+foundStack); // this is kind of fun to see which stack it's pulling records from 
   finalList.add(clanMaps.get(foundStack).get(0));
   try {
    oos.writeObject(clanMaps.get(foundStack).get(0));
    if (finalList.size() > batchRecords) {
    System.out.println("Flushing records to disk...");
     oos.flush();
     finalList.clear();
    }
   } catch (IOException e) {
    System.out.println("Unable to write records out to the finalList - "+e);
   }
   // now that we've collected the lowest item, remove it from the stack
   clanMaps.get(foundStack).remove(0);
   finalRecordCounter++;
  }
  
 }
 try {
  oos.close();
 } catch (IOException e) {
  System.out.println("Unable to close finalList file, this will result in partial or no final data - "+e);
 }
 displayFinalRecords();
}

/**
 * Retrieves the next available set of Clan items in a stack
 * It does this by keeping track of the start and end range for each stack and increments them as records are retrieved
 * @param stackNumber the stack to retrieve records from
 * @return an ArrayList of clan objects or an empty list if no more are available
 */

@SuppressWarnings("unchecked")
private ArrayList getNextBlock(int stackNumber) {
 ArrayList clanList = new ArrayList();
 //blockRangeStop[blockCounter]+=(blockRangeStart[blockCounter]+(recordsPerBlock-1));
 bucketRangeStop[stackNumber]+=(recordsPerBucket);
 System.out.println("reading block "+stackNumber+" bucketRangeStart = "+bucketRangeStart[stackNumber]+", bucketRangeStop = "+bucketRangeStop[stackNumber]+", recordsPerBucket = "+recordsPerBucket);
 if (bucketRangeStop[stackNumber] > batchRecords) {
  System.out.println("no more records for this block");
  return clanList;
 }
 try {
  String fileName = "outputFile"+(stackNumber)+"_"+bucketRangeStart[stackNumber]+"_"+(bucketRangeStart[stackNumber]+recordsPerBucket-1);
  System.out.println("Reading records for file "+fileName);
  ObjectInputStream ois = new ObjectInputStream(new FileInputStream(fileName));
  clanList = (ArrayList)ois.readObject();
  bucketRangeStart[stackNumber] = bucketRangeStop[stackNumber]+1;
  ois.close();
 } catch (Exception e) {
  System.out.println("Yea, we had an input file error; this is bad, quiting now - "+e);
  System.exit(1);
 }
 return clanList;
}
There are a few problems with this code
  • If the total records and number of buckets don't evenly divide, that is you have a remainder bucket, it'll fail
  • The initial sort is slow, so that should be replaced with something faster
  • This uses hard drive space to hold the intermediary and final lists, which might not work for your use case