Module:fun: Difference between revisions

2,072 bytes removed ,  8 January 2025
no edit summary
No edit summary
No edit summary
Line 1: Line 1:
local export = {}
local export = {}


local libraryUtil = require("libraryUtil")
local checkType = libraryUtil.checkType
local checkTypeMulti = libraryUtil.checkTypeMulti
local format = string.format
local getmetatable = getmetatable
local getmetatable = getmetatable
local gmatch = string.gmatch
local ipairs = ipairs
local ipairs = ipairs
local is_callable -- defined as export.is_callable below
local pairs = pairs
local pairs = pairs
local pcall = pcall
local rawequal = rawequal
local rawget = rawget
local select = select
local select = select
local setmetatable = setmetatable
local tostring = tostring
local tostring = tostring
local type = type
local type = type
local unpack = unpack
local unpack = unpack


local iterableTypes = { "table", "string" }
--[==[
Loaders for functions in other modules, which overwrite themselves with the target function when called. This ensures modules are only loaded when needed, retains the speed/convenience of locally-declared pre-loaded functions, and has no overhead after the first call, since the target functions are called directly in any subsequent calls.]=


local function _check(funcName, expectType)
local function _iterString(iter, i)
if type(expectType) == "string" then
i = i + 1
return function(argIndex, arg, nilOk)
local char = iter()
return checkType(funcName, argIndex, arg, expectType, nilOk)
if char ~= nil then
end
return i, char
else
return function(argIndex, arg, expectType, nilOk)
if type(expectType) == "table" then
if not (nilOk and arg == nil) then
return checkTypeMulti(funcName, argIndex, arg, expectType)
end
else
return checkType(funcName, argIndex, arg, expectType, nilOk)
end
end
end
end
end
end
Line 37: Line 27:
-- Iterate over UTF-8-encoded codepoints in string.
-- Iterate over UTF-8-encoded codepoints in string.
local function iterString(str)
local function iterString(str)
local iter = string.gmatch(str, ".[\128-\191]*")
return _iterString, gmatch(str, ".[\128-\191]*"), 0
local i = 0
local function iterator()
i = i + 1
local char = iter()
if char then
return i, char
end
end
return iterator
end
end


--[==[
--[==[
Return {true} if the input is a function or functor (a table which can be called like a function, because it has a {__call} metamethod).
Return {true} if the input is a function or functor (a table which can be called like a function, because it has a {__call} metamethod).
]==]
 
Note: if the input is a table with a protected metatable (i.e. one hidden using the `__metatable` metamethod), then this function will treat the value of `__metatable` as though it is the metatable, as that is what gets returned by `getmetatable` in such cases. If you are making use of the `__metatable` metamethod, make sure that `__metatable` is a table with a function at the `__call` key to ensure that this function returns the correct result; it does not matter if this function is the true `__call` metamethod.]==]
function export.is_callable(f)
function export.is_callable(f)
local f_type = type(f)
local f_type = type(f)
Line 60: Line 41:
return false
return false
end
end
-- A table is a functor if it has a `__call` metamethod. The only way to truly confirm this is by trying to call the table, but that could modify the table or other variables out of scope, so look for a `__call` metamethod instead. If the metatable is protected with `__metatable`, this may not be possible.
local mt = getmetatable(f)
local mt = getmetatable(f)
-- __call metamethods have to be functions, not functors.
if mt == nil then
return mt and type(mt.__call) == "function" or false
return false
end
-- Check if the metatable is protected: `setmetatable` will throw an error if so.
local success = pcall(setmetatable, f, mt)
-- If it's protected, then `mt` could be anything, but use the heuristic that if a `__call` key exists then that's probably intentional.
-- This also builds in ways to ensure that this function always returns the correct result when implementing protected metatables.
if not success then
if type(mt) ~= "table" then
return false
end
local __metatable = rawget(mt, "__metatable")
-- If the value of `__metatable` is also `mt`, then `mt` must be the true metatable anyway (e.g. mw.loadData does this).
end
local __call = rawget(mt, "__call")
-- `__call` metamethods have to be functions, so don't recurse when checking it.
return __call ~= nil and type(__call) == "function"
end
end
is_callable = export.is_callable


function export.chain(func1, func2, ...)
function export.chain(func1, func2, ...)
return func1(func2(...))
return func1(func2(...))
end
do
local function catch_values(start, iter, state, k, ...)
if start == k or k == nil then
return k, ...
end
return catch_values(start, iter, state, iter(state, k))
end
function export.iterateFrom(start, iter, state, k)
local first = true
return function(state, k)
if first then
first = false
return catch_values(start, iter, state, iter(state, k))
end
return iter(state, k)
end, state, k
end
end
end


Line 75: Line 91:
-- "abc") --> { "A", "B", "C" }
-- "abc") --> { "A", "B", "C" }
function export.map(func, iterable, isArray)
function export.map(func, iterable, isArray)
local check = _check 'map'
check(1, func, "function")
check(2, iterable, iterableTypes)
local array = {}
local array = {}
local iterator = type(iterable) == "string" and iterString
for k, v in (type(iterable) == "string" and iterString or (isArray or iterable[1] ~= nil) and ipairs or pairs)(iterable) do
or (isArray or iterable[1] ~= nil) and ipairs or pairs
array[k] = func(v, k, iterable)
for i_or_k, val in iterator(iterable) do
array[i_or_k] = func(val, i_or_k, iterable)
end
end
return array
return array
end
end


function export.mapIter(func, iter, iterable, initVal)
function export.mapIter(func, iter, state, init)
local check = _check 'mapIter'
-- init could be anything
check(1, func, "function")
local array, i = {}, 0
check(2, iter, "function")
for x, y in iter, state, init do
check(3, iterable, iterableTypes, true)
-- initVal could be anything
local array = {}
local i = 0
for x, y in iter, iterable, initVal do
i = i + 1
i = i + 1
array[i] = func(y, x, iterable)
array[i] = func(y, x, state)
end
end
return array
return array
end
do
local function iter_tuples(tuples)
local i = tuples.i
if i > 1 then
i = i - 1
tuples.i = i
return unpack(tuples[i])
end
end
-- Takes an iterator function, and returns a new iterator that iterates in reverse, given the same arguments.
-- Note: changes to the state during iteration are not taken into account, since all the return values are calculated in advance.
function export.reverseIter(func)
return function(...)
-- Store all returned values as a list of tuples, then iterate in reverse over that list.
local tuples, i, iter, state, val1 = {}, 0, func(...)
while true do
i = i + 1
local vals = {iter(state, val1)}
-- Terminates if the first return value is nil, even if other values are non-nil.
val1 = vals[1]
if val1 == nil then
tuples.i = i
return iter_tuples, tuples
end
tuples[i] = vals
end
end
end
end
end


function export.forEach(func, iterable, isArray)
function export.forEach(func, iterable, isArray)
local check = _check 'forEach'
for k, v in (type(iterable) == "string" and iterString or (isArray or iterable[1] ~= nil) and ipairs or pairs)(iterable) do
check(1, func, "function")
func(v, k, iterable)
check(2, iterable, iterableTypes)
local iterator = type(iterable) == "string" and iterString
or (isArray or iterable[1] ~= nil) and ipairs or pairs
for i_or_k, val in iterator(iterable) do
func(val, i_or_k, iterable)
end
end
return nil
return nil
Line 126: Line 154:
-- reverse args by building a function to do it, similar to the unpack() example
-- reverse args by building a function to do it, similar to the unpack() example
local function reverseHelper(acc, v, ...)
local function reverseHelper(acc, v, ...)
if select('#', ...) == 0 then
if select("#", ...) == 0 then
return v, acc()
return v, acc()
else
else
Line 166: Line 194:
-- { 2, 3, 5, 7, 11 }) --> true
-- { 2, 3, 5, 7, 11 }) --> true
function export.some(func, t, isArray)
function export.some(func, t, isArray)
if isArray or t[1] ~= nil then -- array
for k, v in ((isArray or t[1] ~= nil) and ipairs or pairs)(t) do
for i, v in ipairs(t) do
if func(v, k, t) then
if func(v, i, t) then
return true
return true
end
end
else
for k, v in pairs(t) do
if func(v, k, t) then
return true
end
end
end
end
end
Line 185: Line 205:
-- { 2, 4, 8, 10, 12 }) --> true
-- { 2, 4, 8, 10, 12 }) --> true
function export.all(func, t, isArray)
function export.all(func, t, isArray)
if isArray or t[1] ~= nil then -- array
for k, v in ((isArray or t[1] ~= nil) and ipairs or pairs)(t) do
for i, v in ipairs(t) do
if not func(v, k, t) then
if not func(v, i, t) then
return false
return false
end
end
else
for k, v in pairs(t) do
if not func(v, k, t) then
return false
end
end
end
end
end
Line 232: Line 244:
-- Fancy stuff
-- Fancy stuff
local function capture(...)
local function capture(...)
local vals = { n = select('#', ...), ... }
local vals = {n = select("#", ...), ...}
return function()
return function()
return unpack(vals, 1, vals.n)
return unpack(vals, 1, vals.n)
Line 265: Line 277:
end
end
return t
return t
end
----- M E M O I Z A T I O N-----
-- Memoizes a function or callable table.
-- Supports any number of arguments and return values.
-- If the optional parameter `simple` is set, then the memoizer will use a faster implementation, but this is only compatible with one argument and one return value. If `simple` is set, additional arguments will be accepted, but this should only be done if those arguments will always be the same.
do
-- Placeholders.
local args, nil_, pos_nan, neg_nan, neg_0
-- Certain potential argument values can't be used as table keys, so we use placeholders for them instead: e.g. f("foo", nil, "bar") would be memoized at f["foo"][nil_]["bar"][args]. These values are:
-- nil.
-- -0, which is equivalent to 0 in most situations, but becomes "-0" on conversion to string; it also behaves differently in some operations (e.g. 1/a evaluates to inf if a is 0, but -inf if a is -0).
-- NaN and -NaN, which are the only values for which n == n is false; they only seem to differ on conversion to string ("nan" and "-nan").
local function get_key(input)
-- nil
if input == nil then
if not nil_ then
nil_ = {}
end
return nil_
-- -0
elseif input == 0 and 1 / input < 0 then
if not neg_0 then
neg_0 = {}
end
return neg_0
-- Default
elseif input == input then
return input
-- NaN
elseif format("%f", input) == "nan" then
if not pos_nan then
pos_nan = {}
end
return pos_nan
-- -NaN
elseif not neg_nan then
neg_nan = {}
end
return neg_nan
end
-- Return values are memoized as tables of return values, which are looked up using each input argument as a key, followed by args. e.g. if the input arguments were (1, 2, 3), the memo would be located at t[1][2][3][args]. args is always used as the final lookup key so that (for example) the memo for f(1, 2, 3), f[1][2][3][args], doesn't interfere with the memo for f(1, 2), f[1][2][args].
local function get_memo(memo, n, nargs, key, ...)
key = get_key(key)
local next_memo = memo[key]
if next_memo == nil then
next_memo = {}
memo[key] = next_memo
end
memo = next_memo
return n == nargs and memo or get_memo(memo, n + 1, nargs, ...)
end
-- Catch the function output values, and return the hidden variable arg (which is {...}, and available when a function has ...). We do this instead of catching the output in a table directly, because arg also contains the key "n", which is equal to select("#", ...). i.e. it's the number of arguments in ..., including any nils returned after the last non-nil value (e.g. select("#", nil) == 1, select("#") == 0, select("#", nil, "foo", nil, nil) == 4 etc.). The distinction between nil and nothing affects some native functions (e.g. tostring() throws an error, but tostring(nil) returns "nil"), so it needs to be reconstructable from the memo.
local function catch_output(...)
return arg
end
function export.memoize(func, simple)
if not is_callable(func) then
local _type = type(func)
error(format(
"Only functions and callable tables are memoizable. Received %s.",
_type == "table" and "non-callable table" or _type
))
end
local memo = {}
return simple and function(...)
local key = get_key(...)
local output = memo[key]
if output ~= nil then
if output == nil_ then
return nil
end
return output
end
output = func(...)
if output ~= nil then
memo[key] = output
return output
elseif not nil_ then
nil_ = {}
end
memo[key] = nil_
return nil
end or function(...)
local nargs = select("#", ...)
local memo = nargs == 0 and memo or get_memo(memo, 1, nargs, ...)
if not args then
args = {}
end
local output = memo[args]
if output == nil then
output = catch_output(func(...))
memo[args] = output
end
-- Unpack from 1 to the original number of return values (memoized as output.n); unpack returns nil for any values not in output.
return unpack(output, 1, output.n)
end
end
end
end


return export
return export