combine.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from faster_whisper.transcribe import Segment
  2. def combine(whisper_segments: list[Segment], diarization_turns: list[dict]) -> list[dict]:
  3. raw: list[dict] = []
  4. for segment in whisper_segments:
  5. if not segment.words:
  6. continue
  7. current_speaker = None
  8. current_text = ""
  9. seg_start = segment.words[0].start
  10. for word in segment.words:
  11. mid = word.start + (word.end - word.start) / 2
  12. speaker = None
  13. for turn in diarization_turns:
  14. if turn["start"] <= mid < turn["end"]:
  15. speaker = turn["speaker"]
  16. break
  17. if not speaker:
  18. if current_speaker:
  19. speaker = current_speaker
  20. else:
  21. closest = min(
  22. diarization_turns,
  23. key=lambda t: min(abs(t["start"] - mid), abs(t["end"] - mid)),
  24. )
  25. speaker = closest["speaker"]
  26. if current_speaker is None:
  27. current_speaker = speaker
  28. seg_start = word.start
  29. current_text = word.word
  30. elif speaker != current_speaker:
  31. raw.append({"start": seg_start, "end": word.start, "speaker": current_speaker, "text": current_text.strip()})
  32. current_speaker = speaker
  33. seg_start = word.start
  34. current_text = word.word
  35. else:
  36. current_text += " " + word.word
  37. if current_text:
  38. raw.append({"start": seg_start, "end": segment.end, "speaker": current_speaker, "text": current_text.strip()})
  39. # Merge consecutive same-speaker segments within 0.5s gap
  40. merged: list[dict] = []
  41. if raw:
  42. curr = raw[0]
  43. for nxt in raw[1:]:
  44. if nxt["speaker"] == curr["speaker"] and (nxt["start"] - curr["end"]) < 0.5:
  45. curr["text"] += " " + nxt["text"]
  46. curr["end"] = nxt["end"]
  47. else:
  48. merged.append(curr)
  49. curr = nxt
  50. merged.append(curr)
  51. return merged