local service_name = ARGV local consumer_name = ARGV local CONSUMER_TTL = ARGV

– Performs deep equality between two tables local function table_eq(table1, table2)

local avoid_loops = {}
local function recurse(t1, t2)
 -- compare value types
 if type(t1) ~= type(t2) then return false end
 -- Base case: compare simple values
 if type(t1) ~= "table" then return t1 == t2 end
 -- Now, on to tables.
 -- First, let's avoid looping forever.
 if avoid_loops[t1] then return avoid_loops[t1] == t2 end
 avoid_loops[t1] = t2
 -- Copy keys from t2
 local t2keys = {}
 local t2tablekeys = {}
 for k, _ in pairs(t2) do
    if type(k) == "table" then table.insert(t2tablekeys, k) end
    t2keys[k] = true
 end
 -- Let's iterate keys from t1
 for k1, v1 in pairs(t1) do
    local v2 = t2[k1]
    if type(k1) == "table" then
       -- if key is a table, we need to find an equivalent one.
       local ok = false
       for i, tk in ipairs(t2tablekeys) do
          if table_eq(k1, tk) and recurse(v1, t2[tk]) then
             table.remove(t2tablekeys, i)
             t2keys[tk] = nil
             ok = true
             break
          end
       end
       if not ok then return false end
    else
       -- t1 has a key which t2 doesn't have, fail.
       if v2 == nil then return false end
       t2keys[k1] = nil
       if not recurse(v1, v2) then return false end
    end
 end
 -- if t2 has a key which t1 doesn't have, fail.
 if next(t2keys) then return false end
 return true
end
return recurse(table1, table2)

end

local function distribute(consumers, partition_count)

local distribution   = {}
local consumer_count = table.getn(consumers)
local remainder      = partition_count % consumer_count

for i=1,consumer_count do
  distribution[i] = math.floor(partition_count/consumer_count)
end

for i=1,remainder do
  distribution[i] = distribution[i] + 1
end

return distribution

end

local function getdesiredstate(service_name, consumers, partition_count)

local state                    = {}
local distribution             = distribute(consumers, partition_count)
local consumer_count           = table.getn(consumers)
local assigned_partition_count = 0

for i=1,consumer_count do
  state[consumers[i]] = {}

  for j=1,distribution[i] do
    table.insert(state[consumers[i]], 1, service_name .. ":" .. j + assigned_partition_count - 1)
  end

  assigned_partition_count = assigned_partition_count + distribution[i]
end

return state

end

local function getcurrentstate(service_name, consumers)

local state = {}

for _, consumer in ipairs(consumers) do
  local assigned_key = service_name .. ":" .. consumer .. ":assigned"
  state[consumer] = redis.call("LRANGE", assigned_key, 0, -1)
end

return state

end

local function states_match(state1, state2)

return table_eq(state1, state2)

end

local function all_free(workers)

local total_count = 0

for _, partitions in pairs(workers) do
  total_count = total_count + table.getn(partitions)
end

return total_count == 0

end

local function save_state(service_name, state)

for worker, partitions in pairs(state) do
  for _, partition in ipairs(partitions) do
    redis.call("RPUSH", service_name .. ":" .. worker .. ":assigned", partition)
    redis.call("EXPIRE", service_name .. ":" .. worker .. ":assigned", CONSUMER_TTL)
  end
end

end

local function rebalance(service_name, consumer_name)

local consumers = redis.call("SMEMBERS", service_name .. ":consumers")
table.sort(consumers)
local partition_count = tonumber(redis.call("GET", service_name .. ":partition_count"))

local desired_state = getdesiredstate(service_name, consumers, partition_count)

local current_state = getcurrentstate(service_name, consumers)

local is_stable_state = states_match(desired_state, current_state)

if not is_stable_state then
  if all_free(current_state) then
    save_state(service_name, desired_state)

    return desired_state[consumer_name]
  else
    redis.call("DEL", service_name .. ":" .. consumer_name .. ":assigned")
    return {}
  end
else
  return desired_state[consumer_name]
end

end

return rebalance(service_name, consumer_name)