Training an AI model to perform advanced speedrunner inputs of the popular racing title Mario Kart DS
Table of Contents
- Clone the repository
git clone https://github.com/blayyyyyk/marIOkart.git
- Install the dependencies
Tip
uv package manager is highly recommended for it's ease of use. But you can still install dependencies with pip install -r requirements.txt
cd marIOkart
uv sync
source .venv/bin/activate
python -m mariokart_ml train -c example_configs/train_rl.jsonTip
You can add your own config json and/or override specific options with CLI flags
--window - Enables a visual window of the training
Caution
This may reduce performance
--scale [integer] - Specify the scale of the visual window
--num-procs [integer] - Specify the number of parallel emulator instances to train in parallel.
--save-model-path [.zip path] - Specify the model weights to load before beginning training.
--load-model-path [.zip path] - Specify the output file name to save the trained model weights to after training.
--env-name [environment name] - The gymnasium environment to run the training in.
Tip
Recommended environments: mariokart_ml/TimeTrial-v1, mariokart_ml/TimeTrial-streamlit-v1
Training telemetry data is recorded by both Stable Baselines 3 and our own in-house data collection. Both can be viewed by starting the server with the specified logdir
Stable Baselines
tensorboard --logdir=tensorboard_logs/PPO_7Custom Telemetry
tensorboard --logdir=tensorboard_logs/mkds_ppo_20260426_152202Once the server is running, you can access tensorboard in your browser at http://localhost:[specified port number]
You can test the emulator with human input with the following command:
python -m mariokart_ml debug --play -c example_configs/debug.jsonA web-based streamlit dashboard can be used to view realtime observation and reward distribution for each emulator instance. View the demo here
To enable this dashboard, you must use the mariokart_ml/TimeTrial-streamlit-v1
debugging
python -m mariokart_ml debug --play -c example_configs/debug.json --env-name mariokart_ml/TimeTrial-streamlit-v1training
python -m mariokart_ml train -c example_configs/train_rl.json --env-name mariokart_ml/TimeTrial-streamlit-v1streamlit run src/mariokart_ml/streamlit/app.py -- --num-instances 16 --base-data-port 64000Note
--num-instances in almost all circumstances should be qual to --num-procs.
Warning
Make sure all streamlit tabs are closed before starting the streamlit server. If you are starting up a new training run, you DO NOT need to restart the streamlit server; your existing tab should pick up the data stream automatically when the new run starts up.
Under the hood, we use a popular library called py-desmume. It is a python library for interfacing with the DeSmuME C API. We forked the python project and integrated custom functionality for accessing game attributes from memory. By optimizing the C/C++ to Python interoperability, this enables not only 100x performance improvements in memory reads by bypassing traditional hooking overhead, but it also allows users to extract variables in an intuitive and object-oriented fashion. This high-speed data pipeline is essential for preventing bottlenecks during intensive model training.
With this API, you can easily access complex internal observation states directly from the emulator's RAM without needing to rely on heavier computer vision techniques or screen captures. Learn more
Installation
pip install py-desmume-mkdsExample Usage
from desmume.emulator_mkds import MarioKart
import torch
emu = MarioKart()
emu.open('pathtorom.nds')
# Create the window for the emulator
window = emu.create_sdl_window()
# Run the emulation as fast as possible until quit
while not window.has_quit():
window.process_input() # Controls are the default DeSmuME controls, see below.
emu.cycle()
window.draw()
# checks if a race has started
if not emu.memory.race_ready: continue
# access the player's current kart position
kart_position: torch.Tensor = emu.memory.driver.position
# access the player's current boost timer
boost_timer: float = self.emu.memory.driver.boostTimer
# access the player's current race_progress
race_progress: float = emu.memory.race_status.driverStatus[0].raceProgress