47,637
edits
(Created page with "local export = {} local ustring = mw.ustring local libraryUtil = require "libraryUtil" local checkType = libraryUtil.checkType local checkTypeMulti = libraryUtil.checkTypeMul...") |
No edit summary |
||
Line 1: | Line 1: | ||
local export = {} | local export = {} | ||
local libraryUtil = require("libraryUtil") | |||
local libraryUtil = require "libraryUtil" | |||
local checkType = libraryUtil.checkType | local checkType = libraryUtil.checkType | ||
local checkTypeMulti = libraryUtil.checkTypeMulti | local checkTypeMulti = libraryUtil.checkTypeMulti | ||
local format = string.format | |||
local getmetatable = getmetatable | |||
local ipairs = ipairs | |||
local is_callable -- defined as export.is_callable below | |||
local pairs = pairs | |||
local select = select | |||
local tostring = tostring | |||
local type = type | |||
local unpack = unpack | |||
local iterableTypes = { "table", "string" } | local iterableTypes = { "table", "string" } | ||
Line 28: | Line 37: | ||
-- Iterate over UTF-8-encoded codepoints in string. | -- Iterate over UTF-8-encoded codepoints in string. | ||
local function iterString(str) | local function iterString(str) | ||
local iter = string.gmatch(str, " | local iter = string.gmatch(str, ".[\128-\191]*") | ||
local i = 0 | local i = 0 | ||
local function iterator() | local function iterator() | ||
Line 40: | Line 49: | ||
return iterator | return iterator | ||
end | end | ||
--[==[ | |||
Return {true} if the input is a function or functor (a table which can be called like a function, because it has a {__call} metamethod). | |||
]==] | |||
function export.is_callable(f) | |||
local f_type = type(f) | |||
if f_type == "function" then | |||
return true | |||
elseif f_type ~= "table" then | |||
return false | |||
end | |||
local mt = getmetatable(f) | |||
-- __call metamethods have to be functions, not functors. | |||
return mt and type(mt.__call) == "function" or false | |||
end | |||
is_callable = export.is_callable | |||
function export.chain(func1, func2, ...) | function export.chain(func1, func2, ...) | ||
Line 243: | Line 268: | ||
----- M E M O I Z A T I O N----- | ----- M E M O I Z A T I O N----- | ||
-- | -- Memoizes a function or callable table. | ||
-- | -- Supports any number of arguments and return values. | ||
local | -- If the optional parameter `simple` is set, then the memoizer will use a faster implementation, but this is only compatible with one argument and one return value. If `simple` is set, additional arguments will be accepted, but this should only be done if those arguments will always be the same. | ||
local function | do | ||
local output = | -- Placeholders. | ||
local args, nil_, pos_nan, neg_nan, neg_0 | |||
output = | |||
-- Certain potential argument values can't be used as table keys, so we use placeholders for them instead: e.g. f("foo", nil, "bar") would be memoized at f["foo"][nil_]["bar"][args]. These values are: | |||
-- nil. | |||
-- -0, which is equivalent to 0 in most situations, but becomes "-0" on conversion to string; it also behaves differently in some operations (e.g. 1/a evaluates to inf if a is 0, but -inf if a is -0). | |||
-- NaN and -NaN, which are the only values for which n == n is false; they only seem to differ on conversion to string ("nan" and "-nan"). | |||
local function get_key(input) | |||
-- nil | |||
if input == nil then | |||
if not nil_ then | |||
nil_ = {} | |||
end | |||
return nil_ | |||
-- -0 | |||
elseif input == 0 and 1 / input < 0 then | |||
if not neg_0 then | |||
neg_0 = {} | |||
end | |||
return neg_0 | |||
-- Default | |||
elseif input == input then | |||
return input | |||
-- NaN | |||
elseif format("%f", input) == "nan" then | |||
if not pos_nan then | |||
pos_nan = {} | |||
end | |||
return pos_nan | |||
-- -NaN | |||
elseif not neg_nan then | |||
neg_nan = {} | |||
end | |||
return neg_nan | |||
end | |||
-- Return values are memoized as tables of return values, which are looked up using each input argument as a key, followed by args. e.g. if the input arguments were (1, 2, 3), the memo would be located at t[1][2][3][args]. args is always used as the final lookup key so that (for example) the memo for f(1, 2, 3), f[1][2][3][args], doesn't interfere with the memo for f(1, 2), f[1][2][args]. | |||
local function get_memo(memo, n, nargs, key, ...) | |||
key = get_key(key) | |||
local next_memo = memo[key] | |||
if next_memo == nil then | |||
next_memo = {} | |||
memo[key] = next_memo | |||
end | |||
memo = next_memo | |||
return n == nargs and memo or get_memo(memo, n + 1, nargs, ...) | |||
end | |||
-- Catch the function output values, and return the hidden variable arg (which is {...}, and available when a function has ...). We do this instead of catching the output in a table directly, because arg also contains the key "n", which is equal to select("#", ...). i.e. it's the number of arguments in ..., including any nils returned after the last non-nil value (e.g. select("#", nil) == 1, select("#") == 0, select("#", nil, "foo", nil, nil) == 4 etc.). The distinction between nil and nothing affects some native functions (e.g. tostring() throws an error, but tostring(nil) returns "nil"), so it needs to be reconstructable from the memo. | |||
local function catch_output(...) | |||
return arg | |||
end | |||
function export.memoize(func, simple) | |||
if not is_callable(func) then | |||
local _type = type(func) | |||
error(format( | |||
"Only functions and callable tables are memoizable. Received %s.", | |||
_type == "table" and "non-callable table" or _type | |||
)) | |||
end | |||
local memo = {} | |||
return simple and function(...) | |||
local key = get_key(...) | |||
local output = memo[key] | |||
if output ~= nil then | |||
if output == nil_ then | |||
return nil | |||
end | |||
return output | |||
end | |||
output = func(...) | |||
if output ~= nil then | |||
memo[key] = output | |||
return output | |||
elseif not nil_ then | |||
nil_ = {} | |||
end | |||
memo[key] = nil_ | |||
return nil | |||
end or function(...) | |||
local nargs = select("#", ...) | |||
local memo = nargs == 0 and memo or get_memo(memo, 1, nargs, ...) | |||
if not args then | |||
args = {} | |||
end | |||
local output = memo[args] | |||
if output == nil then | |||
output = catch_output(func(...)) | |||
memo[args] = output | |||
end | |||
-- Unpack from 1 to the original number of return values (memoized as output.n); unpack returns nil for any values not in output. | |||
return unpack(output, 1, output.n) | |||
end | |||
end | end | ||
end | end | ||
return export | return export |