Module:memoize: Difference between revisions

From Linguifex
Jump to navigation Jump to search
(Created page with "local format = string.format local select = select local unpack = unpack ----- M E M O I Z A T I O N----- -- Memoizes a function or callable table. -- Supports any number of arguments and return values. -- If the optional parameter `simple` is set, then the memoizer will use a faster implementation, but this is only compatible with one argument and one return value. If `simple` is set, additional arguments will be accepted, but this should only be done if those argument...")
 
(Optimisations.)
Line 1: Line 1:
local format = string.format
local math_module = "Module:math"
local table_pack_module = "Module:table/pack"
 
local require = require
local select = select
local select = select
local unpack = unpack
local unpack = unpack or table.unpack -- Lua 5.2 compatibility
 
-- table.pack: in Lua 5.2+, this is a function that wraps the parameters given
-- into a table with the additional key `n` that contains the total number of
-- parameters given. This is not available on Lua 5.1, so [[Module:table/pack]]
-- provides the same functionality.
local function pack(...)
pack = require(table_pack_module)
return pack(...)
end
 
local function sign(...)
sign = require(math_module).sign
return sign(...)
end


----- M E M O I Z A T I O N-----
----- M E M O I Z A T I O N-----
Line 9: Line 26:


-- Sentinels.
-- Sentinels.
local nil_, neg_0, pos_nan, neg_nan
local _nil, neg_0, pos_nan, neg_nan = {}, {}, {}, {}


-- Certain values can't be used as table keys, so they require sentinels as well: e.g. f("foo", nil, "bar") would be memoized at memo["foo"][nil_]["bar"][memo]. These values are:
-- Certain values can't be used as table keys, so they require sentinels as well: e.g. f("foo", nil, "bar") would be memoized at memo["foo"][_nil]["bar"][memo]. These values are:
-- nil.
-- nil.
-- -0, which is equivalent to 0 in most situations, but becomes "-0" on conversion to string; it also behaves differently in some operations (e.g. 1/a evaluates to inf if a is 0, but -inf if a is -0).
-- -0, which is equivalent to 0 in most situations, but becomes "-0" on conversion to string; it also behaves differently in some operations (e.g. 1/a evaluates to inf if a is 0, but -inf if a is -0).
-- NaN and -NaN, which are the only values for which n == n is false; they only seem to differ on conversion to string ("nan" and "-nan").
-- NaN and -NaN, which are the only values for which n == n is false; they only seem to differ on conversion to string ("nan" and "-nan").
local function get_key(input)
local function get_key(x)
-- nil
if x == x then
if input == nil then
return x == nil and _nil or x == 0 and 1 / x < 0 and neg_0 or x
if not nil_ then
nil_ = {}
end
return nil_
-- -0
elseif input == 0 and 1 / input < 0 then
if not neg_0 then
neg_0 = {}
end
return neg_0
-- Default
elseif input == input then
return input
-- NaN
elseif format("%f", input) == "nan" then
if not pos_nan then
pos_nan = {}
end
return pos_nan
-- -NaN
elseif not neg_nan then
neg_nan = {}
end
end
return neg_nan
return sign(x) == 1 and pos_nan or neg_nan
end
end


Line 56: Line 51:
end
end


-- Catch the function output values, and return the hidden variable arg (which is {...}, and available when a function has ...). We do this instead of catching the output in a table directly, because arg also contains the key "n", which is equal to select("#", ...). i.e. it's the number of arguments in ..., including any nils returned after the last non-nil value (e.g. select("#", nil) == 1, select("#") == 0, select("#", nil, "foo", nil, nil) == 4 etc.). The distinction between nil and nothing affects some native functions (e.g. tostring() throws an error, but tostring(nil) returns "nil"), so it needs to be reconstructable from the memo.
-- Used to catch the function output values instead of using a table directly,
local function catch_output(...)
-- since pack() returns a table with the key `n`, giving the number of return
-- TODO uses arg; will not work if Scribunto is upgraded to Lua 5.2, 5.3, etc.
-- values, even if they are nil. This ensures that any nil return values after
return arg
-- the last non-nil value will always be present (e.g. pack() gives {n = 0},
-- pack(nil) gives {n = 1}, pack(nil, "foo", nil) gives {[2] = "foo", n = 3}
-- etc.). The distinction between nil and nothing affects some native functions
-- (e.g. tostring() throws an error, but tostring(nil) returns "nil"), so it
-- needs to be reconstructable from the memo.
local function memoize_then_return(memo, _memo, ...)
_memo[memo] = pack(...)
return ...
end
end


return function(func, simple)
return function(func, simple)
local memo
local memo = {}
return simple and function(...)
 
local key = get_key(...)
if simple then
if not memo then
 
memo = {}
return function(...)
end
local key = get_key((...))
local output = memo[key]
local output = memo[key]
if output == nil then
if output == nil then
output = func(...)
output = func(...)
if output ~= nil then
memo[key] = output == nil and _nil or output
memo[key] = output
return output
return output
elseif not nil_ then
elseif output == _nil then
nil_ = {}
return nil
end
end
memo[key] = nil_
return output
return nil
elseif output == nil_ then
return nil
end
end
return output
 
end or function(...)
end
 
return function(...)
local nargs = select("#", ...)
local nargs = select("#", ...)
if not memo then
-- Since all possible inputs need to be memoized (including true, false
memo = {}
-- and nil), the memo table itself is used as a sentinel to ensure that
end
-- the table of arguments will always have a unique key.
-- Since all possible inputs need to be memoized (including true, false and nil), the memo table itself is used as the key for the arguments.
local _memo = nargs == 0 and memo or get_memo(memo, 1, nargs, ...)
local _memo = nargs == 0 and memo or get_memo(memo, 1, nargs, ...)
local output = _memo[memo]
local output = _memo[memo]
-- If get_memo() returned nil, call `func` with the arguments and catch
-- the output with memoize_then_return(); this packs the return values
-- into a table to memoize them, then returns them. Since the return
-- values are available to it as `...`, this avoids the need to call
-- unpack() on the memoized table on the first call, as they can be
-- returned directly.
if output == nil then
if output == nil then
output = catch_output(func(...))
return memoize_then_return(memo, _memo, func(...))
_memo[memo] = output
end
end
-- Unpack from 1 to the original number of return values (memoized as output.n); unpack returns nil for any values not in output.
-- Unpack from 1 to the original number of return values (memoized at
-- key `n`); unpack() returns nil for any values not in output.
return unpack(output, 1, output.n)
return unpack(output, 1, output.n)
end
end
end
end

Revision as of 02:18, 12 May 2025

Documentation for this module may be created at Module:memoize/doc

local math_module = "Module:math"
local table_pack_module = "Module:table/pack"

local require = require
local select = select
local unpack = unpack or table.unpack -- Lua 5.2 compatibility

-- table.pack: in Lua 5.2+, this is a function that wraps the parameters given
-- into a table with the additional key `n` that contains the total number of
-- parameters given. This is not available on Lua 5.1, so [[Module:table/pack]]
-- provides the same functionality.
local function pack(...)
	pack = require(table_pack_module)
	return pack(...)
end

local function sign(...)
	sign = require(math_module).sign
	return sign(...)
end

----- M E M O I Z A T I O N-----
-- Memoizes a function or callable table.
-- Supports any number of arguments and return values.
-- If the optional parameter `simple` is set, then the memoizer will use a faster implementation, but this is only compatible with one argument and one return value. If `simple` is set, additional arguments will be accepted, but this should only be done if those arguments will always be the same.

-- Sentinels.
local _nil, neg_0, pos_nan, neg_nan = {}, {}, {}, {}

-- Certain values can't be used as table keys, so they require sentinels as well: e.g. f("foo", nil, "bar") would be memoized at memo["foo"][_nil]["bar"][memo]. These values are:
	-- nil.
	-- -0, which is equivalent to 0 in most situations, but becomes "-0" on conversion to string; it also behaves differently in some operations (e.g. 1/a evaluates to inf if a is 0, but -inf if a is -0).
	-- NaN and -NaN, which are the only values for which n == n is false; they only seem to differ on conversion to string ("nan" and "-nan").
local function get_key(x)
	if x == x then
		return x == nil and _nil or x == 0 and 1 / x < 0 and neg_0 or x
	end
	return sign(x) == 1 and pos_nan or neg_nan
end

-- Return values are memoized as tables of return values, which are looked up using each input argument as a key, followed by `memo`. e.g. if the input arguments were (1, 2, 3), the memo would be located at t[1][2][3][memo]. `memo` is always used as the final lookup key so that (for example) the memo for f(1, 2, 3), f[1][2][3][memo], doesn't interfere with the memo for f(1, 2), f[1][2][memo].
local function get_memo(memo, n, nargs, key, ...)
	key = get_key(key)
	local next_memo = memo[key]
	if next_memo == nil then
		next_memo = {}
		memo[key] = next_memo
	end
	memo = next_memo
	return n == nargs and memo or get_memo(memo, n + 1, nargs, ...)
end

-- Used to catch the function output values instead of using a table directly,
-- since pack() returns a table with the key `n`, giving the number of return
-- values, even if they are nil. This ensures that any nil return values after
-- the last non-nil value will always be present (e.g. pack() gives {n = 0},
-- pack(nil) gives {n = 1}, pack(nil, "foo", nil) gives {[2] = "foo", n = 3}
-- etc.). The distinction between nil and nothing affects some native functions
-- (e.g. tostring() throws an error, but tostring(nil) returns "nil"), so it
-- needs to be reconstructable from the memo.
local function memoize_then_return(memo, _memo, ...)
	_memo[memo] = pack(...)
	return ...
end

return function(func, simple)
	local memo = {}

	if simple then

		return function(...)
			local key = get_key((...))
			local output = memo[key]
			if output == nil then
				output = func(...)
				memo[key] = output == nil and _nil or output
				return output
			elseif output == _nil then
				return nil
			end
			return output
		end

	end

	return function(...)
		local nargs = select("#", ...)
		-- Since all possible inputs need to be memoized (including true, false
		-- and nil), the memo table itself is used as a sentinel to ensure that
		-- the table of arguments will always have a unique key.
		local _memo = nargs == 0 and memo or get_memo(memo, 1, nargs, ...)
		local output = _memo[memo]
		-- If get_memo() returned nil, call `func` with the arguments and catch
		-- the output with memoize_then_return(); this packs the return values
		-- into a table to memoize them, then returns them. Since the return
		-- values are available to it as `...`, this avoids the need to call
		-- unpack() on the memoized table on the first call, as they can be
		-- returned directly.
		if output == nil then
			return memoize_then_return(memo, _memo, func(...))
		end
		-- Unpack from 1 to the original number of return values (memoized at
		-- key `n`); unpack() returns nil for any values not in output.
		return unpack(output, 1, output.n)
	end
end