Browse Source

Add API for restoring PseudoRandom and PcgRandom state (#14123)

sfence 3 months ago
parent
commit
ceaa7e2fb0

+ 1 - 0
builtin/game/features.lua

@@ -30,6 +30,7 @@ core.features = {
 	sound_params_start_time = true,
 	physics_overrides_v2 = true,
 	hud_def_type_field = true,
+	random_state_restore = true,
 }
 
 function core.has_feature(arg)

+ 9 - 1
doc/lua_api.md

@@ -5284,6 +5284,10 @@ Utilities
       physics_overrides_v2 = true,
       -- In HUD definitions the field `type` is used and `hud_elem_type` is deprecated (5.9.0)
       hud_def_type_field = true,
+      -- PseudoRandom and PcgRandom state is restorable
+      -- PseudoRandom has get_state method
+      -- PcgRandom has get_state and set_state methods (5.9.0)
+      random_state_restore = true,
   }
   ```
 
@@ -8056,7 +8060,7 @@ child will follow movement and rotation of that bone.
 * `get_lighting()`: returns the current state of lighting for the player.
     * Result is a table with the same fields as `light_definition` in `set_lighting`.
 * `respawn()`: Respawns the player using the same mechanism as the death screen,
-  including calling on_respawnplayer callbacks.
+  including calling `on_respawnplayer` callbacks.
 
 `PcgRandom`
 -----------
@@ -8079,6 +8083,8 @@ offering very strong randomness.
     * `mean = (max - min) / 2`, and
     * `variance = (((max - min + 1) ^ 2) - 1) / (12 * num_trials)`
     * Increasing `num_trials` improves accuracy of the approximation
+* `get_state()`: return generator state encoded in string
+* `set_state(state_string)`: restore generator state from encoded string
 
 `PerlinNoise`
 -------------
@@ -8171,6 +8177,8 @@ Uses a well-known LCG algorithm introduced by K&R.
 * `next(min, max)`: return next integer random number [`min`...`max`]
     * Either `max - min == 32767` or `max - min <= 6553` must be true
       due to the simple implementation making a bad distribution otherwise.
+* `get_state()`: return state of pseudorandom generator as number
+    * use returned number as seed in PseudoRandom constructor to restore
 
 `Raycast`
 ---------

+ 1 - 0
games/devtest/.luacheckrc

@@ -27,6 +27,7 @@ read_globals = {
 	"Settings",
 	"check",
 	"PseudoRandom",
+	"PcgRandom",
 
 	string = {fields = {"split", "trim"}},
 	table  = {fields = {"copy", "getn", "indexof", "insert_all"}},

+ 30 - 8
games/devtest/mods/unittests/misc.lua

@@ -1,15 +1,37 @@
-local function test_random()
+local function test_pseudo_random()
 	-- We have comprehensive unit tests in C++, this is just to make sure the API code isn't messing up
-	local pr = PseudoRandom(13)
-	assert(pr:next() == 22290)
-	assert(pr:next() == 13854)
+	local gen1 = PseudoRandom(13)
+	assert(gen1:next() == 22290)
+	assert(gen1:next() == 13854)
 
-	local pr2 = PseudoRandom(-101)
-	assert(pr2:next(0, 100) == 35)
+	local gen2 = PseudoRandom(gen1:get_state())
+	for n = 0, 16 do
+		assert(gen1:next() == gen2:next())
+	end
+
+	local pr3 = PseudoRandom(-101)
+	assert(pr3:next(0, 100) == 35)
 	-- unusual case that is normally disallowed:
-	assert(pr2:next(10000, 42767) == 12485)
+	assert(pr3:next(10000, 42767) == 12485)
+end
+unittests.register("test_pseudo_random", test_pseudo_random)
+
+local function test_pcg_random()
+	-- We have comprehensive unit tests in C++, this is just to make sure the API code isn't messing up
+	local gen1 = PcgRandom(55)
+
+	for n = 0, 16 do
+		gen1:next()
+	end
+
+	local gen2 = PcgRandom(26)
+	gen2:set_state(gen1:get_state())
+
+	for n = 16, 32 do
+		assert(gen1:next() == gen2:next())
+	end
 end
-unittests.register("test_random", test_random)
+unittests.register("test_pcg_random", test_pcg_random)
 
 local function test_dynamic_media(cb, player)
 	if core.get_player_information(player:get_player_name()).protocol_version < 40 then

+ 12 - 0
src/noise.cpp

@@ -152,6 +152,18 @@ s32 PcgRandom::randNormalDist(s32 min, s32 max, int num_trials)
 	return myround((float)accum / num_trials);
 }
 
+void PcgRandom::getState(u64 state[2]) const
+{
+	state[0] = m_state;
+	state[1] = m_inc;
+}
+
+void PcgRandom::setState(const u64 state[2])
+{
+  m_state = state[0];
+  m_inc = state[1];
+}
+
 ///////////////////////////////////////////////////////////////////////////////
 
 float noise2d(int x, int y, s32 seed)

+ 8 - 0
src/noise.h

@@ -76,6 +76,11 @@ public:
 		return (next() % (max - min + 1)) + min;
 	}
 
+	// Allow save and restore of state
+	inline s32 getState() const
+	{
+		return m_next;
+	}
 private:
 	s32 m_next;
 };
@@ -94,6 +99,9 @@ public:
 	void bytes(void *out, size_t len);
 	s32 randNormalDist(s32 min, s32 max, int num_trials=6);
 
+	// Allow save and restore of state
+	void getState(u64 state[2]) const;
+	void setState(const u64 state[2]);
 private:
 	u64 m_state;
 	u64 m_inc;

+ 53 - 0
src/script/lua_api/l_noise.cpp

@@ -425,6 +425,17 @@ int LuaPseudoRandom::l_next(lua_State *L)
 	return 1;
 }
 
+int LuaPseudoRandom::l_get_state(lua_State *L)
+{
+	NO_MAP_LOCK_REQUIRED;
+
+	LuaPseudoRandom *o = checkObject<LuaPseudoRandom>(L, 1);
+	PseudoRandom &pseudo = o->m_pseudo;
+	int val = pseudo.getState();
+	lua_pushinteger(L, val);
+	return 1;
+}
+
 
 int LuaPseudoRandom::create_object(lua_State *L)
 {
@@ -462,6 +473,7 @@ void LuaPseudoRandom::Register(lua_State *L)
 const char LuaPseudoRandom::className[] = "PseudoRandom";
 const luaL_Reg LuaPseudoRandom::methods[] = {
 	luamethod(LuaPseudoRandom, next),
+	luamethod(LuaPseudoRandom, get_state),
 	{0,0}
 };
 
@@ -496,6 +508,45 @@ int LuaPcgRandom::l_rand_normal_dist(lua_State *L)
 	return 1;
 }
 
+int LuaPcgRandom::l_get_state(lua_State *L)
+{
+	NO_MAP_LOCK_REQUIRED;
+
+	LuaPcgRandom *o = checkObject<LuaPcgRandom>(L, 1);
+
+	u64 state[2];
+	o->m_rnd.getState(state);
+
+	std::ostringstream oss;
+	oss << std::hex << std::setw(16) << std::setfill('0')
+		<< state[0] << state[1];
+
+	lua_pushstring(L, oss.str().c_str());
+	return 1;
+}
+
+int LuaPcgRandom::l_set_state(lua_State *L)
+{
+	NO_MAP_LOCK_REQUIRED;
+
+	LuaPcgRandom *o = checkObject<LuaPcgRandom>(L, 1);
+
+	std::string l_string = readParam<std::string>(L, 2);
+	if (l_string.size() != 32) {
+		throw LuaError("PcgRandom:set_state: Expected hex string of 32 characters");
+	}
+
+	std::istringstream s_state_0(l_string.substr(0, 16));
+	std::istringstream s_state_1(l_string.substr(16, 16));
+
+	u64 state[2];
+	s_state_0 >> std::hex >> state[0];
+	s_state_1 >> std::hex >> state[1];
+	
+	o->m_rnd.setState(state);
+	
+	return 0;
+}
 
 int LuaPcgRandom::create_object(lua_State *L)
 {
@@ -536,6 +587,8 @@ const char LuaPcgRandom::className[] = "PcgRandom";
 const luaL_Reg LuaPcgRandom::methods[] = {
 	luamethod(LuaPcgRandom, next),
 	luamethod(LuaPcgRandom, rand_normal_dist),
+	luamethod(LuaPcgRandom, get_state),
+	luamethod(LuaPcgRandom, set_state),
 	{0,0}
 };
 

+ 5 - 0
src/script/lua_api/l_noise.h

@@ -116,6 +116,8 @@ private:
 	// next(self, min=0, max=32767) -> get next value
 	static int l_next(lua_State *L);
 
+	// save state
+	static int l_get_state(lua_State *L);
 public:
 	LuaPseudoRandom(s32 seed) : m_pseudo(seed) {}
 
@@ -150,6 +152,9 @@ private:
 	// get next normally distributed random value
 	static int l_rand_normal_dist(lua_State *L);
 
+	// save and restore state
+	static int l_get_state(lua_State *L);
+	static int l_set_state(lua_State *L);
 public:
 	LuaPcgRandom(u64 seed) : m_rnd(seed) {}
 	LuaPcgRandom(u64 seed, u64 seq) : m_rnd(seed, seq) {}

+ 37 - 21
src/unittest/test_random.cpp

@@ -65,32 +65,39 @@ void TestRandom::testPseudoRandom()
 	for (u32 i = 0; i != 256; i++)
 		UASSERTEQ(s32, pr.next(), expected_pseudorandom_results[i]);
 
-	PseudoRandom pr2(0);
-	UASSERTEQ(int, pr2.next(), 0);
-	UASSERTEQ(int, pr2.next(), 21469);
-	UASSERTEQ(int, pr2.next(), 9989);
-
-	PseudoRandom pr3(-101);
-	UASSERTEQ(int, pr3.next(), 3267);
-	UASSERTEQ(int, pr3.next(), 2485);
-	UASSERTEQ(int, pr3.next(), 30057);
+	s32 state = pr.getState();
+	PseudoRandom pr2(state);
+
+	for (u32 i = 0; i != 256; i++) {
+		UASSERTEQ(s32, pr.next(), pr2.next());
+	}
+
+	PseudoRandom pr3(0);
+	UASSERTEQ(s32, pr3.next(), 0);
+	UASSERTEQ(s32, pr3.next(), 21469);
+	UASSERTEQ(s32, pr3.next(), 9989);
+
+	PseudoRandom pr4(-101);
+	UASSERTEQ(s32, pr4.next(), 3267);
+	UASSERTEQ(s32, pr4.next(), 2485);
+	UASSERTEQ(s32, pr4.next(), 30057);
 }
 
 
 void TestRandom::testPseudoRandomRange()
 {
-	PseudoRandom pr((int)time(NULL));
+	PseudoRandom pr((s32)time(NULL));
 
 	EXCEPTION_CHECK(PrngException, pr.range(2000, 8600));
 	EXCEPTION_CHECK(PrngException, pr.range(5, 1));
 
 	for (u32 i = 0; i != 32768; i++) {
-		int min = (pr.next() % 3000) - 500;
-		int max = (pr.next() % 3000) - 500;
+		s32 min = (pr.next() % 3000) - 500;
+		s32 max = (pr.next() % 3000) - 500;
 		if (min > max)
-			SWAP(int, min, max);
+			SWAP(s32, min, max);
 
-		int randval = pr.range(min, max);
+		s32 randval = pr.range(min, max);
 		UASSERT(randval >= min);
 		UASSERT(randval <= max);
 	}
@@ -103,12 +110,21 @@ void TestRandom::testPcgRandom()
 
 	for (u32 i = 0; i != 256; i++)
 		UASSERTEQ(u32, pr.next(), expected_pcgrandom_results[i]);
+
+	PcgRandom pr2(0, 0);
+	u64 state[2];
+	pr.getState(state);
+	pr2.setState(state);
+
+	for (u32 i = 0; i != 256; i++) {
+		UASSERTEQ(u32, pr.next(), pr2.next());
+	}
 }
 
 
 void TestRandom::testPcgRandomRange()
 {
-	PcgRandom pr((int)time(NULL));
+	PcgRandom pr((u64)time(NULL));
 
 	EXCEPTION_CHECK(PrngException, pr.range(5, 1));
 
@@ -116,12 +132,12 @@ void TestRandom::testPcgRandomRange()
 	pr.range(pr.RANDOM_MIN, pr.RANDOM_MAX);
 
 	for (u32 i = 0; i != 32768; i++) {
-		int min = (pr.next() % 3000) - 500;
-		int max = (pr.next() % 3000) - 500;
+		s32 min = (pr.next() % 3000) - 500;
+		s32 max = (pr.next() % 3000) - 500;
 		if (min > max)
-			SWAP(int, min, max);
+			SWAP(s32, min, max);
 
-		int randval = pr.range(min, max);
+		s32 randval = pr.range(min, max);
 		UASSERT(randval >= min);
 		UASSERT(randval <= max);
 	}
@@ -147,8 +163,8 @@ void TestRandom::testPcgRandomBytes()
 
 void TestRandom::testPcgRandomNormalDist()
 {
-	static const int max = 120;
-	static const int min = -120;
+	static const s32 max = 120;
+	static const s32 min = -120;
 	static const int num_trials = 20;
 	static const u32 num_samples = 61000;
 	s32 bins[max - min + 1];