def _remap_mistral_audio_args(config: dict) -> dict:
whisper_args = config["multimodal"].pop("whisper_model_args")
encoder_args = whisper_args["encoder_args"]
downsample_args = whisper_args["downsample_args"]
downsample_factor = downsample_args["downsample_factor"]
# make sure that k/v blocks can be allocated with
# unified k/v cache class and pool whisper k/v cache blocks
# with downsample_factor:1 ratio
if encoder_args.get("causal"):
block_pool_size = downsample_factor
config["projection_size"] = downsample_factor * encoder_args["dim"]
else:
block_pool_size = 1
_maybe_sliding_window = encoder_args.get("ragged_attention", None)
if _maybe_sliding_window is None:
sliding_window = None
elif _maybe_sliding_window.isdigit():
sliding_window = int(_maybe_sliding_window)
else:
raise NotImplementedError(f"Unsupported: {_maybe_sliding_window=}")
architecture = (
"VoxtralStreamingGeneration"
if encoder_args.get("causal")
else "VoxtralForConditionalGeneration"
)
quant_config = config.get("quantization_config")
config = {
"model_type": "voxtral",
"architectures": [architecture],
"text_config": PretrainedConfig.from_dict(config),
"audio_config": WhisperConfig(
num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
window_size=encoder_args["audio_encoding_args"]["window_size"],
sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
hop_length=encoder_args["audio_encoding_args"]["hop_length"],
downsample_factor=downsample_factor,
d_model=encoder_args["dim"],
encoder_layers=encoder_args["n_layers"],
encoder_ffn_dim=encoder_args["hidden_dim"],
encoder_attention_heads=encoder_args["n_heads"],
vocab_size=encoder_args["vocab_size"],
max_source_positions=encoder_args["max_source_positions"],
is_encoder_decoder=False, # Override WhisperConfig default
is_causal=encoder_args.get("causal", False),
sliding_window=sliding_window,
block_pool_size=block_pool_size,
pos_embed=encoder_args.get("pos_embed", "sinusoidal"),
),
}
if quant_config:
config["quantization_config"] = quant_config
return config