1# worktree.py -- Working tree operations for Git repositories
2# Copyright (C) 2024 Jelmer Vernooij <jelmer@jelmer.uk>
3#
4# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
5# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
6# General Public License as published by the Free Software Foundation; version 2.0
7# or (at your option) any later version. You can redistribute it and/or
8# modify it under the terms of either of these two licenses.
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# You should have received a copy of the licenses; if not, see
17# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
18# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
19# License, Version 2.0.
20#
21
22"""Working tree operations for Git repositories."""
23
24import os
25import stat
26import sys
27import time
28import warnings
29from collections.abc import Iterable
30from typing import TYPE_CHECKING, Optional, Union
31
32if TYPE_CHECKING:
33 from .repo import Repo
34
35from .errors import CommitError, HookError
36from .objects import Commit, ObjectID, Tag, Tree
37from .refs import Ref
38from .repo import check_user_identity, get_user_identity
39
40
41class WorkTree:
42 """Working tree operations for a Git repository.
43
44 This class provides methods for working with the working tree,
45 such as staging files, committing changes, and resetting the index.
46 """
47
48 def __init__(self, repo: "Repo", path: Union[str, bytes, os.PathLike]) -> None:
49 """Initialize a WorkTree for the given repository.
50
51 Args:
52 repo: The repository this working tree belongs to
53 path: Path to the working tree directory
54 """
55 self._repo = repo
56 raw_path = os.fspath(path)
57 if isinstance(raw_path, bytes):
58 self.path: str = os.fsdecode(raw_path)
59 else:
60 self.path = raw_path
61 self.path = os.path.abspath(self.path)
62
63 def stage(
64 self,
65 fs_paths: Union[
66 str, bytes, os.PathLike, Iterable[Union[str, bytes, os.PathLike]]
67 ],
68 ) -> None:
69 """Stage a set of paths.
70
71 Args:
72 fs_paths: List of paths, relative to the repository path
73 """
74 root_path_bytes = os.fsencode(self.path)
75
76 if isinstance(fs_paths, (str, bytes, os.PathLike)):
77 fs_paths = [fs_paths]
78 fs_paths = list(fs_paths)
79
80 from .index import (
81 _fs_to_tree_path,
82 blob_from_path_and_stat,
83 index_entry_from_directory,
84 index_entry_from_stat,
85 )
86
87 index = self._repo.open_index()
88 blob_normalizer = self._repo.get_blob_normalizer()
89 for fs_path in fs_paths:
90 if not isinstance(fs_path, bytes):
91 fs_path = os.fsencode(fs_path)
92 if os.path.isabs(fs_path):
93 raise ValueError(
94 f"path {fs_path!r} should be relative to "
95 "repository root, not absolute"
96 )
97 tree_path = _fs_to_tree_path(fs_path)
98 full_path = os.path.join(root_path_bytes, fs_path)
99 try:
100 st = os.lstat(full_path)
101 except OSError:
102 # File no longer exists
103 try:
104 del index[tree_path]
105 except KeyError:
106 pass # already removed
107 else:
108 if stat.S_ISDIR(st.st_mode):
109 entry = index_entry_from_directory(st, full_path)
110 if entry:
111 index[tree_path] = entry
112 else:
113 try:
114 del index[tree_path]
115 except KeyError:
116 pass
117 elif not stat.S_ISREG(st.st_mode) and not stat.S_ISLNK(st.st_mode):
118 try:
119 del index[tree_path]
120 except KeyError:
121 pass
122 else:
123 blob = blob_from_path_and_stat(full_path, st)
124 blob = blob_normalizer.checkin_normalize(blob, fs_path)
125 self._repo.object_store.add_object(blob)
126 index[tree_path] = index_entry_from_stat(st, blob.id)
127 index.write()
128
129 def unstage(self, fs_paths: list[str]) -> None:
130 """Unstage specific file in the index
131 Args:
132 fs_paths: a list of files to unstage,
133 relative to the repository path.
134 """
135 from .index import IndexEntry, _fs_to_tree_path
136
137 index = self._repo.open_index()
138 try:
139 tree_id = self._repo[b"HEAD"].tree
140 except KeyError:
141 # no head mean no commit in the repo
142 for fs_path in fs_paths:
143 tree_path = _fs_to_tree_path(fs_path)
144 del index[tree_path]
145 index.write()
146 return
147
148 for fs_path in fs_paths:
149 tree_path = _fs_to_tree_path(fs_path)
150 try:
151 tree = self._repo.object_store[tree_id]
152 assert isinstance(tree, Tree)
153 tree_entry = tree.lookup_path(
154 self._repo.object_store.__getitem__, tree_path
155 )
156 except KeyError:
157 # if tree_entry didn't exist, this file was being added, so
158 # remove index entry
159 try:
160 del index[tree_path]
161 continue
162 except KeyError as exc:
163 raise KeyError(f"file '{tree_path.decode()}' not in index") from exc
164
165 st = None
166 try:
167 st = os.lstat(os.path.join(self.path, fs_path))
168 except FileNotFoundError:
169 pass
170
171 index_entry = IndexEntry(
172 ctime=(self._repo[b"HEAD"].commit_time, 0),
173 mtime=(self._repo[b"HEAD"].commit_time, 0),
174 dev=st.st_dev if st else 0,
175 ino=st.st_ino if st else 0,
176 mode=tree_entry[0],
177 uid=st.st_uid if st else 0,
178 gid=st.st_gid if st else 0,
179 size=len(self._repo[tree_entry[1]].data),
180 sha=tree_entry[1],
181 flags=0,
182 extended_flags=0,
183 )
184
185 index[tree_path] = index_entry
186 index.write()
187
188 def commit(
189 self,
190 message: Optional[bytes] = None,
191 committer: Optional[bytes] = None,
192 author: Optional[bytes] = None,
193 commit_timestamp=None,
194 commit_timezone=None,
195 author_timestamp=None,
196 author_timezone=None,
197 tree: Optional[ObjectID] = None,
198 encoding: Optional[bytes] = None,
199 ref: Optional[Ref] = b"HEAD",
200 merge_heads: Optional[list[ObjectID]] = None,
201 no_verify: bool = False,
202 sign: bool = False,
203 ):
204 """Create a new commit.
205
206 If not specified, committer and author default to
207 get_user_identity(..., 'COMMITTER')
208 and get_user_identity(..., 'AUTHOR') respectively.
209
210 Args:
211 message: Commit message (bytes or callable that takes (repo, commit)
212 and returns bytes)
213 committer: Committer fullname
214 author: Author fullname
215 commit_timestamp: Commit timestamp (defaults to now)
216 commit_timezone: Commit timestamp timezone (defaults to GMT)
217 author_timestamp: Author timestamp (defaults to commit
218 timestamp)
219 author_timezone: Author timestamp timezone
220 (defaults to commit timestamp timezone)
221 tree: SHA1 of the tree root to use (if not specified the
222 current index will be committed).
223 encoding: Encoding
224 ref: Optional ref to commit to (defaults to current branch).
225 If None, creates a dangling commit without updating any ref.
226 merge_heads: Merge heads (defaults to .git/MERGE_HEAD)
227 no_verify: Skip pre-commit and commit-msg hooks
228 sign: GPG Sign the commit (bool, defaults to False,
229 pass True to use default GPG key,
230 pass a str containing Key ID to use a specific GPG key)
231
232 Returns:
233 New commit SHA1
234 """
235 try:
236 if not no_verify:
237 self._repo.hooks["pre-commit"].execute()
238 except HookError as exc:
239 raise CommitError(exc) from exc
240 except KeyError: # no hook defined, silent fallthrough
241 pass
242
243 c = Commit()
244 if tree is None:
245 index = self._repo.open_index()
246 c.tree = index.commit(self._repo.object_store)
247 else:
248 if len(tree) != 40:
249 raise ValueError("tree must be a 40-byte hex sha string")
250 c.tree = tree
251
252 config = self._repo.get_config_stack()
253 if merge_heads is None:
254 merge_heads = self._repo._read_heads("MERGE_HEAD")
255 if committer is None:
256 committer = get_user_identity(config, kind="COMMITTER")
257 check_user_identity(committer)
258 c.committer = committer
259 if commit_timestamp is None:
260 # FIXME: Support GIT_COMMITTER_DATE environment variable
261 commit_timestamp = time.time()
262 c.commit_time = int(commit_timestamp)
263 if commit_timezone is None:
264 # FIXME: Use current user timezone rather than UTC
265 commit_timezone = 0
266 c.commit_timezone = commit_timezone
267 if author is None:
268 author = get_user_identity(config, kind="AUTHOR")
269 c.author = author
270 check_user_identity(author)
271 if author_timestamp is None:
272 # FIXME: Support GIT_AUTHOR_DATE environment variable
273 author_timestamp = commit_timestamp
274 c.author_time = int(author_timestamp)
275 if author_timezone is None:
276 author_timezone = commit_timezone
277 c.author_timezone = author_timezone
278 if encoding is None:
279 try:
280 encoding = config.get(("i18n",), "commitEncoding")
281 except KeyError:
282 pass # No dice
283 if encoding is not None:
284 c.encoding = encoding
285 # Store original message (might be callable)
286 original_message = message
287 message = None # Will be set later after parents are set
288
289 # Check if we should sign the commit
290 should_sign = sign
291 if sign is None:
292 # Check commit.gpgSign configuration when sign is not explicitly set
293 config = self._repo.get_config_stack()
294 try:
295 should_sign = config.get_boolean((b"commit",), b"gpgSign")
296 except KeyError:
297 should_sign = False # Default to not signing if no config
298 keyid = sign if isinstance(sign, str) else None
299
300 if ref is None:
301 # Create a dangling commit
302 c.parents = merge_heads
303 else:
304 try:
305 old_head = self._repo.refs[ref]
306 c.parents = [old_head, *merge_heads]
307 except KeyError:
308 c.parents = merge_heads
309
310 # Handle message after parents are set
311 if callable(original_message):
312 message = original_message(self._repo, c)
313 if message is None:
314 raise ValueError("Message callback returned None")
315 else:
316 message = original_message
317
318 if message is None:
319 # FIXME: Try to read commit message from .git/MERGE_MSG
320 raise ValueError("No commit message specified")
321
322 try:
323 if no_verify:
324 c.message = message
325 else:
326 c.message = self._repo.hooks["commit-msg"].execute(message)
327 if c.message is None:
328 c.message = message
329 except HookError as exc:
330 raise CommitError(exc) from exc
331 except KeyError: # no hook defined, message not modified
332 c.message = message
333
334 if ref is None:
335 # Create a dangling commit
336 if should_sign:
337 c.sign(keyid)
338 self._repo.object_store.add_object(c)
339 else:
340 try:
341 old_head = self._repo.refs[ref]
342 if should_sign:
343 c.sign(keyid)
344 self._repo.object_store.add_object(c)
345 ok = self._repo.refs.set_if_equals(
346 ref,
347 old_head,
348 c.id,
349 message=b"commit: " + message,
350 committer=committer,
351 timestamp=commit_timestamp,
352 timezone=commit_timezone,
353 )
354 except KeyError:
355 c.parents = merge_heads
356 if should_sign:
357 c.sign(keyid)
358 self._repo.object_store.add_object(c)
359 ok = self._repo.refs.add_if_new(
360 ref,
361 c.id,
362 message=b"commit: " + message,
363 committer=committer,
364 timestamp=commit_timestamp,
365 timezone=commit_timezone,
366 )
367 if not ok:
368 # Fail if the atomic compare-and-swap failed, leaving the
369 # commit and all its objects as garbage.
370 raise CommitError(f"{ref!r} changed during commit")
371
372 self._repo._del_named_file("MERGE_HEAD")
373
374 try:
375 self._repo.hooks["post-commit"].execute()
376 except HookError as e: # silent failure
377 warnings.warn(f"post-commit hook failed: {e}", UserWarning)
378 except KeyError: # no hook defined, silent fallthrough
379 pass
380
381 # Trigger auto GC if needed
382 from .gc import maybe_auto_gc
383
384 maybe_auto_gc(self._repo)
385
386 return c.id
387
388 def reset_index(self, tree: Optional[bytes] = None):
389 """Reset the index back to a specific tree.
390
391 Args:
392 tree: Tree SHA to reset to, None for current HEAD tree.
393 """
394 from .index import (
395 build_index_from_tree,
396 symlink,
397 validate_path_element_default,
398 validate_path_element_hfs,
399 validate_path_element_ntfs,
400 )
401
402 if tree is None:
403 head = self._repo[b"HEAD"]
404 if isinstance(head, Tag):
405 _cls, obj = head.object
406 head = self._repo.get_object(obj)
407 tree = head.tree
408 config = self._repo.get_config()
409 honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
410 if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"):
411 validate_path_element = validate_path_element_ntfs
412 elif config.get_boolean(b"core", b"core.protectHFS", sys.platform == "darwin"):
413 validate_path_element = validate_path_element_hfs
414 else:
415 validate_path_element = validate_path_element_default
416 if config.get_boolean(b"core", b"symlinks", True):
417 symlink_fn = symlink
418 else:
419
420 def symlink_fn(source, target) -> None: # type: ignore
421 with open(
422 target, "w" + ("b" if isinstance(source, bytes) else "")
423 ) as f:
424 f.write(source)
425
426 blob_normalizer = self._repo.get_blob_normalizer()
427 return build_index_from_tree(
428 self.path,
429 self._repo.index_path(),
430 self._repo.object_store,
431 tree,
432 honor_filemode=honor_filemode,
433 validate_path_element=validate_path_element,
434 symlink_fn=symlink_fn,
435 blob_normalizer=blob_normalizer,
436 )
437
438 def _sparse_checkout_file_path(self) -> str:
439 """Return the path of the sparse-checkout file in this repo's control dir."""
440 return os.path.join(self._repo.controldir(), "info", "sparse-checkout")
441
442 def configure_for_cone_mode(self) -> None:
443 """Ensure the repository is configured for cone-mode sparse-checkout."""
444 config = self._repo.get_config()
445 config.set((b"core",), b"sparseCheckout", b"true")
446 config.set((b"core",), b"sparseCheckoutCone", b"true")
447 config.write_to_path()
448
449 def infer_cone_mode(self) -> bool:
450 """Return True if 'core.sparseCheckoutCone' is set to 'true' in config, else False."""
451 config = self._repo.get_config()
452 try:
453 sc_cone = config.get((b"core",), b"sparseCheckoutCone")
454 return sc_cone == b"true"
455 except KeyError:
456 # If core.sparseCheckoutCone is not set, default to False
457 return False
458
459 def get_sparse_checkout_patterns(self) -> list[str]:
460 """Return a list of sparse-checkout patterns from info/sparse-checkout.
461
462 Returns:
463 A list of patterns. Returns an empty list if the file is missing.
464 """
465 path = self._sparse_checkout_file_path()
466 try:
467 with open(path, encoding="utf-8") as f:
468 return [line.strip() for line in f if line.strip()]
469 except FileNotFoundError:
470 return []
471
472 def set_sparse_checkout_patterns(self, patterns: list[str]) -> None:
473 """Write the given sparse-checkout patterns into info/sparse-checkout.
474
475 Creates the info/ directory if it does not exist.
476
477 Args:
478 patterns: A list of gitignore-style patterns to store.
479 """
480 info_dir = os.path.join(self._repo.controldir(), "info")
481 os.makedirs(info_dir, exist_ok=True)
482
483 path = self._sparse_checkout_file_path()
484 with open(path, "w", encoding="utf-8") as f:
485 for pat in patterns:
486 f.write(pat + "\n")
487
488 def set_cone_mode_patterns(self, dirs: Union[list[str], None] = None) -> None:
489 """Write the given cone-mode directory patterns into info/sparse-checkout.
490
491 For each directory to include, add an inclusion line that "undoes" the prior
492 ``!/*/`` 'exclude' that re-includes that directory and everything under it.
493 Never add the same line twice.
494 """
495 patterns = ["/*", "!/*/"]
496 if dirs:
497 for d in dirs:
498 d = d.strip("/")
499 line = f"/{d}/"
500 if d and line not in patterns:
501 patterns.append(line)
502 self.set_sparse_checkout_patterns(patterns)