Module:patterns

From Linguifex
Jump to navigation Jump to search

Documentation for this module may be created at Module:patterns/doc

local export = {}

local byte = string.byte
local find = string.find
local gsub = string.gsub
local match = string.match

do
	local pattern_chars
	local function get_pattern_chars()
		pattern_chars, get_pattern_chars = {
			["\0"] = "%z", ["$"] = "%$", ["%"] = "%%", ["("] = "%(", [")"] = "%)",
			["*"] = "%*", ["+"] = "%+", ["-"] = "%-", ["."] = "%.", ["?"] = "%?",
			["["] = "%[", ["]"] = "%]", ["^"] = "%^",
		}, nil
		return pattern_chars
	end

	--[==[Escapes the magic characters used in a pattern: {$%()*+-.?[]^}, and converts the null character to {%z}. For example, {"^$()%.[]*+-?\0"} becomes {"%^%$%(%)%%%.%[%]%*%+%-%?%z"}. This is necessary when constructing a pattern involving arbitrary text (e.g. from user input).]==]
	function export.pattern_escape(str)
		return (gsub(str, "[%z$%%()*+%-.?[%]^]", pattern_chars or get_pattern_chars()))
	end
end

do
	local charset_chars
	local function get_charset_chars()
		charset_chars, get_charset_chars = {
			["\0"] = "%z", ["%"] = "%%", ["-"] = "%-", ["]"] = "%]", ["^"] = "%^"
		}, nil
		return charset_chars
	end

	--[==[Escapes the magic characters used in pattern character sets: {%-]^}, and converts the null character to {%z}.]==]
	function export.charset_escape(str)
		return (gsub(str, "[%z%%%-%]^]", charset_chars or get_charset_chars()))
	end
end

--[==[Escapes {%}, which is the only magic character used in replacement strings, which are given as the third argument to {string.gsub} and {mw.ustring.gsub}.]==]
function export.replacement_escape(str)
	return (gsub(str, "%%", "%%%%"))
end

do
	local function parse_charset(pattern, pos)
		local nxt, ch = byte(pattern, pos)
		-- "^" indicates a negative charset, so the search begins from the next character.
		if nxt == 0x5E then -- ^
			pos = pos + 1
			nxt = byte(pattern, pos)
		end
		-- "]" is non-magic if it's the first character of a charset (including after "^"), so ignore it.
		if nxt == 0x5D then -- ]
			pos = pos + 1
		end
		repeat
			ch, pos = match(pattern, "([%%%]])()", pos)
			-- Escaping "%".
			if ch == "%" then
				pos = pos + 1
			-- End of charset.
			elseif ch == "]" then
				return pos
			end
		until not ch
		-- End of string throws an error, as the charset is incomplete.
		return false, "unclosed charset: must be closed with ']'"
	end
	
	local function _validate_pattern(pattern, str_lib)
		if pattern == "" then
			return true
		-- "\0" can be used in ustring patterns, and with string.find iff the `plain` flag is set.
		elseif str_lib == "string" and find(pattern, "\0", nil, true) then
			return false, "string library pattern cannot contain the null character '\\000'"
		elseif str_lib == "ustring" and #pattern > 10000 then
			return false, "ustring library pattern cannot be more than 10,000 bytes"
		end
		local pos, cap_open, cap_complete, ch = 1, 0, 0
		repeat
			ch, pos = match(pattern, "([%%()[])()", pos)
			-- Escaping "%".
			if ch == "%" then
				local nxt = byte(pattern, pos)
				-- Balanced string "%bxy".
				if nxt == 0x62 then -- b
					-- Must be followed by two characters, which are always treated as literals.
					if byte(pattern, pos + 2) == nil then
						return false, "incomplete balanced string: '%b' must be followed by two characters"
					end
					pos = pos + 3
				-- Frontier pattern "%f[abc]".
				elseif nxt == 0x66 then -- f
					if byte(pattern, pos + 1) ~= 0x5B then -- [
						return false, "incomplete frontier pattern: '%f' must be followed by a charset"
					end
					-- Charset after "%f".
					local result, err_msg = parse_charset(pattern, pos + 2)
					if not result then
						return false, err_msg
					end
					pos = result
				-- Back-reference to a complete capture group (e.g. "(foo)%1"); not possible to reference groups above "%9".
				elseif nxt >= 0x31 and nxt <= 0x39 then -- 1-9
					-- References to open, unstarted or undefined capture groups are invalid (e.g. "%1(foo)").
					local n = nxt - 0x30
					if n > cap_complete then
						return false, "invalid capture index '%" .. n .. "'"
					end
					pos = pos + 1
				-- "%0" is a reference to the full match, which can never be valid in patterns since it will always be incomplete; only valid in replacement strings.
				elseif nxt == 0x30 then -- 0
					return false, "invalid capture index '%0'"
				-- End of string throws an error, as the escape sequence is incomplete.
				elseif nxt == nil then
					return false, "incomplete escape sequence: final '%' must be followed by a character"
				else
					pos = pos + 1
				end
			-- New capture group.
			elseif ch == "(" then
				-- String library patterns cannot have more than 32 capture groups.
				if str_lib == "string" and cap_open + cap_complete >= 32 then
					return false, "string library pattern cannot contain more than 32 capture groups"
				end
				-- Increment the number of open groups.
				cap_open = cap_open + 1
			-- End of capture group.
			elseif ch == ")" then
				-- There must be at least one open group.
				if cap_open < 1 then
					return false, "cannot close a capture group with ')' when none are open"
				end
				-- Decrement the number of open groups, and increment the number of complete groups.
				cap_open, cap_complete = cap_open - 1, cap_complete + 1
			-- Charset "[abc]".
			elseif ch == "[" then
				local result, err_msg = parse_charset(pattern, pos)
				if not result then
					return false, err_msg
				end
				pos = result
			end
		until not ch
		-- End of string throws an error if any capture groups are open, as they are incomplete.
		if cap_open > 0 then
			return false, "unclosed capture group: must be closed with ')'"
		end
		return true
	end
	
	function export.validate_pattern(pattern, safe, str_lib)
		local ok, err_msg = _validate_pattern(pattern, str_lib or "string")
		if ok then
			return ok
		elseif safe then
			return ok, err_msg
		end
		error(err_msg)
	end
end

return export