[docs][pytorch] Add examples for compiling with external weights. (#18658)

Progress on https://github.com/iree-org/iree/issues/18564.

Adds examples to the PyTorch guide, showing how to externalize module
parameters, and load them at runtime, both through command line
(`iree-run-module`) and through the iree-runtime Python API (using
`ParameterIndex`).
diff --git a/docs/website/docs/guides/ml-frameworks/pytorch.md b/docs/website/docs/guides/ml-frameworks/pytorch.md
index 739db73..e80dec9 100644
--- a/docs/website/docs/guides/ml-frameworks/pytorch.md
+++ b/docs/website/docs/guides/ml-frameworks/pytorch.md
@@ -321,6 +321,127 @@
         self.value = new_value
     ```
 
+#### :octicons-file-symlink-file-16: Using external parameters
+
+Model parameters can be stored in standalone files that can be efficiently
+stored and loaded separately from model compute graphs. See the
+[Parameters guide](../parameters.md) for more general information about
+parameters in IREE.
+
+When using iree-turbine, the `aot.externalize_module_parameters()` function
+separates parameters from program modules and encodes a symbolic relationship
+between them so they can be loaded at runtime.
+
+We use [Safetensors](https://huggingface.co/docs/safetensors/) here to store the
+models parameters on disk, so that they can be loaded later during runtime.
+
+```python
+import torch
+from safetensors.torch import save_file
+import numpy as np
+import shark_turbine.aot as aot
+
+class LinearModule(torch.nn.Module):
+    def __init__(self, in_features, out_features):
+        super().__init__()
+        self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))
+        self.bias = torch.nn.Parameter(torch.randn(out_features))
+
+    def forward(self, input):
+        return (input @ self.weight) + self.bias
+
+linear_module = LinearModule(4,3)
+
+# Create a params dictionary. Note that the keys here match LinearModule's
+# attributes. We will use the saved safetensor file for use from the command
+# line.
+wt = linear_module.weight.t().contiguous()
+bias = linear_module.bias.t().contiguous()
+params = { "weight": wt, "bias": bias }
+save_file(params, "params.safetensors")
+
+# Externalize the model parameters. This removes weight tensors from the IR
+# module, allowing them to be loaded at runtime. Symbolic references to these
+# parameters are still retained in the IR.
+aot.externalize_module_parameters(linear_module)
+
+input = torch.randn(4)
+exported_module = aot.export(linear_module, input)
+
+# Compile the exported module, to generate the binary. When `save_to` is
+# not None, the binary will be stored at the path passed in to `save_to`.
+# Here, we pass in None, so that the binary can stored in a variable.
+binary = exported_module.compile(save_to=None)
+
+# Save the input as an npy tensor, so that it can be passed in through the
+# command line to `iree-run-module`.
+input_np = input.numpy()
+np.save("input.npy", input_np)
+```
+
+=== "Python runtime"
+
+    Runtime invocation now requires loading the parameters as a separate module.
+    To get the parameters as a module, iree.runtime provides a convenient method,
+    called `create_io_parameters_module()`.
+
+    ```python
+    import iree.runtime as ireert
+
+    # To load the parameters, we need to define ParameterIndex for each
+    # parameter class.
+    idx = ireert.ParameterIndex()
+    idx.add_buffer("weight", wt.detach().numpy().tobytes())
+    idx.add_buffer("bias", bias.detach().numpy().tobytes())
+
+
+    # Create the runtime instance, and load the runtime.
+    config = ireert.Config(driver_name="local-task")
+    instance = config.vm_instance
+
+    param_module = ireert.create_io_parameters_module(
+        instance, idx.create_provider(scope="model"),
+    )
+
+    # Load the runtime. There are essentially two modules to load, one for the
+    # weights, and one for the main module. Ensure that the VMFB file is not
+    # already open or deleted before use.
+    vm_modules = ireert.load_vm_modules(
+        param_module,
+        ireert.create_hal_module(instance, config.device),
+        ireert.VmModule.copy_buffer(instance, binary.map_memory()),
+        config=config,
+    )
+
+    # vm_modules is a list of modules. The last module in the list is the one
+    # generated from the binary, so we use that to generate an output.
+    result = vm_modules[-1].main(input)
+    print(result.to_host())
+    ```
+
+=== "Command line tools"
+
+    It is also possible to save the VMFB binary to disk, then call `iree-run-module`
+    through the command line to generate outputs.
+
+    ```python
+    # When save_to is not None, the binary is saved to the given path,
+    # and a None value is returned.
+    binary = exported_module.compile(save_to="compiled_module.vmfb")
+    ```
+
+    The stored safetensors file, the input tensor, and the VMFB can now be passed
+    in to IREE through the command line.
+
+    ```bash
+    iree-run-module --module=compiled_module.vmfb --parameters=model=params.safetensors \
+                    --input=@input.npy
+    ```
+
+    Note here that the `--parameters` flag has `model=` following it immediately.
+    This simply specifies the scope of the parameters, and is reflected in the
+    compiled module.
+
 #### :octicons-code-16: Samples
 
 | Code samples |  |